diff --git a/Cargo.lock b/Cargo.lock
index 9a976c0d4..2c12fb17c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -27,12 +27,6 @@ dependencies = [
"gimli",
]
-[[package]]
-name = "adler"
-version = "1.0.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
-
[[package]]
name = "adler2"
version = "2.0.0"
@@ -68,9 +62,9 @@ checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f"
[[package]]
name = "allocator-api2"
-version = "0.2.18"
+version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
+checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9"
[[package]]
name = "android-tzdata"
@@ -87,11 +81,17 @@ dependencies = [
"libc",
]
+[[package]]
+name = "anes"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
+
[[package]]
name = "anstream"
-version = "0.6.17"
+version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338"
+checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
dependencies = [
"anstyle",
"anstyle-parse",
@@ -104,9 +104,9 @@ dependencies = [
[[package]]
name = "anstyle"
-version = "1.0.9"
+version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56"
+checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
[[package]]
name = "anstyle-parse"
@@ -138,18 +138,18 @@ dependencies = [
[[package]]
name = "anyhow"
-version = "1.0.91"
+version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8"
+checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
dependencies = [
"backtrace",
]
[[package]]
name = "arbitrary"
-version = "1.3.2"
+version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
+checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
dependencies = [
"derive_arbitrary",
]
@@ -174,7 +174,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -191,9 +191,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "axum"
-version = "0.7.7"
+version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
+checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
@@ -253,7 +253,7 @@ dependencies = [
"addr2line",
"cfg-if",
"libc",
- "miniz_oxide 0.8.0",
+ "miniz_oxide",
"object",
"rustc-demangle",
"windows-targets 0.52.6",
@@ -341,8 +341,8 @@ dependencies = [
"proc-macro2",
"quote",
"str_inflector",
- "syn 2.0.85",
- "thiserror",
+ "syn 2.0.87",
+ "thiserror 1.0.69",
"try_match",
]
@@ -354,9 +354,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytemuck"
-version = "1.19.0"
+version = "1.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d"
+checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a"
dependencies = [
"bytemuck_derive",
]
@@ -369,7 +369,7 @@ checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -391,89 +391,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da"
[[package]]
-name = "candle-core"
-version = "0.8.0"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e"
-dependencies = [
- "accelerate-src",
- "byteorder",
- "candle-kernels",
- "candle-metal-kernels",
- "cudarc",
- "float8",
- "gemm",
- "half",
- "intel-mkl-src",
- "libc",
- "memmap2",
- "metal 0.27.0",
- "num-traits",
- "num_cpus",
- "rand",
- "rand_distr",
- "rayon",
- "safetensors",
- "thiserror",
- "ug",
- "ug-cuda",
- "ug-metal",
- "yoke",
- "zip",
-]
-
-[[package]]
-name = "candle-flash-attn"
-version = "0.8.0"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e"
-dependencies = [
- "anyhow",
- "bindgen_cuda 0.1.5",
- "candle-core",
- "half",
-]
-
-[[package]]
-name = "candle-kernels"
-version = "0.8.0"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e"
-dependencies = [
- "bindgen_cuda 0.1.5",
-]
-
-[[package]]
-name = "candle-metal-kernels"
-version = "0.8.0"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e"
-dependencies = [
- "metal 0.27.0",
- "once_cell",
- "thiserror",
- "tracing",
-]
-
-[[package]]
-name = "candle-nn"
-version = "0.8.0"
-source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e"
-dependencies = [
- "accelerate-src",
- "candle-core",
- "candle-metal-kernels",
- "half",
- "intel-mkl-src",
- "metal 0.27.0",
- "num-traits",
- "rayon",
- "safetensors",
- "serde",
- "thiserror",
-]
+name = "cast"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
-version = "1.1.31"
+version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f"
+checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
dependencies = [
"shlex",
]
@@ -492,9 +419,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "cfgrammar"
-version = "0.13.7"
+version = "0.13.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6026d8cd82ada8bbcfe337805dd1eb6afdc9e80fa4d57e977b3a36315e0c5525"
+checksum = "6d621f687a04efa1f269f1cd13d8cfea9660852bdb3d1cd2c3c9fb6fdd34daf2"
dependencies = [
"indexmap",
"lazy_static",
@@ -540,11 +467,38 @@ dependencies = [
"phf_codegen",
]
+[[package]]
+name = "ciborium"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
+dependencies = [
+ "ciborium-io",
+ "ciborium-ll",
+ "serde",
+]
+
+[[package]]
+name = "ciborium-io"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
+
+[[package]]
+name = "ciborium-ll"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
+dependencies = [
+ "ciborium-io",
+ "half",
+]
+
[[package]]
name = "clap"
-version = "4.5.20"
+version = "4.5.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8"
+checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f"
dependencies = [
"clap_builder",
"clap_derive",
@@ -552,9 +506,9 @@ dependencies = [
[[package]]
name = "clap_builder"
-version = "4.5.20"
+version = "4.5.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54"
+checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec"
dependencies = [
"anstream",
"anstyle",
@@ -572,14 +526,14 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
name = "clap_lex"
-version = "0.7.2"
+version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
+checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
[[package]]
name = "cli-table"
@@ -590,7 +544,7 @@ dependencies = [
"cli-table-derive",
"csv",
"termcolor",
- "unicode-width",
+ "unicode-width 0.1.14",
]
[[package]]
@@ -625,7 +579,7 @@ dependencies = [
"encode_unicode",
"lazy_static",
"libc",
- "unicode-width",
+ "unicode-width 0.1.14",
"windows-sys 0.52.0",
]
@@ -658,9 +612,9 @@ dependencies = [
[[package]]
name = "cpufeatures"
-version = "0.2.14"
+version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0"
+checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6"
dependencies = [
"libc",
]
@@ -674,6 +628,40 @@ dependencies = [
"cfg-if",
]
+[[package]]
+name = "criterion"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
+dependencies = [
+ "anes",
+ "cast",
+ "ciborium",
+ "clap",
+ "criterion-plot",
+ "is-terminal",
+ "itertools 0.10.5",
+ "num-traits",
+ "once_cell",
+ "oorandom",
+ "regex",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
+[[package]]
+name = "criterion-plot"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
+dependencies = [
+ "cast",
+ "itertools 0.10.5",
+]
+
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
@@ -742,9 +730,9 @@ dependencies = [
[[package]]
name = "csv"
-version = "1.3.0"
+version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe"
+checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
dependencies = [
"csv-core",
"itoa",
@@ -826,7 +814,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim 0.11.1",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -848,7 +836,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core 0.20.10",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -871,18 +859,18 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
name = "derive_arbitrary"
-version = "1.3.2"
+version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
+checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -903,7 +891,7 @@ dependencies = [
"darling 0.20.10",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -913,7 +901,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -924,7 +912,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -975,7 +963,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -1033,7 +1021,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -1063,15 +1051,14 @@ dependencies = [
[[package]]
name = "exr"
-version = "1.72.0"
+version = "1.73.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4"
+checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0"
dependencies = [
"bit_field",
- "flume",
"half",
"lebe",
- "miniz_oxide 0.7.4",
+ "miniz_oxide",
"rayon-core",
"smallvec",
"zune-inflate",
@@ -1089,15 +1076,15 @@ dependencies = [
[[package]]
name = "fastrand"
-version = "2.1.1"
+version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
+checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4"
[[package]]
name = "fdeflate"
-version = "0.3.5"
+version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
+checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb"
dependencies = [
"simd-adler32",
]
@@ -1116,12 +1103,12 @@ dependencies = [
[[package]]
name = "flate2"
-version = "1.0.34"
+version = "1.0.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
+checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
dependencies = [
"crc32fast",
- "miniz_oxide 0.8.0",
+ "miniz_oxide",
]
[[package]]
@@ -1137,15 +1124,6 @@ dependencies = [
"rand_distr",
]
-[[package]]
-name = "flume"
-version = "0.11.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095"
-dependencies = [
- "spin",
-]
-
[[package]]
name = "fnv"
version = "1.0.7"
@@ -1179,7 +1157,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -1259,7 +1237,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -1437,8 +1415,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
+ "js-sys",
"libc",
"wasi",
+ "wasm-bindgen",
]
[[package]]
@@ -1450,7 +1430,7 @@ dependencies = [
"proc-macro-error2",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -1477,9 +1457,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "h2"
-version = "0.4.6"
+version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205"
+checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
dependencies = [
"atomic-waker",
"bytes",
@@ -1520,9 +1500,9 @@ dependencies = [
[[package]]
name = "hashbrown"
-version = "0.15.0"
+version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb"
+checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
[[package]]
name = "heck"
@@ -1536,6 +1516,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
+[[package]]
+name = "hermit-abi"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc"
+
[[package]]
name = "hf-hub"
version = "0.3.2"
@@ -1549,7 +1535,7 @@ dependencies = [
"rand",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"ureq",
]
@@ -1601,9 +1587,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
-version = "1.5.0"
+version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a"
+checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f"
dependencies = [
"bytes",
"futures-channel",
@@ -1656,9 +1642,9 @@ dependencies = [
[[package]]
name = "hyper-util"
-version = "0.1.9"
+version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b"
+checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4"
dependencies = [
"bytes",
"futures-channel",
@@ -1696,6 +1682,124 @@ dependencies = [
"cc",
]
+[[package]]
+name = "icu_collections"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526"
+dependencies = [
+ "displaydoc",
+ "yoke",
+ "zerofrom",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_locid"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637"
+dependencies = [
+ "displaydoc",
+ "litemap",
+ "tinystr",
+ "writeable",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_locid_transform"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e"
+dependencies = [
+ "displaydoc",
+ "icu_locid",
+ "icu_locid_transform_data",
+ "icu_provider",
+ "tinystr",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_locid_transform_data"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e"
+
+[[package]]
+name = "icu_normalizer"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f"
+dependencies = [
+ "displaydoc",
+ "icu_collections",
+ "icu_normalizer_data",
+ "icu_properties",
+ "icu_provider",
+ "smallvec",
+ "utf16_iter",
+ "utf8_iter",
+ "write16",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_normalizer_data"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516"
+
+[[package]]
+name = "icu_properties"
+version = "1.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5"
+dependencies = [
+ "displaydoc",
+ "icu_collections",
+ "icu_locid_transform",
+ "icu_properties_data",
+ "icu_provider",
+ "tinystr",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_properties_data"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569"
+
+[[package]]
+name = "icu_provider"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9"
+dependencies = [
+ "displaydoc",
+ "icu_locid",
+ "icu_provider_macros",
+ "stable_deref_trait",
+ "tinystr",
+ "writeable",
+ "yoke",
+ "zerofrom",
+ "zerovec",
+]
+
+[[package]]
+name = "icu_provider_macros"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.87",
+]
+
[[package]]
name = "ident_case"
version = "1.0.1"
@@ -1704,19 +1808,30 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
-version = "0.5.0"
+version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
+checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e"
dependencies = [
- "unicode-bidi",
- "unicode-normalization",
+ "idna_adapter",
+ "smallvec",
+ "utf8_iter",
+]
+
+[[package]]
+name = "idna_adapter"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71"
+dependencies = [
+ "icu_normalizer",
+ "icu_properties",
]
[[package]]
name = "image"
-version = "0.25.4"
+version = "0.25.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae"
+checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b"
dependencies = [
"bytemuck",
"byteorder-lite",
@@ -1755,22 +1870,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da"
dependencies = [
"equivalent",
- "hashbrown 0.15.0",
+ "hashbrown 0.15.1",
"serde",
]
[[package]]
name = "indicatif"
-version = "0.17.8"
+version = "0.17.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3"
+checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281"
dependencies = [
"console",
- "instant",
"number_prefix",
"portable-atomic",
"rayon",
- "unicode-width",
+ "unicode-width 0.2.0",
+ "web-time",
]
[[package]]
@@ -1779,15 +1894,6 @@ version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
-[[package]]
-name = "instant"
-version = "0.1.13"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
-dependencies = [
- "cfg-if",
-]
-
[[package]]
name = "intel-mkl-src"
version = "0.8.1"
@@ -1816,12 +1922,32 @@ version = "2.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
+[[package]]
+name = "is-terminal"
+version = "0.4.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b"
+dependencies = [
+ "hermit-abi 0.4.0",
+ "libc",
+ "windows-sys 0.52.0",
+]
+
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
+[[package]]
+name = "itertools"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
+dependencies = [
+ "either",
+]
+
[[package]]
name = "itertools"
version = "0.11.0"
@@ -1884,9 +2010,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
-version = "0.2.161"
+version = "0.2.164"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1"
+checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
[[package]]
name = "libloading"
@@ -1900,9 +2026,9 @@ dependencies = [
[[package]]
name = "libm"
-version = "0.2.8"
+version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
+checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
[[package]]
name = "libredox"
@@ -1921,6 +2047,12 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
+[[package]]
+name = "litemap"
+version = "0.7.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704"
+
[[package]]
name = "lock_api"
version = "0.4.12"
@@ -1939,9 +2071,9 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "lrtable"
-version = "0.13.7"
+version = "0.13.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d42d2752cb50a171efadda0cb6fa97432e8bf05accfff3eed320b87e80a2f69e"
+checksum = "49e35162de3a5d91b380f8ebb31fc6c5e9a4618276465df4725ff1f88613312b"
dependencies = [
"cfgrammar",
"fnv",
@@ -1990,6 +2122,88 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
+[[package]]
+name = "mcandle-core"
+version = "0.3.2"
+dependencies = [
+ "accelerate-src",
+ "anyhow",
+ "byteorder",
+ "clap",
+ "criterion",
+ "cudarc",
+ "float8",
+ "gemm",
+ "half",
+ "intel-mkl-src",
+ "libc",
+ "mcandle-kernels",
+ "mcandle-metal-kernels",
+ "memmap2",
+ "metal 0.27.0",
+ "num-traits",
+ "num_cpus",
+ "rand",
+ "rand_distr",
+ "rayon",
+ "safetensors",
+ "thiserror 1.0.69",
+ "ug",
+ "ug-cuda",
+ "ug-metal",
+ "yoke",
+ "zip",
+]
+
+[[package]]
+name = "mcandle-flash-attn"
+version = "0.3.2"
+dependencies = [
+ "anyhow",
+ "bindgen_cuda 0.1.5",
+ "half",
+ "mcandle-core",
+]
+
+[[package]]
+name = "mcandle-kernels"
+version = "0.3.2"
+dependencies = [
+ "bindgen_cuda 0.1.5",
+]
+
+[[package]]
+name = "mcandle-metal-kernels"
+version = "0.3.2"
+dependencies = [
+ "metal 0.27.0",
+ "once_cell",
+ "thiserror 1.0.69",
+ "tracing",
+]
+
+[[package]]
+name = "mcandle-nn"
+version = "0.3.2"
+dependencies = [
+ "accelerate-src",
+ "anyhow",
+ "clap",
+ "criterion",
+ "half",
+ "intel-mkl-src",
+ "mcandle-core",
+ "mcandle-flash-attn",
+ "mcandle-metal-kernels",
+ "metal 0.27.0",
+ "num-traits",
+ "rand",
+ "rayon",
+ "safetensors",
+ "serde",
+ "thiserror 1.0.69",
+]
+
[[package]]
name = "memchr"
version = "2.7.4"
@@ -2063,9 +2277,9 @@ dependencies = [
[[package]]
name = "minijinja"
-version = "2.4.0"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c9ca8daf4b0b4029777f1bc6e1aedd1aec7b74c276a43bc6f620a8e1a1c0a90e"
+checksum = "2c37e1b517d1dcd0e51dc36c4567b9d5a29262b3ec8da6cb5d35e27a8fb529b5"
dependencies = [
"serde",
"serde_json",
@@ -2073,9 +2287,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
-version = "2.3.1"
+version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a"
+checksum = "7fe51f1a6a8285f03fcd1544d834234fe8db285f29e1c2253600c93b3ae19242"
dependencies = [
"minijinja",
"serde",
@@ -2087,15 +2301,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
-[[package]]
-name = "miniz_oxide"
-version = "0.7.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
-dependencies = [
- "adler",
-]
-
[[package]]
name = "miniz_oxide"
version = "0.8.0"
@@ -2124,7 +2329,7 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
dependencies = [
- "hermit-abi",
+ "hermit-abi 0.3.9",
"libc",
"wasi",
"windows-sys 0.52.0",
@@ -2135,11 +2340,11 @@ name = "mistralrs"
version = "0.3.2"
dependencies = [
"anyhow",
- "candle-core",
"either",
"futures",
"image",
"indexmap",
+ "mcandle-core",
"mistralrs-core",
"rand",
"reqwest",
@@ -2153,9 +2358,9 @@ name = "mistralrs-bench"
version = "0.3.2"
dependencies = [
"anyhow",
- "candle-core",
"clap",
"cli-table",
+ "mcandle-core",
"mistralrs-core",
"serde",
"serde_json",
@@ -2177,9 +2382,6 @@ dependencies = [
"buildstructor",
"bytemuck",
"bytemuck_derive",
- "candle-core",
- "candle-flash-attn",
- "candle-nn",
"cfgrammar",
"chrono",
"clap",
@@ -2199,6 +2401,9 @@ dependencies = [
"intel-mkl-src",
"itertools 0.13.0",
"lrtable",
+ "mcandle-core",
+ "mcandle-flash-attn",
+ "mcandle-nn",
"minijinja",
"minijinja-contrib",
"mistralrs-paged-attn",
@@ -2211,7 +2416,7 @@ dependencies = [
"rand_isaac",
"rayon",
"regex",
- "regex-automata 0.4.8",
+ "regex-automata 0.4.9",
"reqwest",
"rustc-hash",
"safetensors",
@@ -2222,7 +2427,7 @@ dependencies = [
"serde_yaml",
"strum",
"sysinfo",
- "thiserror",
+ "thiserror 1.0.69",
"tokenizers",
"tokio",
"tokio-rayon",
@@ -2241,9 +2446,9 @@ version = "0.3.2"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
- "candle-core",
"float8",
"half",
+ "mcandle-core",
]
[[package]]
@@ -2253,13 +2458,13 @@ dependencies = [
"accelerate-src",
"anyhow",
"base64 0.22.1",
- "candle-core",
"data-url",
"either",
"futures",
"image",
"indexmap",
"intel-mkl-src",
+ "mcandle-core",
"mistralrs-core",
"pyo3",
"pyo3-build-config",
@@ -2276,17 +2481,17 @@ version = "0.3.2"
dependencies = [
"bindgen_cuda 0.1.5",
"byteorder",
- "candle-core",
- "candle-nn",
"float8",
"half",
"lazy_static",
+ "mcandle-core",
+ "mcandle-nn",
"metal 0.27.0",
"once_cell",
"paste",
"rayon",
"serde",
- "thiserror",
+ "thiserror 1.0.69",
"tracing",
]
@@ -2297,7 +2502,6 @@ dependencies = [
"accelerate-src",
"anyhow",
"axum",
- "candle-core",
"clap",
"ctrlc",
"data-url",
@@ -2306,6 +2510,7 @@ dependencies = [
"image",
"indexmap",
"intel-mkl-src",
+ "mcandle-core",
"mistralrs-core",
"once_cell",
"regex",
@@ -2324,8 +2529,8 @@ dependencies = [
name = "mistralrs-vision"
version = "0.3.2"
dependencies = [
- "candle-core",
"image",
+ "mcandle-core",
]
[[package]]
@@ -2346,7 +2551,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -2497,7 +2702,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
- "hermit-abi",
+ "hermit-abi 0.3.9",
"libc",
]
@@ -2519,7 +2724,7 @@ dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -2568,7 +2773,7 @@ dependencies = [
"serde_json",
"strum",
"strum_macros",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
@@ -2590,7 +2795,7 @@ dependencies = [
"serde_json",
"sha2",
"tar",
- "thiserror",
+ "thiserror 1.0.69",
"toml",
"ureq",
"url",
@@ -2626,6 +2831,12 @@ dependencies = [
"pkg-config",
]
+[[package]]
+name = "oorandom"
+version = "11.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
+
[[package]]
name = "openssl"
version = "0.10.68"
@@ -2649,7 +2860,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -2802,7 +3013,7 @@ dependencies = [
"crc32fast",
"fdeflate",
"flate2",
- "miniz_oxide 0.8.0",
+ "miniz_oxide",
]
[[package]]
@@ -2872,7 +3083,7 @@ dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -2898,9 +3109,9 @@ dependencies = [
[[package]]
name = "pyo3"
-version = "0.22.5"
+version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51"
+checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
dependencies = [
"anyhow",
"cfg-if",
@@ -2929,9 +3140,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
-version = "0.22.5"
+version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179"
+checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
dependencies = [
"once_cell",
"target-lexicon",
@@ -2939,9 +3150,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
-version = "0.22.5"
+version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d"
+checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
dependencies = [
"libc",
"pyo3-build-config",
@@ -2949,27 +3160,27 @@ dependencies = [
[[package]]
name = "pyo3-macros"
-version = "0.22.5"
+version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e"
+checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
name = "pyo3-macros-backend"
-version = "0.22.5"
+version = "0.22.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce"
+checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -2989,9 +3200,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
-version = "0.11.5"
+version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684"
+checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
@@ -3000,34 +3211,38 @@ dependencies = [
"rustc-hash",
"rustls",
"socket2",
- "thiserror",
+ "thiserror 2.0.3",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
-version = "0.11.8"
+version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6"
+checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
+ "getrandom",
"rand",
"ring",
"rustc-hash",
"rustls",
+ "rustls-pki-types",
"slab",
- "thiserror",
+ "thiserror 2.0.3",
"tinyvec",
"tracing",
+ "web-time",
]
[[package]]
name = "quinn-udp"
-version = "0.5.5"
+version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b"
+checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da"
dependencies = [
+ "cfg_aliases",
"libc",
"once_cell",
"socket2",
@@ -3166,7 +3381,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [
"getrandom",
"libredox",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
@@ -3177,7 +3392,7 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
- "regex-automata 0.4.8",
+ "regex-automata 0.4.9",
"regex-syntax 0.8.5",
]
@@ -3192,9 +3407,9 @@ dependencies = [
[[package]]
name = "regex-automata"
-version = "0.4.8"
+version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
+checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
@@ -3215,9 +3430,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
-version = "0.12.8"
+version = "0.12.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b"
+checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f"
dependencies = [
"base64 0.22.1",
"bytes",
@@ -3297,7 +3512,7 @@ dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
- "syn 2.0.85",
+ "syn 2.0.87",
"walkdir",
]
@@ -3344,9 +3559,9 @@ dependencies = [
[[package]]
name = "rustix"
-version = "0.38.37"
+version = "0.38.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811"
+checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6"
dependencies = [
"bitflags 2.6.0",
"errno",
@@ -3357,9 +3572,9 @@ dependencies = [
[[package]]
name = "rustls"
-version = "0.23.15"
+version = "0.23.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993"
+checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e"
dependencies = [
"log",
"once_cell",
@@ -3384,6 +3599,9 @@ name = "rustls-pki-types"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b"
+dependencies = [
+ "web-time",
+]
[[package]]
name = "rustls-webpki"
@@ -3429,9 +3647,9 @@ dependencies = [
[[package]]
name = "schannel"
-version = "0.1.26"
+version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1"
+checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d"
dependencies = [
"windows-sys 0.59.0",
]
@@ -3457,7 +3675,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde_derive_internals",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -3481,9 +3699,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
-version = "2.12.0"
+version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6"
+checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2"
dependencies = [
"core-foundation-sys",
"libc",
@@ -3503,22 +3721,22 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
-version = "1.0.213"
+version = "1.0.215"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1"
+checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
-version = "1.0.213"
+version = "1.0.215"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5"
+checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -3529,14 +3747,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
name = "serde_json"
-version = "1.0.132"
+version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
+checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [
"itoa",
"memchr",
@@ -3707,9 +3925,6 @@ name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
-dependencies = [
- "lock_api",
-]
[[package]]
name = "spm_precompiled"
@@ -3770,7 +3985,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -3792,9 +4007,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.85"
+version = "2.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56"
+checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
dependencies = [
"proc-macro2",
"quote",
@@ -3824,7 +4039,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -3837,7 +4052,7 @@ dependencies = [
"byteorder",
"enum-as-inner",
"libc",
- "thiserror",
+ "thiserror 1.0.69",
"walkdir",
]
@@ -3879,9 +4094,9 @@ dependencies = [
[[package]]
name = "tar"
-version = "0.4.42"
+version = "0.4.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ff6c40d3aedb5e06b57c6f669ad17ab063dd1e63d977c6a88e7f4dfa4f04020"
+checksum = "c65998313f8e17d0d553d28f91a0df93e4dbbbf770279c7bc21ca0f09ea1a1f6"
dependencies = [
"filetime",
"libc",
@@ -3896,9 +4111,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
-version = "3.13.0"
+version = "3.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
+checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
dependencies = [
"cfg-if",
"fastrand",
@@ -3928,22 +4143,42 @@ dependencies = [
[[package]]
name = "thiserror"
-version = "1.0.65"
+version = "1.0.69"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
+dependencies = [
+ "thiserror-impl 1.0.69",
+]
+
+[[package]]
+name = "thiserror"
+version = "2.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5"
+checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa"
dependencies = [
- "thiserror-impl",
+ "thiserror-impl 2.0.3",
]
[[package]]
name = "thiserror-impl"
-version = "1.0.65"
+version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602"
+checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "2.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.87",
]
[[package]]
@@ -3967,6 +4202,26 @@ dependencies = [
"weezl",
]
+[[package]]
+name = "tinystr"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f"
+dependencies = [
+ "displaydoc",
+ "zerovec",
+]
+
+[[package]]
+name = "tinytemplate"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "tinyvec"
version = "1.8.0"
@@ -4008,7 +4263,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
- "thiserror",
+ "thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
@@ -4016,9 +4271,9 @@ dependencies = [
[[package]]
name = "tokio"
-version = "1.41.0"
+version = "1.41.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb"
+checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
dependencies = [
"backtrace",
"bytes",
@@ -4040,7 +4295,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -4196,7 +4451,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -4261,7 +4516,7 @@ checksum = "b9c81686f7ab4065ccac3df7a910c4249f8c0f3fb70421d6ddec19b9311f63f9"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -4280,7 +4535,7 @@ dependencies = [
"num",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
@@ -4293,7 +4548,7 @@ dependencies = [
"half",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"ug",
]
@@ -4308,7 +4563,7 @@ dependencies = [
"objc",
"serde",
"serde_json",
- "thiserror",
+ "thiserror 1.0.69",
"ug",
]
@@ -4324,27 +4579,12 @@ version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df"
-[[package]]
-name = "unicode-bidi"
-version = "0.3.17"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893"
-
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
-[[package]]
-name = "unicode-normalization"
-version = "0.1.24"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956"
-dependencies = [
- "tinyvec",
-]
-
[[package]]
name = "unicode-normalization-alignments"
version = "0.1.12"
@@ -4366,6 +4606,12 @@ version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
+[[package]]
+name = "unicode-width"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
+
[[package]]
name = "unicode_categories"
version = "0.1.1"
@@ -4411,15 +4657,27 @@ dependencies = [
[[package]]
name = "url"
-version = "2.5.2"
+version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c"
+checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
+[[package]]
+name = "utf16_iter"
+version = "1.0.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246"
+
+[[package]]
+name = "utf8_iter"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
+
[[package]]
name = "utf8parse"
version = "0.2.2"
@@ -4448,7 +4706,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -4577,7 +4835,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
"wasm-bindgen-shared",
]
@@ -4611,7 +4869,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -4632,6 +4890,16 @@ dependencies = [
"wasm-bindgen",
]
+[[package]]
+name = "web-time"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
+dependencies = [
+ "js-sys",
+ "wasm-bindgen",
+]
+
[[package]]
name = "webpki-roots"
version = "0.26.6"
@@ -4884,6 +5152,18 @@ dependencies = [
"memchr",
]
+[[package]]
+name = "write16"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936"
+
+[[package]]
+name = "writeable"
+version = "0.5.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
+
[[package]]
name = "xattr"
version = "1.3.1"
@@ -4915,7 +5195,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
"synstructure",
]
@@ -4937,7 +5217,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
]
[[package]]
@@ -4957,7 +5237,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.85",
+ "syn 2.0.87",
"synstructure",
]
@@ -4967,6 +5247,28 @@ version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
+[[package]]
+name = "zerovec"
+version = "0.10.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079"
+dependencies = [
+ "yoke",
+ "zerofrom",
+ "zerovec-derive",
+]
+
+[[package]]
+name = "zerovec-derive"
+version = "0.10.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.87",
+]
+
[[package]]
name = "zip"
version = "1.1.4"
@@ -4980,7 +5282,7 @@ dependencies = [
"flate2",
"indexmap",
"num_enum",
- "thiserror",
+ "thiserror 1.0.69",
]
[[package]]
diff --git a/Cargo.toml b/Cargo.toml
index a116c0da5..70198542e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -7,9 +7,14 @@ members = [
"mistralrs-bench",
"mistralrs-vision",
"mistralrs-quant",
+ "mcandle/mcandle-core",
+ "mcandle/mcandle-nn",
]
exclude = [
"mistralrs-paged_attn",
+ "mcandle/mcandle-flash-attn",
+ "mcandle/mcandle-kernels",
+ "mcandle/mcandle-metal-kernels",
]
resolver = "2"
@@ -25,8 +30,11 @@ license = "MIT"
[workspace.dependencies]
anyhow = "1.0.80"
-candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "cb8082b" }
-candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "cb8082b" }
+mcandle-core = { path = "./mcandle/mcandle-core", package = "mcandle-core" }
+mcandle-flash-attn = { path = "./mcandle/mcandle-flash-attn" }
+mcandle-kernels = { path = "./mcandle/mcandle-kernels" }
+mcandle-metal-kernels = { path = "./mcandle/mcandle-metal-kernels" }
+mcandle-nn = { path = "./mcandle/mcandle-nn" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
@@ -49,5 +57,23 @@ rayon = "1.1.0"
url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
-float8 = "0.1.1"
+float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] }
regex = "1.10.6"
+byteorder = "1.4.3"
+cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
+gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
+libc = { version = "0.2.147" }
+memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
+yoke = { version = "0.7.2", features = ["derive"] }
+zip = { version = "1.1.1", default-features = false }
+metal = { version = "0.27.0", features = ["mps"]}
+num_cpus = "1.15.0"
+num-traits = "0.2.15"
+rand = "0.8.5"
+rand_distr = "0.4.3"
+safetensors = "0.4.1"
+thiserror = "1"
+ug = "0.0.2"
+ug-cuda = "0.0.2"
+ug-metal = "0.0.2"
+criterion = { version = "0.5.1", default-features=false }
diff --git a/README.md b/README.md
index 8b6a5be3f..af229409a 100644
--- a/README.md
+++ b/README.md
@@ -597,7 +597,7 @@ If you want to add a new model, please contact us via an issue and we can coordi
- For non-quantized models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device.
## Credits
-This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.
+This project would not be possible without the excellent work at [`mcandle`](https://github.com/huggingface/mcandle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.
⬆️ Back to Top
diff --git a/mcandle/mcandle-core/Cargo.toml b/mcandle/mcandle-core/Cargo.toml
new file mode 100644
index 000000000..44c9fd018
--- /dev/null
+++ b/mcandle/mcandle-core/Cargo.toml
@@ -0,0 +1,58 @@
+[package]
+name = "mcandle-core"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+readme = "README.md"
+
+[dependencies]
+accelerate-src = { workspace = true, optional = true }
+byteorder = { workspace = true }
+mcandle-kernels = { workspace = true, optional = true }
+mcandle-metal-kernels = { workspace = true, optional = true }
+metal = { workspace = true, optional = true}
+cudarc = { workspace = true, optional = true }
+gemm = { workspace = true }
+half = { workspace = true }
+float8 = { workspace = true }
+intel-mkl-src = { workspace = true, optional = true }
+libc = { workspace = true, optional = true }
+memmap2 = { workspace = true }
+num-traits = { workspace = true }
+num_cpus = { workspace = true }
+rand = { workspace = true }
+rand_distr = { workspace = true }
+rayon = { workspace = true }
+safetensors = { workspace = true }
+thiserror = { workspace = true }
+ug = { workspace = true }
+ug-cuda = { workspace = true, optional = true }
+ug-metal = { workspace = true, optional = true }
+yoke = { workspace = true }
+zip = { workspace = true }
+
+[dev-dependencies]
+anyhow = { workspace = true }
+clap = { workspace = true }
+criterion = { workspace = true }
+
+
+[features]
+default = []
+cuda = ["cudarc", "dep:mcandle-kernels", "float8/cuda", "dep:ug-cuda"]
+cudnn = ["cuda", "cudarc/cudnn"]
+mkl = ["dep:libc", "dep:intel-mkl-src"]
+accelerate = ["dep:libc", "dep:accelerate-src"]
+metal = ["dep:metal", "dep:mcandle-metal-kernels", "dep:ug-metal"]
+
+[[bench]]
+name = "bench_main"
+harness = false
+
+[[example]]
+name = "metal_basics"
+required-features = ["metal"]
diff --git a/mcandle/mcandle-core/LICENSE b/mcandle/mcandle-core/LICENSE
new file mode 100644
index 000000000..261eeb9e9
--- /dev/null
+++ b/mcandle/mcandle-core/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/mcandle/mcandle-core/README.md b/mcandle/mcandle-core/README.md
new file mode 100644
index 000000000..dd6c2572f
--- /dev/null
+++ b/mcandle/mcandle-core/README.md
@@ -0,0 +1,2 @@
+# mcandle
+Minimalist ML framework for Rust
diff --git a/mcandle/mcandle-core/benches/bench_main.rs b/mcandle/mcandle-core/benches/bench_main.rs
new file mode 100644
index 000000000..2e1816fd7
--- /dev/null
+++ b/mcandle/mcandle-core/benches/bench_main.rs
@@ -0,0 +1,12 @@
+mod benchmarks;
+
+use criterion::criterion_main;
+criterion_main!(
+ benchmarks::affine::benches,
+ benchmarks::matmul::benches,
+ benchmarks::random::benches,
+ benchmarks::where_cond::benches,
+ benchmarks::conv_transpose2d::benches,
+ benchmarks::qmatmul::benches,
+ benchmarks::unary::benches
+);
diff --git a/mcandle/mcandle-core/benches/benchmarks/affine.rs b/mcandle/mcandle-core/benches/benchmarks/affine.rs
new file mode 100644
index 000000000..34b50f528
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/affine.rs
@@ -0,0 +1,43 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn run(a: &Tensor) {
+ a.affine(12.34, 56.78).unwrap();
+}
+
+fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
+
+ let flops = b * m * k * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_affine_benchmark(c, &device, DType::F32, "affine_f32");
+ run_affine_benchmark(c, &device, DType::F16, "affine_f16");
+ run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs b/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs
new file mode 100644
index 000000000..b38f235a5
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs
@@ -0,0 +1,59 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn run(
+ x: &Tensor,
+ k: &Tensor,
+ padding: usize,
+ output_padding: usize,
+ stride: usize,
+ dilation: usize,
+) {
+ x.conv_transpose2d(k, padding, output_padding, stride, dilation)
+ .unwrap();
+}
+
+fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let t = Tensor::arange(0.0f32, 10000.0, device)
+ .unwrap()
+ .reshape((1, 4, 50, 50))
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+
+ let kernel = Tensor::arange(0.0f32, 100.0, device)
+ .unwrap()
+ .reshape((4, 1, 5, 5))
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+
+ let flops = t.dims().iter().product::() * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
+ run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
+ run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/matmul.rs b/mcandle/mcandle-core/benches/benchmarks/matmul.rs
new file mode 100644
index 000000000..498b82df2
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/matmul.rs
@@ -0,0 +1,44 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn run(a: &Tensor, b: &Tensor) {
+ a.matmul(&b.t().unwrap()).unwrap();
+}
+
+fn run_bench(c: &mut Criterion, device: &Device) {
+ let b = 1;
+ let m = 1;
+ let n = 2048;
+ let k = 2048;
+
+ let dtype = DType::F32;
+ let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
+ let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
+
+ let flops = b * m * n * k;
+
+ let mut group = c.benchmark_group(device.bench_name("matmul"));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_bench(c, &device);
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/mod.rs b/mcandle/mcandle-core/benches/benchmarks/mod.rs
new file mode 100644
index 000000000..986488e5f
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/mod.rs
@@ -0,0 +1,72 @@
+pub(crate) mod affine;
+pub(crate) mod conv_transpose2d;
+pub(crate) mod matmul;
+pub(crate) mod qmatmul;
+pub(crate) mod random;
+pub(crate) mod unary;
+pub(crate) mod where_cond;
+
+use mcandle_core::{Device, Result};
+
+pub(crate) trait BenchDevice {
+ fn sync(&self) -> Result<()>;
+
+ fn bench_name>(&self, name: S) -> String;
+}
+
+impl BenchDevice for Device {
+ fn sync(&self) -> Result<()> {
+ match self {
+ Device::Cpu => Ok(()),
+ Device::Cuda(device) => {
+ #[cfg(feature = "cuda")]
+ {
+ use cuda::WrapErr;
+ return Ok(device.synchronize().w()?);
+ }
+ #[cfg(not(feature = "cuda"))]
+ panic!("Cuda device without cuda feature enabled: {:?}", device)
+ }
+ Device::Metal(device) => {
+ #[cfg(feature = "metal")]
+ return Ok(device.wait_until_completed()?);
+ #[cfg(not(feature = "metal"))]
+ panic!("Metal device without metal feature enabled: {:?}", device)
+ }
+ }
+ }
+
+ fn bench_name>(&self, name: S) -> String {
+ match self {
+ Device::Cpu => {
+ let cpu_type = if cfg!(feature = "accelerate") {
+ "accelerate"
+ } else if cfg!(feature = "mkl") {
+ "mkl"
+ } else {
+ "cpu"
+ };
+ format!("{}_{}", cpu_type, name.into())
+ }
+ Device::Cuda(_) => format!("cuda_{}", name.into()),
+ Device::Metal(_) => format!("metal_{}", name.into()),
+ }
+ }
+}
+
+struct BenchDeviceHandler {
+ devices: Vec,
+}
+
+impl BenchDeviceHandler {
+ pub fn new() -> Result {
+ let mut devices = Vec::new();
+ if cfg!(feature = "metal") {
+ devices.push(Device::new_metal(0)?);
+ } else if cfg!(feature = "cuda") {
+ devices.push(Device::new_cuda(0)?);
+ }
+ devices.push(Device::Cpu);
+ Ok(Self { devices })
+ }
+}
diff --git a/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs b/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs
new file mode 100644
index 000000000..965266f4c
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs
@@ -0,0 +1,72 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{
+ quantized::{self, GgmlDType, QMatMul},
+ Device, Module, Tensor,
+};
+use std::time::Instant;
+
+fn run(matmul: &QMatMul, x: &Tensor) {
+ matmul.forward(x).unwrap();
+}
+
+fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
+ let b = 1;
+ let m = 1;
+ let n = 1024;
+ let k = 1024;
+
+ let lhs = (0..(m * k))
+ .map(|v| v as f32 / (m * k) as f32)
+ .collect::>();
+ let rhs = (0..(k * n))
+ .map(|v| v as f32 / (n * k) as f32)
+ .collect::>();
+
+ let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
+ let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
+
+ let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
+ let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
+
+ let flops = b * m * n * k;
+
+ let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
+ group.sample_size(200);
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&matmul), black_box(&lhs));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ for dtype in [
+ GgmlDType::F32,
+ GgmlDType::F16,
+ GgmlDType::Q4_0,
+ GgmlDType::Q4_1,
+ GgmlDType::Q5_0,
+ GgmlDType::Q5_1,
+ GgmlDType::Q8_0,
+ GgmlDType::Q2K,
+ GgmlDType::Q3K,
+ GgmlDType::Q4K,
+ GgmlDType::Q5K,
+ GgmlDType::Q6K,
+ ] {
+ run_bench(c, &device, dtype);
+ }
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/random.rs b/mcandle/mcandle-core/benches/benchmarks/random.rs
new file mode 100644
index 000000000..e0c74b5e0
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/random.rs
@@ -0,0 +1,63 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn rand_uniform(a: &Tensor) {
+ a.rand_like(-1.0, 123.0).unwrap();
+}
+
+fn rand_normal(a: &Tensor) {
+ a.randn_like(100.0, 15.0).unwrap();
+}
+
+fn run_random_bench(c: &mut Criterion, device: &Device) {
+ let b = 1;
+
+ let rows = 2048;
+ let cols = 2048;
+
+ let dtype = DType::F32;
+ let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
+
+ let flops = b * rows * cols * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name("random_uniform"));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |benches| {
+ benches.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ rand_uniform(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+
+ let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
+
+ let mut group = c.benchmark_group(device.bench_name("random_normal"));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |benches| {
+ benches.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ rand_normal(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_random_bench(c, &device);
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/unary.rs b/mcandle/mcandle-core/benches/benchmarks/unary.rs
new file mode 100644
index 000000000..f13b17948
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/unary.rs
@@ -0,0 +1,49 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn run(a: &Tensor) {
+ a.sqrt().unwrap();
+}
+
+fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let b = 1;
+ let m = 1024;
+ let k = 1024;
+
+ let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap()
+ .reshape((b, m, k))
+ .unwrap();
+
+ let flops = b * m * k * dtype.size_in_bytes();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&tensor));
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ for dtype in [DType::F32, DType::BF16, DType::F16] {
+ let name = format!("sqrt_{:?}", dtype);
+ run_unary_benchmark(c, &device, dtype, &name);
+ }
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/benches/benchmarks/where_cond.rs b/mcandle/mcandle-core/benches/benchmarks/where_cond.rs
new file mode 100644
index 000000000..2ed1f85c9
--- /dev/null
+++ b/mcandle/mcandle-core/benches/benchmarks/where_cond.rs
@@ -0,0 +1,64 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use criterion::{black_box, criterion_group, Criterion, Throughput};
+use mcandle_core::{DType, Device, Tensor};
+use std::time::Instant;
+
+fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
+ a.where_cond(b, c).unwrap();
+}
+
+const fn create_cond_arr() -> [u8; N] {
+ let mut arr = [0u8; N];
+ let mut i = 0;
+ while i < N {
+ arr[i] = (i % 2) as u8;
+ i += 1;
+ }
+ arr
+}
+
+const B: usize = 1;
+const M: usize = 1024;
+const K: usize = 1024;
+const SIZE: usize = B * M * K;
+
+const DATA: [u8; SIZE] = create_cond_arr::();
+
+fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
+ let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
+ let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
+
+ let elements = B * M * K;
+ // E.g. 2 f32 tensors + 1 u8 tensor
+ let flops = (2 * elements * dtype.size_in_bytes()) + elements;
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(
+ black_box(&tensor),
+ black_box(&on_true),
+ black_box(&on_false),
+ );
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let device = BenchDeviceHandler::new().unwrap();
+ for d in device.devices {
+ run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
+ run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
+ run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
diff --git a/mcandle/mcandle-core/examples/basics.rs b/mcandle/mcandle-core/examples/basics.rs
new file mode 100644
index 000000000..b88c36fa7
--- /dev/null
+++ b/mcandle/mcandle-core/examples/basics.rs
@@ -0,0 +1,17 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Result;
+use mcandle_core::{Device, Tensor};
+
+fn main() -> Result<()> {
+ let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
+ let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
+ let new_a = a.slice_scatter(&b, 1, 2)?;
+ assert_eq!(a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
+ assert_eq!(new_a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
+ Ok(())
+}
diff --git a/mcandle/mcandle-core/examples/cuda_basics.rs b/mcandle/mcandle-core/examples/cuda_basics.rs
new file mode 100644
index 000000000..2457714b1
--- /dev/null
+++ b/mcandle/mcandle-core/examples/cuda_basics.rs
@@ -0,0 +1,33 @@
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use mcandle_core::{Device, Tensor};
+
+fn main() -> Result<()> {
+ let device = Device::new_cuda(0)?;
+ let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
+ .to_dtype(mcandle_core::DType::BF16)?;
+ mcandle_core::cuda::set_gemm_reduced_precision_f32(false);
+ mcandle_core::cuda::set_gemm_reduced_precision_bf16(false);
+ let _x1 = x.matmul(&x)?;
+ drop(_x1);
+ let start_time = std::time::Instant::now();
+ let _x1 = x.matmul(&x)?;
+ device.synchronize()?;
+ println!("fp32: {:?}", start_time.elapsed());
+ drop(_x1);
+ mcandle_core::cuda::set_gemm_reduced_precision_f32(true);
+ mcandle_core::cuda::set_gemm_reduced_precision_bf16(true);
+ let _x1 = x.matmul(&x)?;
+ drop(_x1);
+ let start_time = std::time::Instant::now();
+ let _x1 = x.matmul(&x)?;
+ device.synchronize()?;
+ println!("tf32: {:?}", start_time.elapsed());
+ drop(_x1);
+ Ok(())
+}
diff --git a/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs b/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs
new file mode 100644
index 000000000..5104694d9
--- /dev/null
+++ b/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs
@@ -0,0 +1,54 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use std::str::FromStr;
+
+use anyhow::Result;
+use mcandle_core::{Device, Tensor};
+
+fn cos_sin(n: usize, device: &Device) -> Result {
+ let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();
+ let xs: Vec<_> = thetas.iter().map(|t| t.cos().abs()).collect();
+ let ys: Vec<_> = thetas.iter().map(|t| t.sin().abs()).collect();
+ let xs = Tensor::from_vec(xs, (n, 1), device)?;
+ let ys = Tensor::from_vec(ys, (1, n), device)?;
+ let ys = Tensor::cat(&[&ys, &ys, &ys, &ys, &ys, &ys], 1)?;
+ Ok(xs.matmul(&ys)?)
+}
+
+fn main() -> Result<()> {
+ let device = Device::new_cuda(0)?;
+ let args = std::env::args().collect::>();
+ let n = if args.len() < 2 {
+ 2000usize
+ } else {
+ usize::from_str(&args[1])?
+ };
+ let xys_cpu = cos_sin(n, &Device::Cpu)?;
+ let xys = cos_sin(n, &device)?;
+ println!("{xys_cpu:?} {xys:?}");
+ let sum_keepdim_cpu = xys_cpu.sum_keepdim(1)?;
+ println!("{sum_keepdim_cpu}");
+ let sum_keepdim = xys.sum_keepdim(1)?;
+ println!("{sum_keepdim}");
+ let start = std::time::Instant::now();
+ let n_iters = 100;
+ let mut v = 0f32;
+ for _i in 0..n_iters {
+ let sum_keepdim = xys.sum_keepdim(1)?;
+ let sum_keepdim = sum_keepdim.sum_keepdim(0)?;
+ let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
+ v += sum_keepdim;
+ }
+ let elapsed = start.elapsed();
+ if v > 0. {
+ println!(
+ "ran {n_iters} iterations, time per iter: {:?} ({v})",
+ elapsed.div_f64(n_iters as f64)
+ );
+ }
+ Ok(())
+}
diff --git a/mcandle/mcandle-core/examples/metal_basics.rs b/mcandle/mcandle-core/examples/metal_basics.rs
new file mode 100644
index 000000000..830b12b8d
--- /dev/null
+++ b/mcandle/mcandle-core/examples/metal_basics.rs
@@ -0,0 +1,28 @@
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use mcandle_core::{Device, Tensor};
+
+fn main() -> Result<()> {
+ // This requires the code to be run with MTL_CAPTURE_ENABLED=1
+ let device = Device::new_metal(0)?;
+ let metal_device = match &device {
+ Device::Metal(m) => m,
+ _ => anyhow::bail!("unexpected device"),
+ };
+ metal_device.capture("/tmp/mcandle.gputrace")?;
+ // This first synchronize ensures that a new command buffer gets created after setting up the
+ // capture scope.
+ device.synchronize()?;
+ let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
+ let x1 = x.add(&x)?;
+ println!("{x1:?}");
+ // This second synchronize ensures that the command buffer gets commited before the end of the
+ // capture scope.
+ device.synchronize()?;
+ Ok(())
+}
diff --git a/mcandle/mcandle-core/src/accelerate.rs b/mcandle/mcandle-core/src/accelerate.rs
new file mode 100644
index 000000000..d371d3b33
--- /dev/null
+++ b/mcandle/mcandle-core/src/accelerate.rs
@@ -0,0 +1,476 @@
+#![allow(dead_code)]
+use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
+
+mod ffi {
+ use super::*;
+ extern "C" {
+ // It would be nice to be able to switch to the NEWLAPACK version of the function but this
+ // seems to trigger some link error. Available function names can be seen here:
+ // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
+ #[link_name = "sgemm_"]
+ pub fn sgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_float,
+ a: *const c_float,
+ lda: *const c_int,
+ b: *const c_float,
+ ldb: *const c_int,
+ beta: *const c_float,
+ c: *mut c_float,
+ ldc: *const c_int,
+ );
+ #[link_name = "dgemm_"]
+ pub fn dgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_double,
+ a: *const c_double,
+ lda: *const c_int,
+ b: *const c_double,
+ ldb: *const c_int,
+ beta: *const c_double,
+ c: *mut c_double,
+ ldc: *const c_int,
+ );
+
+ pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
+ pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
+ pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
+ pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
+ pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
+ pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
+ pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
+
+ pub fn vDSP_vaddD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vadd(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vsubD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vsub(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vmulD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vmul(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vdivD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vdiv(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vminD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vmin(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vmaxD(
+ _: *const c_double,
+ _: c_long,
+ _: *const c_double,
+ _: c_long,
+ _: *mut c_double,
+ _: c_long,
+ _: c_ulong,
+ );
+ pub fn vDSP_vmax(
+ _: *const c_float,
+ _: c_long,
+ _: *const c_float,
+ _: c_long,
+ _: *mut c_float,
+ _: c_long,
+ _: c_ulong,
+ );
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn sgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f32,
+ a: &[f32],
+ lda: i32,
+ b: &[f32],
+ ldb: i32,
+ beta: f32,
+ c: &mut [f32],
+ ldc: i32,
+) {
+ ffi::sgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn dgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f64,
+ a: &[f64],
+ lda: i32,
+ b: &[f64],
+ ldb: i32,
+ beta: f64,
+ c: &mut [f64],
+ ldc: i32,
+) {
+ ffi::dgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[inline]
+pub fn vs_exp(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_exp(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vs_sin(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_sin(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+#[inline]
+pub fn vs_cos(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_cos(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+#[inline]
+pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vs_ln(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vd_ln(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
+}
+
+#[inline]
+pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
+}
+
+#[inline]
+pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
+ let a_len = a.len();
+ let y_len = y.len();
+ if a_len != y_len {
+ panic!("a and y have different lengths {a_len} <> {y_len}")
+ }
+ y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
+}
+
+#[inline]
+pub fn vs_tanh_inplace(y: &mut [f32]) {
+ unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vd_tanh_inplace(y: &mut [f64]) {
+ unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vs_exp_inplace(y: &mut [f32]) {
+ unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vd_exp_inplace(y: &mut [f64]) {
+ unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
+}
+
+#[inline]
+pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vs_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vd_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = -v
+ }
+ vs_exp_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = v / (1.0 + *y)
+ }
+}
+
+#[inline]
+pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = -v
+ }
+ vd_exp_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = v / (1.0 + *y)
+ }
+}
+
+macro_rules! binary_op {
+ ($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
+ #[inline]
+ pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
+ let a_len = a.len();
+ let b_len = b.len();
+ let y_len = y.len();
+ if a_len != y_len || b_len != y_len {
+ panic!(
+ "{} a,b,y len mismatch {a_len} {b_len} {y_len}",
+ stringify!($fn_name)
+ );
+ }
+ unsafe {
+ // Weird quirk of accelerate, the rhs comes before the lhs.
+ ffi::$accelerate_name(
+ b.as_ptr(),
+ 1,
+ a.as_ptr(),
+ 1,
+ y.as_mut_ptr(),
+ 1,
+ a_len as u64,
+ )
+ }
+ }
+ };
+}
+binary_op!(vs_add, f32, vDSP_vadd);
+binary_op!(vd_add, f64, vDSP_vaddD);
+binary_op!(vs_sub, f32, vDSP_vsub);
+binary_op!(vd_sub, f64, vDSP_vsubD);
+binary_op!(vs_mul, f32, vDSP_vmul);
+binary_op!(vd_mul, f64, vDSP_vmulD);
+binary_op!(vs_div, f32, vDSP_vdiv);
+binary_op!(vd_div, f64, vDSP_vdivD);
+binary_op!(vs_max, f32, vDSP_vmax);
+binary_op!(vd_max, f64, vDSP_vmaxD);
+binary_op!(vs_min, f32, vDSP_vmin);
+binary_op!(vd_min, f64, vDSP_vminD);
diff --git a/mcandle/mcandle-core/src/backend.rs b/mcandle/mcandle-core/src/backend.rs
new file mode 100644
index 000000000..655c7894d
--- /dev/null
+++ b/mcandle/mcandle-core/src/backend.rs
@@ -0,0 +1,165 @@
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
+use crate::{CpuStorage, DType, Layout, Result, Shape};
+
+pub trait BackendStorage: Sized {
+ type Device: BackendDevice;
+
+ fn try_clone(&self, _: &Layout) -> Result;
+
+ fn dtype(&self) -> DType;
+
+ fn device(&self) -> &Self::Device;
+
+ // Maybe this should return a Cow instead so that no copy is done on the cpu case.
+ fn to_cpu_storage(&self) -> Result;
+
+ fn affine(&self, _: &Layout, _: f64, _: f64) -> Result;
+
+ fn powf(&self, _: &Layout, _: f64) -> Result;
+
+ fn elu(&self, _: &Layout, _: f64) -> Result;
+
+ fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result;
+
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result;
+
+ fn to_dtype(&self, _: &Layout, _: DType) -> Result;
+
+ fn unary_impl(&self, _: &Layout) -> Result;
+
+ fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result;
+
+ fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result;
+
+ fn conv1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv1D,
+ ) -> Result;
+
+ fn conv_transpose1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConvTranspose1D,
+ ) -> Result;
+
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv2D,
+ ) -> Result;
+
+ fn conv_transpose2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConvTranspose2D,
+ ) -> Result;
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result;
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result;
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result;
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result;
+
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result;
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result;
+ fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result;
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result;
+
+ #[allow(clippy::too_many_arguments)]
+ fn matmul_with_alpha_beta(
+ &self,
+ _: &Self,
+ _: &mut Self,
+ _: Option,
+ _: (usize, usize, usize, usize),
+ _: &Layout,
+ _: &Layout,
+ _: &Layout,
+ ) -> Result<()>;
+
+ #[allow(clippy::too_many_arguments)]
+ fn matmul_with_alpha(
+ &self,
+ _: &Self,
+ _: Option,
+ _: (usize, usize, usize, usize),
+ _: &Layout,
+ _: &Layout,
+ ) -> Result;
+
+ fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
+
+ #[allow(clippy::too_many_arguments)]
+ // Similar to cudaMemcpy2D, though values are in elements and not in bytes.
+ fn copy2d(
+ &self,
+ _: &mut Self,
+ _d1: usize,
+ _d2: usize,
+ _src_stride1: usize,
+ _dst_stride1: usize,
+ _src_offset: usize,
+ _dst_offset: usize,
+ ) -> Result<()>;
+}
+
+pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
+ type Storage: BackendStorage;
+
+ // TODO: Make the usize generic and part of a generic DeviceLocation.
+ fn new(_: usize) -> Result;
+
+ fn location(&self) -> crate::DeviceLocation;
+
+ fn same_device(&self, _: &Self) -> bool;
+
+ fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result;
+
+ fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result;
+
+ /// # Safety
+ /// This function is unsafe as it doesn't initialize the underlying data store.
+ /// The caller should ensure that the data is properly initialized as early as possible
+ /// after this call.
+ unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result;
+
+ fn storage_from_slice(&self, _: &[T]) -> Result;
+
+ fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result;
+
+ fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result;
+
+ fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result;
+
+ fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result;
+
+ fn set_seed(&self, _: u64) -> Result<()>;
+ fn get_current_seed(&self) -> Result;
+
+ /// Synchronize should block until all the operations on the device are completed.
+ fn synchronize(&self) -> Result<()>;
+}
diff --git a/mcandle/mcandle-core/src/backprop.rs b/mcandle/mcandle-core/src/backprop.rs
new file mode 100644
index 000000000..d7169d631
--- /dev/null
+++ b/mcandle/mcandle-core/src/backprop.rs
@@ -0,0 +1,764 @@
+/// Methods for backpropagation of gradients.
+use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
+use crate::{Error, Result, Tensor, TensorId};
+use std::collections::HashMap;
+
+// arg has been reduced to node via reduce_dims, expand it back to arg.
+// This has to handle keepdims.
+fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result {
+ if arg.rank() == node.rank() {
+ // keepdim = true
+ node.broadcast_as(arg.shape())
+ } else {
+ // keepdim = false
+ // first expand the reduced dims.
+ node.reshape(reduced_dims)?.broadcast_as(arg.shape())
+ }
+}
+
+thread_local! {
+ static CANDLE_GRAD_DO_NOT_DETACH: bool = {
+ match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
+ Ok(s) => {
+ !s.is_empty() && s != "0"
+ },
+ Err(_) => false,
+ }
+ }
+}
+
+impl Tensor {
+ /// Return all the nodes that lead to this value in a topologically sorted vec, the first
+ /// elements having dependencies on the latter ones, e.g. the first element if any is the
+ /// argument.
+ /// This assumes that the op graph is a DAG.
+ fn sorted_nodes(&self) -> Vec<&Tensor> {
+ // The vec of sorted nodes is passed as an owned value rather than a mutable reference
+ // to get around some lifetime limitations.
+ fn walk<'a>(
+ node: &'a Tensor,
+ nodes: Vec<&'a Tensor>,
+ already_seen: &mut HashMap,
+ ) -> (bool, Vec<&'a Tensor>) {
+ if let Some(&tg) = already_seen.get(&node.id()) {
+ return (tg, nodes);
+ }
+ let mut track_grad = false;
+ let mut nodes = if node.is_variable() {
+ // Do not call recursively on the "leaf" nodes.
+ track_grad = true;
+ nodes
+ } else if node.dtype().is_int() {
+ nodes
+ } else if let Some(op) = node.op() {
+ match op {
+ Op::IndexAdd(t1, t2, t3, _)
+ | Op::ScatterAdd(t1, t2, t3, _)
+ | Op::CustomOp3(t1, t2, t3, _)
+ | Op::WhereCond(t1, t2, t3) => {
+ let (tg, nodes) = walk(t1, nodes, already_seen);
+ track_grad |= tg;
+ let (tg, nodes) = walk(t2, nodes, already_seen);
+ track_grad |= tg;
+ let (tg, nodes) = walk(t3, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ Op::Conv1D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
+ | Op::ConvTranspose1D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
+ | Op::Conv2D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
+ | Op::ConvTranspose2D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
+ | Op::CustomOp2(lhs, rhs, _)
+ | Op::Binary(lhs, rhs, _)
+ | Op::Gather(lhs, rhs, _)
+ | Op::IndexSelect(lhs, rhs, _)
+ | Op::Matmul(lhs, rhs)
+ | Op::SliceScatter0(lhs, rhs, _) => {
+ let (tg, nodes) = walk(lhs, nodes, already_seen);
+ track_grad |= tg;
+ let (tg, nodes) = walk(rhs, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
+ let (tg, nodes) = walk(arg, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }),
+ Op::Affine { arg, mul, .. } => {
+ if *mul == 0. {
+ nodes
+ } else {
+ let (tg, nodes) = walk(arg, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ }
+ Op::Unary(_node, UnaryOp::Ceil)
+ | Op::Unary(_node, UnaryOp::Floor)
+ | Op::Unary(_node, UnaryOp::Round)
+ | Op::Unary(_node, UnaryOp::Sign) => nodes,
+ Op::Reshape(node)
+ | Op::UpsampleNearest1D { arg: node, .. }
+ | Op::UpsampleNearest2D { arg: node, .. }
+ | Op::AvgPool2D { arg: node, .. }
+ | Op::MaxPool2D { arg: node, .. }
+ | Op::Copy(node)
+ | Op::Broadcast(node)
+ | Op::Cmp(node, _)
+ | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
+ | Op::ToDevice(node)
+ | Op::Transpose(node, _, _)
+ | Op::Permute(node, _)
+ | Op::Narrow(node, _, _, _)
+ | Op::Unary(node, _)
+ | Op::Elu(node, _)
+ | Op::Powf(node, _)
+ | Op::CustomOp1(node, _) => {
+ let (tg, nodes) = walk(node, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ Op::ToDType(node) => {
+ if node.dtype().is_float() {
+ let (tg, nodes) = walk(node, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ } else {
+ nodes
+ }
+ }
+ Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
+ }
+ } else {
+ nodes
+ };
+ already_seen.insert(node.id(), track_grad);
+ if track_grad {
+ nodes.push(node);
+ }
+ (track_grad, nodes)
+ }
+ let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
+ nodes.reverse();
+ nodes
+ }
+
+ pub fn backward(&self) -> Result {
+ let sorted_nodes = self.sorted_nodes();
+ let mut grads = GradStore::new();
+ grads.insert(self, self.ones_like()?.contiguous()?);
+ for node in sorted_nodes.iter() {
+ if node.is_variable() {
+ continue;
+ }
+ let grad = grads
+ .remove(node)
+ .expect("mcandle internal error - grad not populated");
+ // https://github.com/huggingface/mcandle/issues/1241
+ // Ideally, we would make these operations in place where possible to ensure that we
+ // do not have to allocate too often. Here we just call `.detach` to avoid computing
+ // the backprop graph of the backprop itself. This would be an issue for second order
+ // derivatives but these are out of scope at the moment.
+ let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
+ let grad = if do_not_detach { grad } else { grad.detach() };
+ if let Some(op) = node.op() {
+ match op {
+ Op::Binary(lhs, rhs, BinaryOp::Add) => {
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
+ }
+ Op::Binary(lhs, rhs, BinaryOp::Sub) => {
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
+ }
+ Op::Binary(lhs, rhs, BinaryOp::Mul) => {
+ let lhs_grad = grad.mul(rhs)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+ let rhs_grad = grad.mul(lhs)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
+ Op::Binary(lhs, rhs, BinaryOp::Div) => {
+ let lhs_grad = grad.div(rhs)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+ let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
+ }
+ Op::Binary(lhs, rhs, BinaryOp::Minimum)
+ | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
+ let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
+ let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
+
+ // If both masks are 1 one the same point, we want to scale the
+ // gradient by 0.5 rather than 1.
+ let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
+ Op::WhereCond(pred, t, f) => {
+ let zeros = grad.zeros_like()?;
+ let t_sum_grad = grads.or_insert(t)?;
+ let t_grad = pred.where_cond(&grad, &zeros)?;
+ *t_sum_grad = t_sum_grad.add(&t_grad)?;
+ let f_sum_grad = grads.or_insert(f)?;
+ let f_grad = pred.where_cond(&zeros, &grad)?;
+ *f_sum_grad = f_sum_grad.add(&f_grad)?;
+ }
+ Op::Conv1D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ dilation,
+ } => {
+ // The output height for conv_transpose1d is:
+ // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
+ let grad_l_in = grad.dim(2)?;
+ let k_size = kernel.dim(2)?;
+ let out_size =
+ (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
+ let out_padding = arg.dim(2)? - out_size;
+ let grad_arg = grad.conv_transpose1d(
+ kernel,
+ *padding,
+ out_padding,
+ *stride,
+ *dilation,
+ /* groups */ 1,
+ )?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+
+ let grad_kernel = arg
+ .transpose(0, 1)?
+ .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
+ .transpose(0, 1)?;
+ let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0) = kernel.dims3()?;
+ let (_, _, g_k0) = grad_kernel.dims3()?;
+ let grad_kernel = if g_k0 != k0 {
+ grad_kernel.narrow(2, 0, k0)?
+ } else {
+ grad_kernel
+ };
+ *sum_grad = sum_grad.add(&grad_kernel)?;
+ }
+ Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ dilation,
+ } => {
+ // The output height for conv_transpose2d is:
+ // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
+ let grad_h = grad.dim(2)?;
+ let k_h = kernel.dim(2)?;
+ let out_size =
+ (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
+ let out_padding = arg.dim(2)? - out_size;
+ let grad_arg = grad.conv_transpose2d(
+ kernel,
+ *padding,
+ out_padding,
+ *stride,
+ *dilation,
+ )?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+
+ let grad_kernel = arg
+ .transpose(0, 1)?
+ .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
+ .transpose(0, 1)?;
+ let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
+ *sum_grad = sum_grad.add(&grad_kernel)?;
+ }
+ Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
+ op: "conv-transpose1d",
+ })?,
+ Op::ConvTranspose2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ dilation,
+ output_padding: _output_padding,
+ } => {
+ let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+
+ let grad_kernel = grad
+ .transpose(0, 1)?
+ .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
+ .transpose(0, 1)?;
+ let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
+ *sum_grad = sum_grad.add(&grad_kernel)?;
+ }
+ Op::AvgPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ let grad_arg = grad.upsample_nearest2d(h, w)?;
+ let grad_arg =
+ (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
+ Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ // For computing the max-pool gradient, we compute a mask where a 1 means
+ // that the element is the maximum, then we apply this mask to the
+ // upsampled gradient (taking into account that multiple max may exist so
+ // we scale the gradient for this case).
+ let node_upsampled = node.upsample_nearest2d(h, w)?;
+ let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
+ let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
+ let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
+ Op::UpsampleNearest1D { arg, target_size } => {
+ let (_n, c, size) = arg.dims3()?;
+ if target_size % size != 0 {
+ crate::bail!("backward not supported for non integer upscaling factors")
+ }
+ let scale = target_size / size;
+
+ let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
+ let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = conv_sum;
+ }
+ Op::UpsampleNearest2D {
+ arg,
+ target_h,
+ target_w,
+ } => {
+ let (_n, c, h, w) = arg.dims4()?;
+ if target_h % h != 0 || target_w % w != 0 {
+ crate::bail!("backward not supported for non integer upscaling factors")
+ }
+ let scale_h = target_h / h;
+ let scale_w = target_w / w;
+
+ if scale_h != scale_w {
+ crate::bail!("backward not supported for non uniform upscaling factors")
+ };
+ let kernel =
+ Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
+ let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = conv_sum;
+ }
+ Op::SliceScatter0(lhs, rhs, start_rhs) => {
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
+ }
+ Op::Gather(arg, indexes, dim) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
+ }
+ Op::ScatterAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.gather(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
+ Op::IndexAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.index_select(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
+ Op::IndexSelect(arg, indexes, dim) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
+ }
+ Op::Matmul(lhs, rhs) => {
+ // Skipping checks, the op went ok, we can skip
+ // the matmul size checks for now.
+
+ let lhs_grad = grad.matmul(&rhs.t()?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = lhs.t()?.matmul(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
+ Op::Cat(args, dim) => {
+ let mut start_idx = 0;
+ for arg in args {
+ let len = arg.dims()[*dim];
+ let arg_grad = grad.narrow(*dim, start_idx, len)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?;
+ start_idx += len;
+ }
+ }
+ Op::Broadcast(arg) => {
+ let arg_dims = arg.dims();
+ let node_dims = node.dims();
+ // The number of dims that have been inserted on the left.
+ let left_dims = node_dims.len() - arg_dims.len();
+ let mut sum_dims: Vec = (0..left_dims).collect();
+ for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
+ .iter()
+ .zip(arg_dims.iter())
+ .enumerate()
+ {
+ if node_dim != arg_dim {
+ sum_dims.push(dim + left_dims)
+ }
+ }
+
+ let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
+ for _i in 0..left_dims {
+ arg_grad = arg_grad.squeeze(0)?
+ }
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
+ }
+ Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
+ let grad = broadcast_back(arg, &grad, reduced_dims)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad)?;
+ }
+ Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
+ let node = broadcast_back(arg, node, reduced_dims)?;
+ let grad = broadcast_back(arg, &grad, reduced_dims)?;
+ let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
+ }
+ Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
+ let node = broadcast_back(arg, node, reduced_dims)?;
+ let grad = broadcast_back(arg, &grad, reduced_dims)?;
+ let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
+ }
+ Op::ToDType(arg) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
+ }
+ Op::Copy(arg) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad)?
+ }
+ Op::Affine { arg, mul, .. } => {
+ let arg_grad = grad.affine(*mul, 0.)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Unary(arg, UnaryOp::Log) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&(grad / arg)?)?
+ }
+ Op::Unary(arg, UnaryOp::Sin) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
+ }
+ Op::Unary(arg, UnaryOp::Cos) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
+ }
+ Op::Unary(arg, UnaryOp::Tanh) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let minus_dtanh = (node.sqr()? - 1.)?;
+ *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
+ }
+ Op::Unary(arg, UnaryOp::Abs) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let ones = arg.ones_like()?;
+ let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
+ *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
+ }
+ Op::Unary(arg, UnaryOp::Exp) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&(&grad * *node)?)?
+ }
+ Op::Unary(arg, UnaryOp::Neg) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.sub(&grad)?
+ }
+ Op::Unary(arg, UnaryOp::Recip) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let grad = (grad / arg.sqr()?)?;
+ *sum_grad = sum_grad.sub(&grad)?
+ }
+ &Op::Narrow(ref arg, dim, start_idx, len) => {
+ let arg_dims = arg.dims();
+ let left_pad = if start_idx == 0 {
+ None
+ } else {
+ let mut dims = arg_dims.to_vec();
+ dims[dim] = start_idx;
+ Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
+ };
+ let right_pad = arg_dims[dim] - start_idx - len;
+ let right_pad = if right_pad == 0 {
+ None
+ } else {
+ let mut dims = arg_dims.to_vec();
+ dims[dim] = right_pad;
+ Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
+ };
+ let arg_grad = match (left_pad, right_pad) {
+ (None, None) => grad,
+ (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
+ (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
+ (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
+ };
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Unary(_, UnaryOp::Floor)
+ | Op::Unary(_, UnaryOp::Round)
+ | Op::Reduce(_, ReduceOp::ArgMin, _)
+ | Op::Reduce(_, ReduceOp::ArgMax, _)
+ | Op::Unary(_, UnaryOp::Sign)
+ | Op::Cmp(_, _) => {}
+ Op::Reshape(arg) => {
+ let arg_grad = grad.reshape(arg.dims())?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
+ Op::Unary(arg, UnaryOp::Gelu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let cube = arg.powf(3.)?;
+ let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
+ let gelu_grad = (((0.5 * &tanh)?
+ + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ + 0.5)?;
+ *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
+ }
+ Op::Unary(arg, UnaryOp::Erf) => {
+ let sum_grad = grads.or_insert(arg)?;
+ // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
+ let erf_grad =
+ (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
+ *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
+ }
+ Op::Unary(arg, UnaryOp::GeluErf) => {
+ let sum_grad = grads.or_insert(arg)?;
+ // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
+ let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
+ let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
+ let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
+ let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
+ let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
+ *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
+ }
+ Op::Unary(arg, UnaryOp::Relu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
+ *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
+ }
+ Op::Unary(arg, UnaryOp::Silu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
+ let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
+ let silu_grad = &sigmoid_arg * (1. - *node) + *node;
+ *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
+ }
+ Op::Elu(arg, alpha) => {
+ // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
+ let sum_grad = grads.or_insert(arg)?;
+ let zeros = arg.zeros_like()?;
+ let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
+ let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
+ // node == alpha * (e^x - 1) for x <= 0, reuse it
+ let negative_exp_mask = (negative_mask * (*node + *alpha))?;
+ let combined_mask = (positive_mask + negative_exp_mask)?;
+ *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
+ }
+ Op::Powf(arg, e) => {
+ let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::CustomOp1(arg, c) => {
+ if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ }
+ Op::CustomOp2(arg1, arg2, c) => {
+ let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
+ if let Some(arg_grad1) = arg_grad1 {
+ let sum_grad = grads.or_insert(arg1)?;
+ *sum_grad = sum_grad.add(&arg_grad1)?
+ }
+ if let Some(arg_grad2) = arg_grad2 {
+ let sum_grad = grads.or_insert(arg2)?;
+ *sum_grad = sum_grad.add(&arg_grad2)?
+ }
+ }
+ Op::CustomOp3(arg1, arg2, arg3, c) => {
+ let (arg_grad1, arg_grad2, arg_grad3) =
+ c.bwd(arg1, arg2, arg3, node, &grad)?;
+ if let Some(arg_grad1) = arg_grad1 {
+ let sum_grad = grads.or_insert(arg1)?;
+ *sum_grad = sum_grad.add(&arg_grad1)?
+ }
+ if let Some(arg_grad2) = arg_grad2 {
+ let sum_grad = grads.or_insert(arg2)?;
+ *sum_grad = sum_grad.add(&arg_grad2)?
+ }
+ if let Some(arg_grad3) = arg_grad3 {
+ let sum_grad = grads.or_insert(arg3)?;
+ *sum_grad = sum_grad.add(&arg_grad3)?
+ }
+ }
+ Op::Unary(arg, UnaryOp::Sqr) => {
+ let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Unary(arg, UnaryOp::Sqrt) => {
+ let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::ToDevice(arg) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let arg_grad = grad.to_device(sum_grad.device())?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Transpose(arg, dim1, dim2) => {
+ let arg_grad = grad.transpose(*dim1, *dim2)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ Op::Permute(arg, dims) => {
+ let mut inv_dims = vec![0; dims.len()];
+ for (i, &dim_idx) in dims.iter().enumerate() {
+ inv_dims[dim_idx] = i
+ }
+ let arg_grad = grad.permute(inv_dims)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
+ };
+ }
+ }
+ Ok(grads)
+ }
+}
+
+/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
+#[derive(Debug)]
+pub struct GradStore(HashMap);
+
+impl GradStore {
+ /// Create a new gradient store
+ fn new() -> Self {
+ GradStore(HashMap::new())
+ }
+
+ /// Get the gradient tensor corresponding to the given tensor id
+ pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
+ self.0.get(&id)
+ }
+
+ /// Get the gradient tensor associated with the given tensor
+ pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
+ self.0.get(&tensor.id())
+ }
+
+ /// Remove the gradient tensor associated with the given tensor, returning it if it exists
+ pub fn remove(&mut self, tensor: &Tensor) -> Option {
+ self.0.remove(&tensor.id())
+ }
+
+ /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
+ pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option {
+ self.0.insert(tensor.id(), grad)
+ }
+
+ /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
+ /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
+ fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
+ use std::collections::hash_map::Entry;
+ let grad = match self.0.entry(tensor.id()) {
+ Entry::Occupied(entry) => entry.into_mut(),
+ Entry::Vacant(entry) => {
+ let grad = tensor.zeros_like()?;
+ entry.insert(grad)
+ }
+ };
+ Ok(grad)
+ }
+
+ /// Get the tensor ids of the stored gradient tensors
+ pub fn get_ids(&self) -> impl Iterator- {
+ self.0.keys()
+ }
+}
diff --git a/mcandle/mcandle-core/src/conv.rs b/mcandle/mcandle-core/src/conv.rs
new file mode 100644
index 000000000..7b3922dd7
--- /dev/null
+++ b/mcandle/mcandle-core/src/conv.rs
@@ -0,0 +1,358 @@
+use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConv1D {
+ pub(crate) b_size: usize,
+ // Maybe we should have a version without l_in as this bit depends on the input and not only on
+ // the weights.
+ pub(crate) l_in: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) k_size: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+ pub(crate) dilation: usize,
+}
+
+impl ParamsConv1D {
+ pub(crate) fn l_out(&self) -> usize {
+ (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec {
+ let l_out = self.l_out();
+ vec![self.b_size, self.c_out, l_out]
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConvTranspose1D {
+ pub(crate) b_size: usize,
+ pub(crate) l_in: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) k_size: usize,
+ pub(crate) padding: usize,
+ pub(crate) output_padding: usize,
+ pub(crate) stride: usize,
+ pub(crate) dilation: usize,
+}
+
+impl ParamsConvTranspose1D {
+ pub(crate) fn l_out(&self) -> usize {
+ (self.l_in - 1) * self.stride - 2 * self.padding
+ + self.dilation * (self.k_size - 1)
+ + self.output_padding
+ + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec {
+ let l_out = self.l_out();
+ vec![self.b_size, self.c_out, l_out]
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum CudnnFwdAlgo {
+ ImplicitGemm,
+ ImplicitPrecompGemm,
+ Gemm,
+ Direct,
+ Fft,
+ FftTiling,
+ Winograd,
+ WinogradNonFused,
+ Count,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConv2D {
+ pub(crate) b_size: usize,
+ pub(crate) i_h: usize,
+ pub(crate) i_w: usize,
+ pub(crate) k_h: usize,
+ pub(crate) k_w: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+ pub(crate) dilation: usize,
+ pub cudnn_fwd_algo: Option,
+}
+
+impl ParamsConv2D {
+ pub(crate) fn out_h(&self) -> usize {
+ (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_w(&self) -> usize {
+ (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec {
+ vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConvTranspose2D {
+ pub(crate) b_size: usize,
+ pub(crate) i_h: usize,
+ pub(crate) i_w: usize,
+ pub(crate) k_h: usize,
+ pub(crate) k_w: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) padding: usize,
+ pub(crate) output_padding: usize,
+ pub(crate) stride: usize,
+ pub(crate) dilation: usize,
+}
+
+impl ParamsConvTranspose2D {
+ pub(crate) fn out_h(&self) -> usize {
+ (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
+ - 2 * self.padding
+ }
+
+ pub(crate) fn out_w(&self) -> usize {
+ (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
+ - 2 * self.padding
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec {
+ vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
+ }
+}
+
+impl Tensor {
+ fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result {
+ let storage =
+ self.storage()
+ .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 1D convolution over the input tensor.
+ pub fn conv1d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ ) -> Result {
+ let (c_out, c_in_k, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
+ if c_in != c_in_k * groups {
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ }
+ .bt())?
+ }
+
+ let params = ParamsConv1D {
+ b_size,
+ l_in,
+ c_out: c_out / groups,
+ c_in: c_in / groups,
+ k_size,
+ padding,
+ stride,
+ dilation,
+ };
+ if groups == 1 {
+ self.conv1d_single_group(kernel, ¶ms)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
+ let blocks = blocks
+ .iter()
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
+ .collect::>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ fn conv_transpose1d_single_group(
+ &self,
+ kernel: &Self,
+ params: &ParamsConvTranspose1D,
+ ) -> Result {
+ let storage = self.storage().conv_transpose1d(
+ self.layout(),
+ &kernel.storage(),
+ kernel.layout(),
+ params,
+ )?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ output_padding: params.output_padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 1D transposed convolution over the input tensor.
+ pub fn conv_transpose1d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ output_padding: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ ) -> Result {
+ let (c_in_k, c_out, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ if c_in % groups != 0 {
+ crate::bail!("in_channel {c_in} is not divisible by the number of groups")
+ }
+ let params = ParamsConvTranspose1D {
+ b_size,
+ l_in,
+ k_size,
+ c_out,
+ c_in: c_in / groups,
+ padding,
+ output_padding,
+ stride,
+ dilation,
+ };
+ if groups == 1 {
+ self.conv_transpose1d_single_group(kernel, ¶ms)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
+ let blocks = blocks
+ .iter()
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
+ .collect::>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result {
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 2D convolution over the input tensor.
+ pub fn conv2d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ ) -> Result {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k * groups {
+ crate::bail!(
+ "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
+ )
+ }
+ let params = ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out: c_out / groups,
+ c_in: c_in / groups,
+ padding,
+ stride,
+ dilation,
+ cudnn_fwd_algo: None,
+ };
+ if groups == 1 {
+ self.conv2d_single_group(kernel, ¶ms)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
+ let blocks = blocks
+ .iter()
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
+ .collect::>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ /// Applies a 2D transposed convolution over the input tensor.
+ pub fn conv_transpose2d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ output_padding: usize,
+ stride: usize,
+ dilation: usize,
+ ) -> Result {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k {
+ crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
+ }
+ let params = ParamsConvTranspose2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ output_padding,
+ stride,
+ dilation,
+ };
+ let storage = self.storage().conv_transpose2d(
+ self.layout(),
+ &kernel.storage(),
+ kernel.layout(),
+ ¶ms,
+ )?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
+ arg,
+ kernel,
+ padding: params.padding,
+ output_padding: params.output_padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+}
diff --git a/mcandle/mcandle-core/src/convert.rs b/mcandle/mcandle-core/src/convert.rs
new file mode 100644
index 000000000..173a96d6e
--- /dev/null
+++ b/mcandle/mcandle-core/src/convert.rs
@@ -0,0 +1,161 @@
+//! Implement conversion traits for tensors
+use crate::{DType, Device, Error, Tensor, WithDType};
+use float8::F8E4M3;
+use half::{bf16, f16, slice::HalfFloatSliceExt};
+use std::convert::TryFrom;
+
+impl TryFrom<&Tensor> for Vec {
+ type Error = Error;
+ fn try_from(tensor: &Tensor) -> Result {
+ tensor.to_vec1::()
+ }
+}
+
+impl TryFrom<&Tensor> for Vec> {
+ type Error = Error;
+ fn try_from(tensor: &Tensor) -> Result {
+ tensor.to_vec2::()
+ }
+}
+
+impl TryFrom<&Tensor> for Vec>> {
+ type Error = Error;
+ fn try_from(tensor: &Tensor) -> Result {
+ tensor.to_vec3::()
+ }
+}
+
+impl TryFrom for Vec {
+ type Error = Error;
+ fn try_from(tensor: Tensor) -> Result {
+ Vec::::try_from(&tensor)
+ }
+}
+
+impl TryFrom for Vec> {
+ type Error = Error;
+ fn try_from(tensor: Tensor) -> Result {
+ Vec::>::try_from(&tensor)
+ }
+}
+
+impl TryFrom for Vec>> {
+ type Error = Error;
+ fn try_from(tensor: Tensor) -> Result {
+ Vec::>>::try_from(&tensor)
+ }
+}
+
+impl TryFrom<&[T]> for Tensor {
+ type Error = Error;
+ fn try_from(v: &[T]) -> Result {
+ Tensor::from_slice(v, v.len(), &Device::Cpu)
+ }
+}
+
+impl TryFrom> for Tensor {
+ type Error = Error;
+ fn try_from(v: Vec) -> Result {
+ let len = v.len();
+ Tensor::from_vec(v, len, &Device::Cpu)
+ }
+}
+
+macro_rules! from_tensor {
+ ($typ:ident) => {
+ impl TryFrom<&Tensor> for $typ {
+ type Error = Error;
+
+ fn try_from(tensor: &Tensor) -> Result {
+ tensor.to_scalar::<$typ>()
+ }
+ }
+
+ impl TryFrom for $typ {
+ type Error = Error;
+
+ fn try_from(tensor: Tensor) -> Result {
+ $typ::try_from(&tensor)
+ }
+ }
+
+ impl TryFrom<$typ> for Tensor {
+ type Error = Error;
+
+ fn try_from(v: $typ) -> Result {
+ Tensor::new(v, &Device::Cpu)
+ }
+ }
+ };
+}
+
+from_tensor!(f64);
+from_tensor!(f32);
+from_tensor!(f16);
+from_tensor!(bf16);
+from_tensor!(i64);
+from_tensor!(u32);
+from_tensor!(u8);
+
+impl Tensor {
+ pub fn write_bytes(&self, f: &mut W) -> crate::Result<()> {
+ use byteorder::{LittleEndian, WriteBytesExt};
+
+ let vs = self.flatten_all()?;
+ match self.dtype() {
+ DType::BF16 => {
+ let vs = vs.to_vec1::()?;
+ for &v in vs.reinterpret_cast() {
+ f.write_u16::(v)?
+ }
+ }
+ DType::F16 => {
+ let vs = vs.to_vec1::()?;
+ for &v in vs.reinterpret_cast() {
+ f.write_u16::(v)?
+ }
+ }
+ DType::F32 => {
+ // TODO: Avoid using a buffer when data is already on the CPU.
+ for v in vs.to_vec1::()? {
+ f.write_f32::(v)?
+ }
+ }
+ DType::F64 => {
+ for v in vs.to_vec1::()? {
+ f.write_f64::(v)?
+ }
+ }
+ DType::U32 => {
+ for v in vs.to_vec1::()? {
+ f.write_u32::(v)?
+ }
+ }
+ DType::I16 => {
+ for v in vs.to_vec1::()? {
+ f.write_i16::(v)?
+ }
+ }
+ DType::I32 => {
+ for v in vs.to_vec1::()? {
+ f.write_i32::(v)?
+ }
+ }
+ DType::I64 => {
+ for v in vs.to_vec1::()? {
+ f.write_i64::(v)?
+ }
+ }
+ DType::U8 => {
+ let vs = vs.to_vec1::()?;
+ f.write_all(&vs)?;
+ }
+ DType::F8E4M3 => {
+ for v in vs.to_vec1::()? {
+ f.write_u8(v.to_bits())?
+ }
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/mcandle/mcandle-core/src/cpu/avx.rs b/mcandle/mcandle-core/src/cpu/avx.rs
new file mode 100644
index 000000000..113fc14ce
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/avx.rs
@@ -0,0 +1,227 @@
+use super::{Cpu, CpuBF16, CpuF16};
+#[cfg(target_arch = "x86")]
+use core::arch::x86::*;
+#[cfg(target_arch = "x86_64")]
+use core::arch::x86_64::*;
+
+use half::{bf16, f16};
+
+pub struct CurrentCpu {}
+
+const STEP: usize = 32;
+const EPR: usize = 8;
+const ARR: usize = STEP / EPR;
+
+impl Cpu for CurrentCpu {
+ type Unit = __m256;
+ type Array = [__m256; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ _mm256_setzero_ps()
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn from_f32(v: f32) -> Self::Unit {
+ _mm256_set1_ps(v)
+ }
+
+ unsafe fn load(mem_addr: *const f32) -> Self::Unit {
+ _mm256_loadu_ps(mem_addr)
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(_mm256_mul_ps(b, c), a)
+ }
+
+ unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
+ _mm256_storeu_ps(mem_addr, a);
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ for i in 0..ARR / 2 {
+ x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
+ }
+ for i in 0..ARR / 4 {
+ x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
+ }
+ #[allow(clippy::reversed_empty_ranges)]
+ for i in 0..ARR / 8 {
+ x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
+ }
+ let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
+ let t1 = _mm_hadd_ps(t0, t0);
+ *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+ }
+}
+
+pub struct CurrentCpuF16 {}
+impl CpuF16 for CurrentCpuF16 {
+ type Unit = __m256;
+ type Array = [__m256; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ _mm256_setzero_ps()
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn from_f32(v: f32) -> Self::Unit {
+ _mm256_set1_ps(v)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn load(mem_addr: *const f16) -> Self::Unit {
+ _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn load(mem_addr: *const f16) -> Self::Unit {
+ let mut tmp = [0.0f32; 8];
+ for i in 0..8 {
+ tmp[i] = (*mem_addr.add(i)).to_f32();
+ }
+ _mm256_loadu_ps(tmp.as_ptr())
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(_mm256_mul_ps(b, c), a)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
+ _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
+ let mut tmp = [0.0f32; 8];
+ _mm256_storeu_ps(tmp.as_mut_ptr(), a);
+ for i in 0..8 {
+ *mem_addr.add(i) = f16::from_f32(tmp[i]);
+ }
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ let mut offset = ARR >> 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
+ let t1 = _mm_hadd_ps(t0, t0);
+ *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+ }
+}
+
+pub struct CurrentCpuBF16 {}
+impl CpuBF16 for CurrentCpuBF16 {
+ type Unit = __m256;
+ type Array = [__m256; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ _mm256_setzero_ps()
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn from_f32(v: f32) -> Self::Unit {
+ _mm256_set1_ps(v)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
+ _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
+ let mut tmp = [0.0f32; 8];
+ for i in 0..8 {
+ tmp[i] = (*mem_addr.add(i)).to_f32();
+ }
+ _mm256_loadu_ps(tmp.as_ptr())
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ _mm256_add_ps(_mm256_mul_ps(b, c), a)
+ }
+
+ #[cfg(target_feature = "f16c")]
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
+ _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
+ }
+
+ #[cfg(not(target_feature = "f16c"))]
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
+ let mut tmp = [0.0f32; 8];
+ _mm256_storeu_ps(tmp.as_mut_ptr(), a);
+ for i in 0..8 {
+ *mem_addr.add(i) = bf16::from_f32(tmp[i]);
+ }
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ let mut offset = ARR >> 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ offset >>= 1;
+ for i in 0..offset {
+ x[i] = _mm256_add_ps(x[i], x[offset + i]);
+ }
+ let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
+ let t1 = _mm_hadd_ps(t0, t0);
+ *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
+ }
+}
diff --git a/mcandle/mcandle-core/src/cpu/erf.rs b/mcandle/mcandle-core/src/cpu/erf.rs
new file mode 100644
index 000000000..ca6be53fd
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/erf.rs
@@ -0,0 +1,763 @@
+#![allow(clippy::excessive_precision)]
+// Code taken from https://github.com/statrs-dev/statrs
+//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
+//! related functions
+
+mod evaluate {
+ //! Provides functions that don't have a numerical solution and must
+ //! be solved computationally (e.g. evaluation of a polynomial)
+
+ /// evaluates a polynomial at `z` where `coeff` are the coeffecients
+ /// to a polynomial of order `k` where `k` is the length of `coeff` and the
+ /// coeffecient
+ /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
+ /// `2z^2 - z + 3`
+ ///
+ /// # Remarks
+ ///
+ /// Returns 0 for a 0 length coefficient slice
+ pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
+ let n = coeff.len();
+ if n == 0 {
+ return 0.0;
+ }
+
+ let mut sum = *coeff.last().unwrap();
+ for c in coeff[0..n - 1].iter().rev() {
+ sum = *c + z * sum;
+ }
+ sum
+ }
+}
+use std::f64;
+
+/// `erf` calculates the error function at `x`.
+pub fn erf(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x >= 0.0 && x.is_infinite() {
+ 1.0
+ } else if x <= 0.0 && x.is_infinite() {
+ -1.0
+ } else if x == 0. {
+ 0.0
+ } else {
+ erf_impl(x, false)
+ }
+}
+
+/// `erf_inv` calculates the inverse error function
+/// at `x`.
+pub fn erf_inv(x: f64) -> f64 {
+ if x == 0.0 {
+ 0.0
+ } else if x >= 1.0 {
+ f64::INFINITY
+ } else if x <= -1.0 {
+ f64::NEG_INFINITY
+ } else if x < 0.0 {
+ erf_inv_impl(-x, 1.0 + x, -1.0)
+ } else {
+ erf_inv_impl(x, 1.0 - x, 1.0)
+ }
+}
+
+/// `erfc` calculates the complementary error function
+/// at `x`.
+pub fn erfc(x: f64) -> f64 {
+ if x.is_nan() {
+ f64::NAN
+ } else if x == f64::INFINITY {
+ 0.0
+ } else if x == f64::NEG_INFINITY {
+ 2.0
+ } else {
+ erf_impl(x, true)
+ }
+}
+
+/// `erfc_inv` calculates the complementary inverse
+/// error function at `x`.
+pub fn erfc_inv(x: f64) -> f64 {
+ if x <= 0.0 {
+ f64::INFINITY
+ } else if x >= 2.0 {
+ f64::NEG_INFINITY
+ } else if x > 1.0 {
+ erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
+ } else {
+ erf_inv_impl(1.0 - x, x, 1.0)
+ }
+}
+
+// **********************************************************
+// ********** Coefficients for erf_impl polynomial **********
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_impl`
+/// in the interval [1e-10, 0.5].
+const ERF_IMPL_AN: &[f64] = &[
+ 0.00337916709551257388990745,
+ -0.00073695653048167948530905,
+ -0.374732337392919607868241,
+ 0.0817442448733587196071743,
+ -0.0421089319936548595203468,
+ 0.0070165709512095756344528,
+ -0.00495091255982435110337458,
+ 0.000871646599037922480317225,
+];
+
+/// Polynomial coefficients for a denominator of `erf_impl`
+/// in the interval [1e-10, 0.5]
+const ERF_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.218088218087924645390535,
+ 0.412542972725442099083918,
+ -0.0841891147873106755410271,
+ 0.0655338856400241519690695,
+ -0.0120019604454941768171266,
+ 0.00408165558926174048329689,
+ -0.000615900721557769691924509,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BN: &[f64] = &[
+ -0.0361790390718262471360258,
+ 0.292251883444882683221149,
+ 0.281447041797604512774415,
+ 0.125610208862766947294894,
+ 0.0274135028268930549240776,
+ 0.00250839672168065762786937,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.5, 0.75].
+const ERF_IMPL_BD: &[f64] = &[
+ 1.0,
+ 1.8545005897903486499845,
+ 1.43575803037831418074962,
+ 0.582827658753036572454135,
+ 0.124810476932949746447682,
+ 0.0113724176546353285778481,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CN: &[f64] = &[
+ -0.0397876892611136856954425,
+ 0.153165212467878293257683,
+ 0.191260295600936245503129,
+ 0.10276327061989304213645,
+ 0.029637090615738836726027,
+ 0.0046093486780275489468812,
+ 0.000307607820348680180548455,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [0.75, 1.25].
+const ERF_IMPL_CD: &[f64] = &[
+ 1.0,
+ 1.95520072987627704987886,
+ 1.64762317199384860109595,
+ 0.768238607022126250082483,
+ 0.209793185936509782784315,
+ 0.0319569316899913392596356,
+ 0.00213363160895785378615014,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DN: &[f64] = &[
+ -0.0300838560557949717328341,
+ 0.0538578829844454508530552,
+ 0.0726211541651914182692959,
+ 0.0367628469888049348429018,
+ 0.00964629015572527529605267,
+ 0.00133453480075291076745275,
+ 0.778087599782504251917881e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [1.25, 2.25].
+const ERF_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.75967098147167528287343,
+ 1.32883571437961120556307,
+ 0.552528596508757581287907,
+ 0.133793056941332861912279,
+ 0.0179509645176280768640766,
+ 0.00104712440019937356634038,
+ -0.106640381820357337177643e-7,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_EN: &[f64] = &[
+ -0.0117907570137227847827732,
+ 0.014262132090538809896674,
+ 0.0202234435902960820020765,
+ 0.00930668299990432009042239,
+ 0.00213357802422065994322516,
+ 0.00025022987386460102395382,
+ 0.120534912219588189822126e-4,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [2.25, 3.5].
+const ERF_IMPL_ED: &[f64] = &[
+ 1.0,
+ 1.50376225203620482047419,
+ 0.965397786204462896346934,
+ 0.339265230476796681555511,
+ 0.0689740649541569716897427,
+ 0.00771060262491768307365526,
+ 0.000371421101531069302990367,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FN: &[f64] = &[
+ -0.00546954795538729307482955,
+ 0.00404190278731707110245394,
+ 0.0054963369553161170521356,
+ 0.00212616472603945399437862,
+ 0.000394984014495083900689956,
+ 0.365565477064442377259271e-4,
+ 0.135485897109932323253786e-5,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [3.5, 5.25].
+const ERF_IMPL_FD: &[f64] = &[
+ 1.0,
+ 1.21019697773630784832251,
+ 0.620914668221143886601045,
+ 0.173038430661142762569515,
+ 0.0276550813773432047594539,
+ 0.00240625974424309709745382,
+ 0.891811817251336577241006e-4,
+ -0.465528836283382684461025e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GN: &[f64] = &[
+ -0.00270722535905778347999196,
+ 0.0013187563425029400461378,
+ 0.00119925933261002333923989,
+ 0.00027849619811344664248235,
+ 0.267822988218331849989363e-4,
+ 0.923043672315028197865066e-6,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [5.25, 8].
+const ERF_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.814632808543141591118279,
+ 0.268901665856299542168425,
+ 0.0449877216103041118694989,
+ 0.00381759663320248459168994,
+ 0.000131571897888596914350697,
+ 0.404815359675764138445257e-11,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HN: &[f64] = &[
+ -0.00109946720691742196814323,
+ 0.000406425442750422675169153,
+ 0.000274499489416900707787024,
+ 0.465293770646659383436343e-4,
+ 0.320955425395767463401993e-5,
+ 0.778286018145020892261936e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [8, 11.5].
+const ERF_IMPL_HD: &[f64] = &[
+ 1.0,
+ 0.588173710611846046373373,
+ 0.139363331289409746077541,
+ 0.0166329340417083678763028,
+ 0.00100023921310234908642639,
+ 0.24254837521587225125068e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_IN: &[f64] = &[
+ -0.00056907993601094962855594,
+ 0.000169498540373762264416984,
+ 0.518472354581100890120501e-4,
+ 0.382819312231928859704678e-5,
+ 0.824989931281894431781794e-7,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [11.5, 17].
+const ERF_IMPL_ID: &[f64] = &[
+ 1.0,
+ 0.339637250051139347430323,
+ 0.043472647870310663055044,
+ 0.00248549335224637114641629,
+ 0.535633305337152900549536e-4,
+ -0.117490944405459578783846e-12,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JN: &[f64] = &[
+ -0.000241313599483991337479091,
+ 0.574224975202501512365975e-4,
+ 0.115998962927383778460557e-4,
+ 0.581762134402593739370875e-6,
+ 0.853971555085673614607418e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [17, 24].
+const ERF_IMPL_JD: &[f64] = &[
+ 1.0,
+ 0.233044138299687841018015,
+ 0.0204186940546440312625597,
+ 0.000797185647564398289151125,
+ 0.117019281670172327758019e-4,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KN: &[f64] = &[
+ -0.000146674699277760365803642,
+ 0.162666552112280519955647e-4,
+ 0.269116248509165239294897e-5,
+ 0.979584479468091935086972e-7,
+ 0.101994647625723465722285e-8,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [24, 38].
+const ERF_IMPL_KD: &[f64] = &[
+ 1.0,
+ 0.165907812944847226546036,
+ 0.0103361716191505884359634,
+ 0.000286593026373868366935721,
+ 0.298401570840900340874568e-5,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LN: &[f64] = &[
+ -0.583905797629771786720406e-4,
+ 0.412510325105496173512992e-5,
+ 0.431790922420250949096906e-6,
+ 0.993365155590013193345569e-8,
+ 0.653480510020104699270084e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [38, 60].
+const ERF_IMPL_LD: &[f64] = &[
+ 1.0,
+ 0.105077086072039915406159,
+ 0.00414278428675475620830226,
+ 0.726338754644523769144108e-4,
+ 0.477818471047398785369849e-6,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MN: &[f64] = &[
+ -0.196457797609229579459841e-4,
+ 0.157243887666800692441195e-5,
+ 0.543902511192700878690335e-7,
+ 0.317472492369117710852685e-9,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [60, 85].
+const ERF_IMPL_MD: &[f64] = &[
+ 1.0,
+ 0.052803989240957632204885,
+ 0.000926876069151753290378112,
+ 0.541011723226630257077328e-5,
+ 0.535093845803642394908747e-15,
+];
+
+/// Polynomial coefficients for a numerator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_NN: &[f64] = &[
+ -0.789224703978722689089794e-5,
+ 0.622088451660986955124162e-6,
+ 0.145728445676882396797184e-7,
+ 0.603715505542715364529243e-10,
+];
+
+/// Polynomial coefficients for a denominator in `erf_impl`
+/// in the interval [85, 110].
+const ERF_IMPL_ND: &[f64] = &[
+ 1.0,
+ 0.0375328846356293715248719,
+ 0.000467919535974625308126054,
+ 0.193847039275845656900547e-5,
+];
+
+// **********************************************************
+// ********** Coefficients for erf_inv_impl polynomial ******
+// **********************************************************
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AN: &[f64] = &[
+ -0.000508781949658280665617,
+ -0.00836874819741736770379,
+ 0.0334806625409744615033,
+ -0.0126926147662974029034,
+ -0.0365637971411762664006,
+ 0.0219878681111168899165,
+ 0.00822687874676915743155,
+ -0.00538772965071242932965,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0, 0.5].
+const ERF_INV_IMPL_AD: &[f64] = &[
+ 1.0,
+ -0.970005043303290640362,
+ -1.56574558234175846809,
+ 1.56221558398423026363,
+ 0.662328840472002992063,
+ -0.71228902341542847553,
+ -0.0527396382340099713954,
+ 0.0795283687341571680018,
+ -0.00233393759374190016776,
+ 0.000886216390456424707504,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BN: &[f64] = &[
+ -0.202433508355938759655,
+ 0.105264680699391713268,
+ 8.37050328343119927838,
+ 17.6447298408374015486,
+ -18.8510648058714251895,
+ -44.6382324441786960818,
+ 17.445385985570866523,
+ 21.1294655448340526258,
+ -3.67192254707729348546,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.5, 0.75].
+const ERF_INV_IMPL_BD: &[f64] = &[
+ 1.0,
+ 6.24264124854247537712,
+ 3.9713437953343869095,
+ -28.6608180499800029974,
+ -20.1432634680485188801,
+ 48.5609213108739935468,
+ 10.8268667355460159008,
+ -22.6436933413139721736,
+ 1.72114765761200282724,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CN: &[f64] = &[
+ -0.131102781679951906451,
+ -0.163794047193317060787,
+ 0.117030156341995252019,
+ 0.387079738972604337464,
+ 0.337785538912035898924,
+ 0.142869534408157156766,
+ 0.0290157910005329060432,
+ 0.00214558995388805277169,
+ -0.679465575181126350155e-6,
+ 0.285225331782217055858e-7,
+ -0.681149956853776992068e-9,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x less than 3.
+const ERF_INV_IMPL_CD: &[f64] = &[
+ 1.0,
+ 3.46625407242567245975,
+ 5.38168345707006855425,
+ 4.77846592945843778382,
+ 2.59301921623620271374,
+ 0.848854343457902036425,
+ 0.152264338295331783612,
+ 0.01105924229346489121,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DN: &[f64] = &[
+ -0.0350353787183177984712,
+ -0.00222426529213447927281,
+ 0.0185573306514231072324,
+ 0.00950804701325919603619,
+ 0.00187123492819559223345,
+ 0.000157544617424960554631,
+ 0.460469890584317994083e-5,
+ -0.230404776911882601748e-9,
+ 0.266339227425782031962e-11,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 3 and 6.
+const ERF_INV_IMPL_DD: &[f64] = &[
+ 1.0,
+ 1.3653349817554063097,
+ 0.762059164553623404043,
+ 0.220091105764131249824,
+ 0.0341589143670947727934,
+ 0.00263861676657015992959,
+ 0.764675292302794483503e-4,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_EN: &[f64] = &[
+ -0.0167431005076633737133,
+ -0.00112951438745580278863,
+ 0.00105628862152492910091,
+ 0.000209386317487588078668,
+ 0.149624783758342370182e-4,
+ 0.449696789927706453732e-6,
+ 0.462596163522878599135e-8,
+ -0.281128735628831791805e-13,
+ 0.99055709973310326855e-16,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 6 and 18.
+const ERF_INV_IMPL_ED: &[f64] = &[
+ 1.0,
+ 0.591429344886417493481,
+ 0.138151865749083321638,
+ 0.0160746087093676504695,
+ 0.000964011807005165528527,
+ 0.275335474764726041141e-4,
+ 0.282243172016108031869e-6,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FN: &[f64] = &[
+ -0.0024978212791898131227,
+ -0.779190719229053954292e-5,
+ 0.254723037413027451751e-4,
+ 0.162397777342510920873e-5,
+ 0.396341011304801168516e-7,
+ 0.411632831190944208473e-9,
+ 0.145596286718675035587e-11,
+ -0.116765012397184275695e-17,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x between 18 and 44.
+const ERF_INV_IMPL_FD: &[f64] = &[
+ 1.0,
+ 0.207123112214422517181,
+ 0.0169410838120975906478,
+ 0.000690538265622684595676,
+ 0.145007359818232637924e-4,
+ 0.144437756628144157666e-6,
+ 0.509761276599778486139e-9,
+];
+
+/// Polynomial coefficients for a numerator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GN: &[f64] = &[
+ -0.000539042911019078575891,
+ -0.28398759004727721098e-6,
+ 0.899465114892291446442e-6,
+ 0.229345859265920864296e-7,
+ 0.225561444863500149219e-9,
+ 0.947846627503022684216e-12,
+ 0.135880130108924861008e-14,
+ -0.348890393399948882918e-21,
+];
+
+/// Polynomial coefficients for a denominator of `erf_inv_impl`
+/// in the interval [0.75, 1] with x greater than 44.
+const ERF_INV_IMPL_GD: &[f64] = &[
+ 1.0,
+ 0.0845746234001899436914,
+ 0.00282092984726264681981,
+ 0.468292921940894236786e-4,
+ 0.399968812193862100054e-6,
+ 0.161809290887904476097e-8,
+ 0.231558608310259605225e-11,
+];
+
+/// `erf_impl` computes the error function at `z`.
+/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
+fn erf_impl(z: f64, inv: bool) -> f64 {
+ if z < 0.0 {
+ if !inv {
+ return -erf_impl(-z, false);
+ }
+ if z < -0.5 {
+ return 2.0 - erf_impl(-z, true);
+ }
+ return 1.0 + erf_impl(-z, false);
+ }
+
+ let result = if z < 0.5 {
+ if z < 1e-10 {
+ z * 1.125 + z * 0.003379167095512573896158903121545171688
+ } else {
+ z * 1.125
+ + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
+ }
+ } else if z < 110.0 {
+ let (r, b) = if z < 0.75 {
+ (
+ evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
+ / evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
+ 0.3440242112,
+ )
+ } else if z < 1.25 {
+ (
+ evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
+ / evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
+ 0.419990927,
+ )
+ } else if z < 2.25 {
+ (
+ evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
+ / evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
+ 0.4898625016,
+ )
+ } else if z < 3.5 {
+ (
+ evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
+ / evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
+ 0.5317370892,
+ )
+ } else if z < 5.25 {
+ (
+ evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
+ / evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
+ 0.5489973426,
+ )
+ } else if z < 8.0 {
+ (
+ evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
+ / evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
+ 0.5571740866,
+ )
+ } else if z < 11.5 {
+ (
+ evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
+ / evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
+ 0.5609807968,
+ )
+ } else if z < 17.0 {
+ (
+ evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
+ / evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
+ 0.5626493692,
+ )
+ } else if z < 24.0 {
+ (
+ evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
+ / evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
+ 0.5634598136,
+ )
+ } else if z < 38.0 {
+ (
+ evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
+ / evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
+ 0.5638477802,
+ )
+ } else if z < 60.0 {
+ (
+ evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
+ / evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
+ 0.5640528202,
+ )
+ } else if z < 85.0 {
+ (
+ evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
+ / evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
+ 0.5641309023,
+ )
+ } else {
+ (
+ evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
+ / evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
+ 0.5641584396,
+ )
+ };
+ let g = (-z * z).exp() / z;
+ g * b + g * r
+ } else {
+ 0.0
+ };
+
+ if inv && z >= 0.5 {
+ result
+ } else if z >= 0.5 || inv {
+ 1.0 - result
+ } else {
+ result
+ }
+}
+
+// `erf_inv_impl` computes the inverse error function where
+// `p`,`q`, and `s` are the first, second, and third intermediate
+// parameters respectively
+fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
+ let result = if p <= 0.5 {
+ let y = 0.0891314744949340820313;
+ let g = p * (p + 10.0);
+ let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
+ g * y + g * r
+ } else if q >= 0.25 {
+ let y = 2.249481201171875;
+ let g = (-2.0 * q.ln()).sqrt();
+ let xs = q - 0.25;
+ let r =
+ evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
+ g / (y + r)
+ } else {
+ let x = (-q.ln()).sqrt();
+ if x < 3.0 {
+ let y = 0.807220458984375;
+ let xs = x - 1.125;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_CD);
+ y * x + r * x
+ } else if x < 6.0 {
+ let y = 0.93995571136474609375;
+ let xs = x - 3.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_DD);
+ y * x + r * x
+ } else if x < 18.0 {
+ let y = 0.98362827301025390625;
+ let xs = x - 6.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_ED);
+ y * x + r * x
+ } else if x < 44.0 {
+ let y = 0.99714565277099609375;
+ let xs = x - 18.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_FD);
+ y * x + r * x
+ } else {
+ let y = 0.99941349029541015625;
+ let xs = x - 44.0;
+ let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
+ / evaluate::polynomial(xs, ERF_INV_IMPL_GD);
+ y * x + r * x
+ }
+ };
+ s * result
+}
diff --git a/mcandle/mcandle-core/src/cpu/kernels.rs b/mcandle/mcandle-core/src/cpu/kernels.rs
new file mode 100644
index 000000000..f81ad625d
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/kernels.rs
@@ -0,0 +1,220 @@
+pub trait VecOps: num_traits::NumAssign + Copy {
+ fn min(self, rhs: Self) -> Self;
+ fn max(self, rhs: Self) -> Self;
+
+ /// Dot-product of two vectors.
+ ///
+ /// # Safety
+ ///
+ /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ *res = Self::zero();
+ for i in 0..len {
+ *res += *lhs.add(i) * *rhs.add(i)
+ }
+ }
+
+ /// Sum of all elements in a vector.
+ ///
+ /// # Safety
+ ///
+ /// The length of `xs` must be at least `len`. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
+ *res = Self::zero();
+ for i in 0..len {
+ *res += *xs.add(i)
+ }
+ }
+
+ /// Maximum element in a non-empty vector.
+ ///
+ /// # Safety
+ ///
+ /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
+ *res = *xs;
+ for i in 1..len {
+ *res = (*res).max(*xs.add(i))
+ }
+ }
+
+ /// Minimum element in a non-empty vector.
+ ///
+ /// # Safety
+ ///
+ /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
+ *res = *xs;
+ for i in 1..len {
+ *res = (*res).min(*xs.add(i))
+ }
+ }
+}
+
+impl VecOps for f32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ super::vec_dot_f32(lhs, rhs, res, len)
+ }
+
+ #[inline(always)]
+ unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
+ super::vec_sum(xs, res, len)
+ }
+}
+
+impl VecOps for half::f16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ let mut res_f32 = 0f32;
+ super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
+ *res = half::f16::from_f32(res_f32);
+ }
+}
+
+impl VecOps for f64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for half::bf16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ let mut res_f32 = 0f32;
+ super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
+ *res = half::bf16::from_f32(res_f32);
+ }
+}
+impl VecOps for u8 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for u32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for i16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for i32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+impl VecOps for i64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ ::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ ::max(self, other)
+ }
+}
+
+#[inline(always)]
+pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
+ if n_threads == 1 {
+ func(0)
+ } else {
+ rayon::scope(|s| {
+ for thread_idx in 0..n_threads {
+ let func = &func;
+ s.spawn(move |_| func(thread_idx));
+ }
+ })
+ }
+}
+
+#[inline(always)]
+pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
+ if n_threads == 1 {
+ for i in lo..up {
+ func(i)
+ }
+ } else {
+ rayon::scope(|s| {
+ for thread_idx in 0..n_threads {
+ let func = &func;
+ s.spawn(move |_| {
+ for i in (thread_idx..up).step_by(n_threads) {
+ func(i)
+ }
+ });
+ }
+ })
+ }
+}
diff --git a/mcandle/mcandle-core/src/cpu/mod.rs b/mcandle/mcandle-core/src/cpu/mod.rs
new file mode 100644
index 000000000..0b77e6ecb
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/mod.rs
@@ -0,0 +1,240 @@
+pub mod erf;
+pub mod kernels;
+
+#[allow(unused)]
+trait Cpu {
+ type Unit;
+ type Array;
+ const STEP: usize;
+ const EPR: usize;
+
+ fn n() -> usize;
+ unsafe fn zero() -> Self::Unit;
+ unsafe fn zero_array() -> Self::Array;
+ unsafe fn load(mem_addr: *const f32) -> Self::Unit;
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
+ unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
+ unsafe fn from_f32(v: f32) -> Self::Unit;
+ unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
+}
+
+#[allow(unused)]
+trait CpuF16 {
+ type Unit;
+ type Array;
+ const STEP: usize;
+ const EPR: usize;
+
+ fn n() -> usize;
+ unsafe fn zero() -> Self::Unit;
+ unsafe fn zero_array() -> Self::Array;
+ unsafe fn load(mem_addr: *const f16) -> Self::Unit;
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
+ unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
+ unsafe fn from_f32(v: f32) -> Self::Unit;
+ unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
+}
+
+#[allow(unused)]
+trait CpuBF16 {
+ type Unit;
+ type Array;
+ const STEP: usize;
+ const EPR: usize;
+
+ fn n() -> usize;
+ unsafe fn zero() -> Self::Unit;
+ unsafe fn zero_array() -> Self::Array;
+ unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
+ unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
+ unsafe fn from_f32(v: f32) -> Self::Unit;
+ unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
+}
+
+use half::{bf16, f16};
+
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+#[cfg(target_feature = "avx")]
+pub mod avx;
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+#[cfg(target_feature = "avx")]
+pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};
+
+#[cfg(target_arch = "wasm32")]
+#[cfg(target_feature = "simd128")]
+pub mod simd128;
+#[cfg(target_arch = "wasm32")]
+#[cfg(target_feature = "simd128")]
+pub use simd128::CurrentCpu;
+
+#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
+#[cfg(target_feature = "neon")]
+pub mod neon;
+#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
+#[cfg(target_feature = "neon")]
+pub use neon::CurrentCpu;
+
+#[cfg(any(
+ target_feature = "neon",
+ target_feature = "avx",
+ target_feature = "simd128"
+))]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
+ let np = k & !(CurrentCpu::STEP - 1);
+
+ let mut sum = CurrentCpu::zero_array();
+ let mut ax = CurrentCpu::zero_array();
+ let mut ay = CurrentCpu::zero_array();
+
+ for i in (0..np).step_by(CurrentCpu::STEP) {
+ for j in 0..CurrentCpu::n() {
+ ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
+ ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
+
+ sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ CurrentCpu::vec_reduce(sum, c);
+
+ // leftovers
+ for i in np..k {
+ *c += *a_row.add(i) * (*b_row.add(i));
+ }
+}
+
+#[cfg(not(any(
+ target_feature = "neon",
+ target_feature = "avx",
+ target_feature = "simd128"
+)))]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
+ // leftovers
+ for i in 0..k {
+ *c += *a_row.add(i) * (*b_row.add(i));
+ }
+}
+
+#[cfg(any(
+ target_feature = "neon",
+ target_feature = "avx",
+ target_feature = "simd128"
+))]
+#[inline(always)]
+pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
+ let np = k & !(CurrentCpu::STEP - 1);
+
+ let mut sum = CurrentCpu::zero_array();
+ let mut x = CurrentCpu::zero_array();
+
+ for i in (0..np).step_by(CurrentCpu::STEP) {
+ for j in 0..CurrentCpu::n() {
+ x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
+ sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
+ }
+ }
+
+ CurrentCpu::vec_reduce(sum, b);
+
+ // leftovers
+ for i in np..k {
+ *b += *row.add(i)
+ }
+}
+
+#[cfg(not(any(
+ target_feature = "neon",
+ target_feature = "avx",
+ target_feature = "simd128"
+)))]
+#[inline(always)]
+pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
+ *b = 0f32;
+ for i in 0..k {
+ *b += *row.add(i)
+ }
+}
+
+#[cfg(target_feature = "avx")]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
+ let mut sumf = 0.0f32;
+ let np = k & !(CurrentCpuF16::STEP - 1);
+
+ let mut sum = CurrentCpuF16::zero_array();
+ let mut ax = CurrentCpuF16::zero_array();
+ let mut ay = CurrentCpuF16::zero_array();
+
+ for i in (0..np).step_by(CurrentCpuF16::STEP) {
+ for j in 0..CurrentCpuF16::n() {
+ ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
+ ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
+
+ sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ CurrentCpuF16::vec_reduce(sum, &mut sumf);
+
+ // leftovers
+ for i in np..k {
+ sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sumf;
+}
+
+#[cfg(target_feature = "avx")]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
+ let mut sumf = 0.0f32;
+ let np = k & !(CurrentCpuBF16::STEP - 1);
+
+ let mut sum = CurrentCpuBF16::zero_array();
+ let mut ax = CurrentCpuBF16::zero_array();
+ let mut ay = CurrentCpuBF16::zero_array();
+
+ for i in (0..np).step_by(CurrentCpuBF16::STEP) {
+ for j in 0..CurrentCpuBF16::n() {
+ ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
+ ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));
+
+ sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ CurrentCpuBF16::vec_reduce(sum, &mut sumf);
+
+ // leftovers
+ for i in np..k {
+ sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sumf;
+}
+
+#[cfg(not(target_feature = "avx"))]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
+ // leftovers
+ let mut sum = 0.0;
+ for i in 0..k {
+ sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sum;
+}
+
+#[cfg(not(target_feature = "avx"))]
+#[inline(always)]
+pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
+ // leftovers
+ let mut sum = 0.0;
+ for i in 0..k {
+ sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
+ }
+ *c = sum;
+}
diff --git a/mcandle/mcandle-core/src/cpu/neon.rs b/mcandle/mcandle-core/src/cpu/neon.rs
new file mode 100644
index 000000000..66b8b45e1
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/neon.rs
@@ -0,0 +1,74 @@
+use super::Cpu;
+#[cfg(target_arch = "arm")]
+use core::arch::arm::*;
+
+#[cfg(target_arch = "aarch64")]
+use core::arch::aarch64::*;
+
+pub struct CurrentCpu {}
+
+const STEP: usize = 16;
+const EPR: usize = 4;
+const ARR: usize = STEP / EPR;
+
+impl CurrentCpu {
+ #[cfg(target_arch = "aarch64")]
+ unsafe fn reduce_one(x: float32x4_t) -> f32 {
+ vaddvq_f32(x)
+ }
+
+ #[cfg(target_arch = "arm")]
+ unsafe fn reduce_one(x: float32x4_t) -> f32 {
+ vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
+ }
+}
+
+impl Cpu for CurrentCpu {
+ type Unit = float32x4_t;
+ type Array = [float32x4_t; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ vdupq_n_f32(0.0)
+ }
+
+ unsafe fn from_f32(x: f32) -> Self::Unit {
+ vdupq_n_f32(x)
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn load(mem_addr: *const f32) -> Self::Unit {
+ vld1q_f32(mem_addr)
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ vaddq_f32(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ vfmaq_f32(a, b, c)
+ }
+
+ unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
+ vst1q_f32(mem_addr, a);
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ for i in 0..ARR / 2 {
+ x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
+ }
+ for i in 0..ARR / 4 {
+ x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
+ }
+ *y = Self::reduce_one(x[0]);
+ }
+}
diff --git a/mcandle/mcandle-core/src/cpu/simd128.rs b/mcandle/mcandle-core/src/cpu/simd128.rs
new file mode 100644
index 000000000..64b94588b
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu/simd128.rs
@@ -0,0 +1,64 @@
+use super::Cpu;
+use core::arch::wasm32::*;
+
+pub struct CurrentCpu {}
+
+const STEP: usize = 16;
+const EPR: usize = 4;
+const ARR: usize = STEP / EPR;
+
+impl Cpu for CurrentCpu {
+ type Unit = v128;
+ type Array = [v128; ARR];
+
+ const STEP: usize = STEP;
+ const EPR: usize = EPR;
+
+ fn n() -> usize {
+ ARR
+ }
+
+ unsafe fn zero() -> Self::Unit {
+ f32x4_splat(0.0)
+ }
+
+ unsafe fn zero_array() -> Self::Array {
+ [Self::zero(); ARR]
+ }
+
+ unsafe fn from_f32(v: f32) -> Self::Unit {
+ f32x4_splat(v)
+ }
+
+ unsafe fn load(mem_addr: *const f32) -> Self::Unit {
+ v128_load(mem_addr as *mut v128)
+ }
+
+ unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
+ f32x4_add(a, b)
+ }
+
+ unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
+ f32x4_add(f32x4_mul(b, c), a)
+ }
+
+ unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
+ v128_store(mem_addr as *mut v128, a);
+ }
+
+ unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
+ for i in 0..ARR / 2 {
+ x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
+ }
+ for i in 0..ARR / 4 {
+ x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
+ }
+ for i in 0..ARR / 8 {
+ x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
+ }
+ *y = f32x4_extract_lane::<0>(x[0])
+ + f32x4_extract_lane::<1>(x[0])
+ + f32x4_extract_lane::<2>(x[0])
+ + f32x4_extract_lane::<3>(x[0]);
+ }
+}
diff --git a/mcandle/mcandle-core/src/cpu_backend/mod.rs b/mcandle/mcandle-core/src/cpu_backend/mod.rs
new file mode 100644
index 000000000..6ef74c072
--- /dev/null
+++ b/mcandle/mcandle-core/src/cpu_backend/mod.rs
@@ -0,0 +1,3759 @@
+use std::ops::Deref;
+
+use crate::backend::{BackendDevice, BackendStorage};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
+use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
+use float8::F8E4M3;
+use half::{bf16, f16};
+use rayon::prelude::*;
+
+mod utils;
+pub use utils::{
+ binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2Alpha, Map2U8,
+ Map3,
+};
+
+const USE_IM2COL_CONV1D: bool = true;
+const USE_COL2IM_CONV1D_TR: bool = true;
+const USE_IM2COL_CONV2D: bool = true;
+
+// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
+// intercept the oom errors to avoid panicking and provide a proper error.
+#[derive(Debug, Clone)]
+pub enum CpuStorage {
+ U8(Vec),
+ U32(Vec),
+ I16(Vec),
+ I32(Vec),
+ I64(Vec),
+ BF16(Vec),
+ F16(Vec),
+ F32(Vec),
+ F64(Vec),
+ F8E4M3(Vec),
+}
+
+#[derive(Debug, Clone)]
+pub enum CpuStorageRef<'a> {
+ U8(&'a [u8]),
+ U32(&'a [u32]),
+ I16(&'a [i16]),
+ I32(&'a [i32]),
+ I64(&'a [i64]),
+ BF16(&'a [bf16]),
+ F16(&'a [f16]),
+ F32(&'a [f32]),
+ F64(&'a [f64]),
+ F8E4M3(&'a [F8E4M3]),
+}
+
+#[derive(Debug, Clone)]
+pub struct CpuDevice;
+
+struct Cmp(CmpOp);
+impl Map2U8 for Cmp {
+ const OP: &'static str = "cmp";
+ #[inline(always)]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result> {
+ let dst = match self.0 {
+ CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
+ CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
+ CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
+ CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
+ CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
+ CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
+ };
+ Ok(dst)
+ }
+}
+
+struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
+
+impl<'a, I: IntDType> Map2 for WCond<'a, I> {
+ const OP: &'static str = "where";
+ #[inline(always)]
+ fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> {
+ let vs = match (
+ self.1.contiguous_offsets(),
+ t_l.contiguous_offsets(),
+ f_l.contiguous_offsets(),
+ ) {
+ (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
+ let pred = &self.0[o1..o2];
+ let t = &t[o_t1..o_t2];
+ let f = &f[o_f1..o_f2];
+ pred.iter()
+ .zip(t.iter().zip(f.iter()))
+ .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
+ .collect::>()
+ }
+ _ => self
+ .1
+ .strided_index()
+ .zip(t_l.strided_index().zip(f_l.strided_index()))
+ .map(|(i_p, (i_t, i_f))| {
+ if self.0[i_p].is_true() {
+ t[i_t]
+ } else {
+ f[i_f]
+ }
+ })
+ .collect::>(),
+ };
+ Ok(vs)
+ }
+}
+
+struct ReduceIndex {
+ reduce_dim_index: usize,
+ use_min: bool,
+ return_index: bool,
+}
+
+impl ReduceIndex {
+ // The value gets replaced if f(s[current_acc], s[i]) returns true.
+ #[inline(always)]
+ fn fold_impl(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result>
+ where
+ T: Clone + Copy,
+ U: Clone + Copy,
+ F: Fn(T, T) -> bool,
+ G: Fn(T, usize) -> U,
+ {
+ let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
+ let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
+ let dst_len = src_l.shape().elem_count() / reduce_dim_size;
+ let mut dst: Vec = Vec::with_capacity(dst_len);
+ let dst_to_set = dst.spare_capacity_mut();
+ let dst_to_set =
+ unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(dst_to_set) };
+ match src_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ let src = &src[o1..o2];
+ if reduce_dim_stride == 1 {
+ for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
+ let start_src_i = start_src_i * reduce_dim_size;
+ let src = &src[start_src_i..start_src_i + reduce_dim_size];
+ let mut acc = 0;
+ let mut val = src[0];
+ for (src_i, &s) in src.iter().enumerate() {
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ *dst_v = g(val, acc)
+ }
+ } else {
+ for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
+ let (p, q) = (
+ start_src_i / reduce_dim_stride,
+ start_src_i % reduce_dim_stride,
+ );
+ // start_src_i = p * reduce_dim_stride + q
+ let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
+ let src = &src[start_src_i..];
+ let mut acc = 0;
+ let mut val = src[0];
+ for src_i in 0..reduce_dim_size {
+ let s = src[src_i * reduce_dim_stride];
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ *dst_v = g(val, acc)
+ }
+ }
+ }
+ None => {
+ let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
+ for (unstr_index, src_index) in l.strided_index().enumerate() {
+ let src = &src[src_index..];
+ let mut acc = 0;
+ let mut val = src[0];
+ for src_i in 0..reduce_dim_size {
+ let s = src[src_i * reduce_dim_stride];
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ dst_to_set[unstr_index] = g(val, acc)
+ }
+ }
+ }
+ unsafe { dst.set_len(dst_len) };
+ Ok(dst)
+ }
+}
+
+impl Map1Any for ReduceIndex {
+ #[inline(always)]
+ fn f) -> CpuStorage>(
+ &self,
+ src: &[T],
+ src_l: &Layout,
+ wrap: W,
+ ) -> Result {
+ if src_l.shape().elem_count() == 0 {
+ Err(Error::EmptyTensor { op: "reduce" }.bt())?
+ }
+ let dst = match (self.return_index, self.use_min) {
+ (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
+ (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
+ (true, true) => {
+ CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
+ }
+ (true, false) => {
+ CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
+ }
+ };
+ Ok(dst)
+ }
+}
+
+struct ReduceSum<'a> {
+ dst_shape: &'a Shape,
+ reduce_dims: &'a [usize],
+ reduce_dims_and_stride: Vec<(usize, usize)>,
+}
+
+impl<'a> ReduceSum<'a> {
+ #[inline(always)]
+ fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result>
+ where
+ T: WithDType,
+ {
+ let mut dst = vec![start_elt; self.dst_shape.elem_count()];
+ match src_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ let src = &src[o1..o2];
+ // Handle the case where we reduce over the last dimensions separately as it is
+ // fairly common and easy to optimize. This rely on the layout being contiguous!
+ // reduce_dims is sorted, check if it is ranging from a to n-1.
+ let reduce_over_last_dims = self
+ .reduce_dims
+ .iter()
+ .rev()
+ .enumerate()
+ .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
+ if reduce_over_last_dims {
+ let reduce_sz = self
+ .reduce_dims_and_stride
+ .iter()
+ .map(|(u, _)| u)
+ .product::();
+ for (dst_i, dst_v) in dst.iter_mut().enumerate() {
+ let src_i = dst_i * reduce_sz;
+ unsafe {
+ T::vec_reduce_sum(
+ src[src_i..src_i + reduce_sz].as_ptr(),
+ dst_v,
+ reduce_sz,
+ )
+ };
+ }
+ return Ok(dst);
+ };
+ for (unstr_index, &src) in src.iter().enumerate() {
+ let mut dst_index = unstr_index;
+ // Set the reduce_dims indexes to 0.
+ for &(dim, stride) in self.reduce_dims_and_stride.iter() {
+ // The compiler is able to optimize the following in a single divmod op.
+ let (pre, post) = (dst_index / stride, dst_index % stride);
+ dst_index = (pre / dim) * stride + post;
+ }
+ dst[dst_index] += src;
+ }
+ }
+ None => {
+ for (unstr_index, src_index) in src_l.strided_index().enumerate() {
+ let mut dst_index = unstr_index;
+ // Set the reduce_dims indexes to 0.
+ for &(dim, stride) in self.reduce_dims_and_stride.iter() {
+ // The compiler is able to optimize the following in a single divmod op.
+ let (pre, post) = (dst_index / stride, dst_index % stride);
+ dst_index = (pre / dim) * stride + post;
+ }
+ dst[dst_index] += src[src_index];
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+impl<'a> Map1 for ReduceSum<'a> {
+ #[inline(always)]
+ fn f(&self, src: &[T], src_l: &Layout) -> Result> {
+ self.fold_impl(src, src_l, T::zero())
+ }
+}
+
+struct Affine(f64, f64);
+
+impl Map1 for Affine {
+ fn f(&self, vs: &[T], layout: &Layout) -> Result> {
+ let mul = T::from_f64(self.0);
+ let add = T::from_f64(self.1);
+ Ok(unary_map(vs, layout, |v| v * mul + add))
+ }
+}
+
+struct AvgPool2D((usize, usize), (usize, usize));
+
+impl Map1 for AvgPool2D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ let scale = 1f64 / (k_h * k_w) as f64;
+ let scale = T::from_f64(scale);
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut sum = T::zero();
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
+ sum += src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ dst[h_idx * w_out + w_idx] = sum * scale;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct MaxPool2D((usize, usize), (usize, usize));
+
+impl Map1 for MaxPool2D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut largest =
+ src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
+ if largest < src[src_index + m * stride_h + n * stride_w] {
+ largest = src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ }
+ dst[h_idx * w_out + w_idx] = largest;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct UpsampleNearest1D(usize);
+
+impl Map1 for UpsampleNearest1D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ // TODO: Specialized implementation for the case 2*sz?
+ let dst_sz = self.0;
+ let (b_sz, c, src_sz) = layout.shape().dims3()?;
+ let stride = layout.stride();
+ let stride_sz = stride[2];
+ let src_index = layout.start_offset();
+ let scale_sz = src_sz as f64 / dst_sz as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_sz];
+ let src_idxs = (0..dst_sz)
+ .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
+ .collect::>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_sz..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_sz..];
+ let src_index = src_index + c_idx * stride[1];
+ for (idx, src_idx) in src_idxs.iter().enumerate() {
+ dst[idx] = src[src_index + src_idx * stride_sz]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct UpsampleNearest2D(usize, usize);
+
+impl Map1 for UpsampleNearest2D {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ // TODO: Specialized implementation for the case 2*h, 2*w?
+ let (dst_h, dst_w) = (self.0, self.1);
+ let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let src_index = layout.start_offset();
+ let scale_h = src_h as f64 / dst_h as f64;
+ let scale_w = src_w as f64 / dst_w as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
+ let src_h_idxs = (0..dst_h)
+ .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
+ .collect::>();
+ let src_w_idxs = (0..dst_w)
+ .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
+ .collect::>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_h * dst_w..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_h * dst_w..];
+ let src_index = src_index + c_idx * stride[1];
+ for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
+ for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
+ let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
+ dst[h_idx * dst_w + w_idx] = src[src_index]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Gather<'a, I: IntDType> {
+ ids: &'a [I],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a, I: IntDType> Map1 for Gather<'a, I> {
+ fn f(&self, src: &[T], src_l: &Layout) -> Result> {
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
+ };
+ let src = match src_l.contiguous_offsets() {
+ Some((a, b)) => &src[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
+ };
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let src_dims = src_l.dims();
+ let dst_len: usize = ids_dims.iter().product();
+ let dst_left_len: usize = ids_dims[..dim].iter().product();
+ let dst_dim_len = ids_dims[dim];
+ let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let src_dim_len = src_dims[dim];
+ let src_right_len: usize = src_dims[dim + 1..].iter().product();
+
+ let mut dst = vec![T::zero(); dst_len];
+ for left_i in 0..dst_left_len {
+ let start_src_idx = left_i * src_right_len * src_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..dst_dim_len {
+ let start_dst_idx = start_dst_idx + i * dst_right_len;
+ for right_i in 0..dst_right_len {
+ let dst_idx = start_dst_idx + right_i;
+ let index = ids[dst_idx].as_usize();
+ if index >= src_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let src_idx = start_src_idx + index * src_right_len + right_i;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct IndexSelect<'a, T: IntDType> {
+ ids: &'a [T],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
+ fn f(&self, src: &[T], layout: &Layout) -> Result> {
+ let src = match layout.contiguous_offsets() {
+ Some((a, b)) => &src[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
+ };
+ let dim = self.dim;
+ let n_ids = match self.ids_l.dims() {
+ [n_ids] => *n_ids,
+ d => Err(Error::UnexpectedNumberOfDims {
+ expected: 1,
+ got: d.len(),
+ shape: self.ids_l.shape().clone(),
+ }
+ .bt())?,
+ };
+ let stride_ids = self.ids_l.stride()[0];
+ let mut dst_dims = layout.dims().to_vec();
+ let src_dim = dst_dims[dim];
+ dst_dims[dim] = n_ids;
+ let dst_len: usize = dst_dims.iter().product();
+ let left_len: usize = dst_dims[..dim].iter().product();
+ let right_len: usize = dst_dims[dim + 1..].iter().product();
+ let mut dst = vec![T::zero(); dst_len];
+ for left_i in 0..left_len {
+ let start_src_idx = left_i * right_len * src_dim;
+ let start_dst_idx = left_i * right_len * n_ids;
+ for i in 0..n_ids {
+ let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
+ if index >= src_dim {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim,
+ op: "index-select",
+ }
+ .bt())?
+ }
+ let start_src_idx = start_src_idx + index * right_len;
+ let start_dst_idx = start_dst_idx + i * right_len;
+ dst[start_dst_idx..start_dst_idx + right_len]
+ .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct ScatterAdd<'a, I: IntDType> {
+ ids: &'a [I],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
+ const OP: &'static str = "scatter-add";
+ fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> {
+ let dst_len = l1.shape().elem_count();
+ let mut dst = vec![T::zero(); dst_len];
+ copy_strided_src_(v1, &mut dst, 0, l1);
+ let src = match src_l.contiguous_offsets() {
+ None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
+ Some((o1, o2)) => &src[o1..o2],
+ };
+
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let dst_dims = l1.dims();
+ let dst_dim_len = dst_dims[dim];
+ let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
+
+ let ids_left_len: usize = ids_dims[..dim].iter().product();
+ let ids_dim_len = ids_dims[dim];
+ let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
+ };
+ for left_i in 0..ids_left_len {
+ let start_ids_idx = left_i * ids_right_len * ids_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..ids_dim_len {
+ let start_ids_idx = start_ids_idx + i * ids_right_len;
+ for right_i in 0..dst_right_len {
+ let ids_idx = start_ids_idx + right_i;
+ let index = ids[ids_idx].as_usize();
+ if index >= dst_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: dst_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let dst_idx = start_dst_idx + index * dst_right_len + right_i;
+ dst[dst_idx] += src[ids_idx]
+ }
+ }
+ }
+
+ Ok(dst)
+ }
+}
+
+struct IndexAdd<'a, I: IntDType> {
+ ids: &'a [I],
+ dim: usize,
+}
+
+impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
+ const OP: &'static str = "index-add";
+ // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
+ // v1, l1 -> self
+ fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> {
+ let dst_len = l1.shape().elem_count();
+ let mut dst = vec![T::zero(); dst_len];
+ copy_strided_src_(v1, &mut dst, 0, l1);
+ let src = match src_l.contiguous_offsets() {
+ None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let dim = self.dim;
+ let max_idx = l1.dims()[dim];
+ let pre_dim = src_l.dims()[..dim].iter().product::();
+ let src_dim_sz = src_l.dims()[dim];
+ let post_dim = src_l.dims()[dim + 1..].iter().product::();
+ if dim == 0 {
+ for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ let dst_idx = dst_idx.as_usize();
+ if dst_idx >= max_idx {
+ Err(Error::InvalidIndex {
+ index: dst_idx,
+ op: "index-add",
+ size: max_idx,
+ })?
+ }
+ let src_idx = src_idx * post_dim;
+ let dst_idx = dst_idx * post_dim;
+ let src = &src[src_idx..src_idx + post_dim];
+ let dst = &mut dst[dst_idx..dst_idx + post_dim];
+ for (d, &s) in dst.iter_mut().zip(src.iter()) {
+ *d += s
+ }
+ }
+ } else {
+ for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ let dst_idx = dst_idx.as_usize();
+ if dst_idx >= max_idx {
+ Err(Error::InvalidIndex {
+ index: dst_idx,
+ op: "index-add",
+ size: max_idx,
+ })?
+ }
+ for pre_i in 0..pre_dim {
+ let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
+ let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
+ let src = &src[pre_src_i..pre_src_i + post_dim];
+ let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
+ for (d, &s) in dst.iter_mut().zip(src.iter()) {
+ *d += s
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+fn copy2d_(
+ src: &[T],
+ dst: &mut [T],
+ d1: usize,
+ d2: usize,
+ src_stride1: usize,
+ dst_stride1: usize,
+ src_offset: usize,
+ dst_offset: usize,
+) {
+ for i1 in 0..d1 {
+ let dst_idx = i1 * dst_stride1 + dst_offset;
+ let src_idx = i1 * src_stride1 + src_offset;
+ let dst = &mut dst[dst_idx..dst_idx + d2];
+ let src = &src[src_idx..src_idx + d2];
+ dst.copy_from_slice(src)
+ }
+}
+
+fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
+ match src_l.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => {
+ let to_copy = (dst.len() - dst_offset).min(len);
+ dst[dst_offset..dst_offset + to_copy]
+ .copy_from_slice(&src[start_offset..start_offset + to_copy])
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len: 1,
+ } => {
+ for (dst_index, src_index) in block_start_index.enumerate() {
+ let dst_index = dst_index + dst_offset;
+ if dst_index >= dst.len() {
+ break;
+ }
+ dst[dst_index] = src[src_index]
+ }
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ let mut dst_index = dst_offset;
+ for src_index in block_start_index {
+ let next_dst_index = dst_index + block_len;
+ if dst_index >= dst.len() {
+ break;
+ }
+ let to_copy = usize::min(block_len, dst.len() - dst_index);
+ dst[dst_index..dst_index + to_copy]
+ .copy_from_slice(&src[src_index..src_index + to_copy]);
+ dst_index = next_dst_index
+ }
+ }
+ }
+}
+
+struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
+
+impl<'a> Map2 for Conv1D<'a> {
+ const OP: &'static str = "conv1d";
+ fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let k = &k[k_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
+ let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
+ let l_out = p.l_out();
+ let dst_elems = p.c_out * l_out * p.b_size;
+ // The output shape is [b_size, c_out, l_out]
+ let dst = vec![T::zero(); dst_elems];
+
+ // TODO: Avoid making this copy if `inp` already has the appropriate layout.
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
+ for b_idx in 0..p.b_size {
+ for src_l in 0..p.l_in {
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
+ inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
+ }
+ }
+ }
+
+ for offset in 0..p.k_size {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
+ let dst_idx = dst_c_idx * l_out;
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
+ .collect::>();
+ for b_idx in 0..p.b_size {
+ let dst_idx = dst_idx + b_idx * p.c_out * l_out;
+ for dst_l in 0..l_out {
+ let dst_idx = dst_idx + dst_l;
+ let src_l = p.stride * dst_l + offset * p.dilation;
+ if src_l < p.padding || src_l >= p.padding + p.l_in {
+ continue;
+ }
+ let src_l = src_l - p.padding;
+ let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
+ assert!(inp_cont.len() >= p.c_in);
+ assert!(k_cont.len() >= p.c_in);
+ let mut d = T::zero();
+ unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
+ // the different tasks so no two threads can try to write at the same
+ // location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
+ }
+ }
+ })
+ }
+ Ok(dst)
+ }
+}
+
+struct Im2Col1D {
+ l_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col1D {
+ fn l_out(&self, l: usize) -> usize {
+ (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
+ }
+}
+
+impl Map1 for Im2Col1D {
+ fn f(&self, vs: &[T], layout: &Layout) -> Result> {
+ let &Self {
+ l_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, l) = layout.shape().dims3()?;
+ let l_out = self.l_out(l);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * l_out * c * l_k];
+ let (src_s0, src_s1, src_s2) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - l_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * l_out * c * l_k;
+ for l_idx in 0..l_out {
+ let dst_idx = dst_idx + l_idx * c * l_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * l_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for l_k_idx in 0..l_k {
+ let src_l = l_idx * stride + l_k_idx * dilation;
+ if padding != 0 && (src_l < padding || src_l >= l + padding) {
+ continue;
+ }
+ let src_l = src_l - padding;
+ let src_idx = src_idx + src_l * src_s2;
+ let dst_idx = dst_idx + l_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Im2Col {
+ h_k: usize,
+ w_k: usize,
+ stride: usize,
+ dilation: usize,
+ padding: usize,
+}
+
+impl Im2Col {
+ fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
+ let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
+ let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
+ (h_out, w_out)
+ }
+}
+
+impl Map1 for Im2Col {
+ fn f(&self, vs: &[T], layout: &Layout) -> Result> {
+ let &Self {
+ h_k,
+ w_k,
+ stride,
+ dilation,
+ padding,
+ } = self;
+ let (b, c, h, w) = layout.shape().dims4()?;
+ let (h_out, w_out) = self.hw_out(h, w);
+ let src = &vs[layout.start_offset()..];
+ let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
+ let (src_s0, src_s1, src_s2, src_s3) = {
+ let s = layout.stride();
+ (s[0], s[1], s[2], s[3])
+ };
+ // TODO: provide specialized kernels for the common use cases.
+ // - h_k = w_k = 1
+ // - padding = 0
+ // - stride = 1
+ // - dilation = 1
+ for b_idx in 0..b {
+ let src_idx = b_idx * src_s0;
+ let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
+ for h_idx in 0..h_out {
+ let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
+ for w_idx in 0..w_out {
+ let dst_idx = dst_idx + w_idx * c * h_k * w_k;
+ for c_idx in 0..c {
+ let dst_idx = dst_idx + c_idx * h_k * w_k;
+ let src_idx = c_idx * src_s1 + src_idx;
+ for h_k_idx in 0..h_k {
+ let src_h = h_idx * stride + h_k_idx * dilation;
+ if padding != 0 && (src_h < padding || src_h >= h + padding) {
+ continue;
+ }
+ let src_h = src_h - padding;
+ let src_idx = src_idx + src_h * src_s2;
+ let dst_idx = dst_idx + h_k_idx * w_k;
+ for w_k_idx in 0..w_k {
+ let src_w = w_idx * stride + w_k_idx * dilation;
+ if padding != 0 && (src_w < padding || src_w >= w + padding) {
+ continue;
+ }
+ let src_w = src_w - padding;
+ let src_idx = src_idx + src_w * src_s3;
+ let dst_idx = dst_idx + w_k_idx;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct Col2Im1D {
+ stride: usize,
+}
+
+impl Map1 for Col2Im1D {
+ fn f(&self, col: &[T], l: &Layout) -> Result> {
+ let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
+ let stride = self.stride;
+ let l_out = (l_in - 1) * stride + k_size;
+ let mut im = vec![T::zero(); b_size * c_out * l_out];
+ let (dst_s0, dst_s1) = (c_out * l_out, l_out);
+ let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
+ for l_in_i in 0..l_in {
+ for k_i in 0..k_size {
+ let l_out_i = l_in_i * stride + k_i;
+ for b_i in 0..b_size {
+ for c_i in 0..c_out {
+ let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
+ let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
+ im[dst_idx] += col[src_idx]
+ }
+ }
+ }
+ }
+ Ok(im)
+ }
+}
+
+struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
+
+impl<'a> Map2 for ConvTranspose1D<'a> {
+ const OP: &'static str = "conv_transpose1d";
+ fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let k = &k[k_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
+ let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
+ let l_out = p.l_out();
+
+ // Output shape: [b_size, c_out, l_out].
+ let dst_elems = p.c_out * l_out * p.b_size;
+ let dst = vec![T::zero(); dst_elems];
+ let dst_s0 = p.c_out * l_out;
+ let dst_s1 = l_out;
+ let dst_s2 = 1;
+
+ // TODO: Avoid making this copy if `inp` already has the appropriate layout.
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
+ let cont_s0 = p.l_in * p.c_in;
+ let cont_s1 = p.c_in;
+ for b_idx in 0..p.b_size {
+ for l_idx in 0..p.l_in {
+ for c_idx in 0..p.c_in {
+ let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
+ let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+
+ for k_idx in 0..p.k_size {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
+ .collect::>();
+ for b_idx in 0..p.b_size {
+ for l_idx in 0..p.l_in {
+ let out_idx = l_idx * p.stride + k_idx * p.dilation;
+ if out_idx < p.padding {
+ continue;
+ }
+ let out_idx = out_idx - p.padding;
+ if out_idx < l_out {
+ let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
+ let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
+ let mut d = T::zero();
+ unsafe {
+ T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
+ }
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to
+ // parallelise the different tasks so no two threads can try to
+ // write at the same location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
+ }
+ }
+ }
+ })
+ }
+ Ok(dst)
+ }
+}
+
+struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
+
+impl<'a> Map2 for Conv2D<'a> {
+ const OP: &'static str = "conv2d";
+ fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
+ let k = &k[k_l.start_offset()..];
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ // Output shape: [b_size, c_out, out_h, out_w].
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+
+ // TODO: Avoid making this copy if `inp` already has the appropriate layout.
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
+ let cont_s0 = p.i_h * p.i_w * p.c_in;
+ let cont_s1 = p.i_w * p.c_in;
+ let cont_s2 = p.c_in;
+ for b_idx in 0..p.b_size {
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ for c_idx in 0..p.c_in {
+ let src_idx =
+ b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+ }
+
+ for offset_h in 0..p.k_h {
+ for offset_w in 0..p.k_w {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
+ let dst_idx = dst_c_idx * out_w * out_h;
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| {
+ k[dst_c_idx * k_s0
+ + c_in_idx * k_s1
+ + offset_h * k_s2
+ + offset_w * k_s3]
+ })
+ .collect::>();
+ for b_idx in 0..p.b_size {
+ let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
+ for dst_h in 0..out_h {
+ let dst_idx = dst_idx + dst_h * out_w;
+ let src_h = p.stride * dst_h + offset_h * p.dilation;
+ if src_h < p.padding || src_h >= p.i_h + p.padding {
+ continue;
+ }
+ let src_h = src_h - p.padding;
+ for dst_w in 0..out_w {
+ let dst_idx = dst_idx + dst_w;
+ let src_w = p.stride * dst_w + offset_w * p.dilation;
+ if src_w < p.padding || src_w >= p.i_w + p.padding {
+ continue;
+ }
+ let src_w = src_w - p.padding;
+ let inp_cont = &inp_cont
+ [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
+ assert!(inp_cont.len() >= p.c_in);
+ assert!(k_cont.len() >= p.c_in);
+ let mut d = T::zero();
+ unsafe {
+ T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
+ }
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
+ // the different tasks so no two threads can try to write at the same
+ // location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
+ }
+ }
+ }
+ });
+ }
+ }
+
+ Ok(dst)
+ }
+}
+
+struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
+
+impl<'a> Map2 for ConvTranspose2D<'a> {
+ const OP: &'static str = "conv_transpose2d";
+ fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> {
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
+ let k = &k[k_l.start_offset()..];
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ // Output shape: [b_size, c_out, out_h, out_w].
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+ let dst_s0 = p.c_out * out_h * out_w;
+ let dst_s1 = out_h * out_w;
+ let dst_s2 = out_w;
+ let dst_s3 = 1;
+
+ // TODO: Avoid making this copy if `inp` already has the appropriate layout.
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
+ let cont_s0 = p.i_h * p.i_w * p.c_in;
+ let cont_s1 = p.i_w * p.c_in;
+ let cont_s2 = p.c_in;
+ for b_idx in 0..p.b_size {
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ for c_idx in 0..p.c_in {
+ let src_idx =
+ b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+ }
+
+ for k_y in 0..p.k_h {
+ for k_x in 0..p.k_w {
+ (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| {
+ k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
+ })
+ .collect::>();
+ for b_idx in 0..p.b_size {
+ for inp_y in 0..p.i_h {
+ for inp_x in 0..p.i_w {
+ let out_x = inp_x * p.stride + k_x * p.dilation;
+ let out_y = inp_y * p.stride + k_y * p.dilation;
+ if out_x < p.padding || out_y < p.padding {
+ continue;
+ }
+ let out_x = out_x - p.padding;
+ let out_y = out_y - p.padding;
+ if out_x < out_w && out_y < out_h {
+ let inp_cont = &inp_cont
+ [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
+ let dst_idx = b_idx * dst_s0
+ + out_y * dst_s2
+ + out_x * dst_s3
+ + dst_c_idx * dst_s1;
+ let mut d = T::zero();
+ unsafe {
+ T::vec_dot(
+ inp_cont.as_ptr(),
+ k_cont.as_ptr(),
+ &mut d,
+ p.c_in,
+ )
+ }
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to
+ // parallelise the different tasks so no two threads can try to
+ // write at the same location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
+ }
+ }
+ }
+ }
+ })
+ }
+ }
+ Ok(dst)
+ }
+}
+
+struct MatMul((usize, usize, usize, usize));
+
+impl MatMul {
+ fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
+ Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
+ lhs_l: lhs_l.clone(),
+ rhs_l: rhs_l.clone(),
+ bmnk: self.0,
+ msg,
+ }))
+ .bt()
+ }
+
+ fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+ let (_b, m, n, k) = self.0;
+ let a_skip: usize = match lhs_stride[..rank - 2] {
+ [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
+ [_, stride] if lhs_l.dims()[0] == 1 => stride,
+ [stride, _] if lhs_l.dims()[1] == 1 => stride,
+ [stride] => stride,
+ [] => m * k,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
+ };
+ let b_skip: usize = match rhs_stride[..rank - 2] {
+ [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
+ [_, stride] if rhs_l.dims()[0] == 1 => stride,
+ [stride, _] if rhs_l.dims()[1] == 1 => stride,
+ [stride] => stride,
+ [] => n * k,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
+ };
+ Ok((a_skip, b_skip))
+ }
+}
+
+impl Map2 for MatMul {
+ const OP: &'static str = "mat_mul";
+
+ #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result> {
+ use gemm::{gemm, Parallelism};
+
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
+ }
+
+ let (b, m, n, k) = self.0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ let mut dst = vec![T::zero(); b * m * n];
+ let num_threads = crate::utils::get_num_threads();
+ let parallelism = if num_threads > 1 {
+ Parallelism::Rayon(num_threads)
+ } else {
+ Parallelism::None
+ };
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ /* m: usize = */ m,
+ /* n: usize = */ n,
+ /* k: usize = */ k,
+ /* dst: *mut T = */ dst_p.as_mut_ptr(),
+ /* dst_cs: isize = */ dst_cs as isize,
+ /* dst_rs: isize = */ dst_rs as isize,
+ /* read_dst: bool = */ false,
+ /* lhs: *const T = */ lhs_p.as_ptr(),
+ /* lhs_cs: isize = */ lhs_cs as isize,
+ /* lhs_rs: isize = */ lhs_rs as isize,
+ /* rhs: *const T = */ rhs_p.as_ptr(),
+ /* rhs_cs: isize = */ rhs_cs as isize,
+ /* rhs_rs: isize = */ rhs_rs as isize,
+ /* alpha: T = */ T::zero(),
+ /* beta: T = */ T::one(),
+ /* conj_dst: bool = */ false,
+ /* conj_lhs: bool = */ false,
+ /* conj_rhs: bool = */ false,
+ parallelism,
+ )
+ }
+ }
+ Ok(dst)
+ }
+
+ #[cfg(feature = "accelerate")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result> {
+ let (b, m, n, k) = self.0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ let mut dst = vec![T::zero(); b * m * n];
+ match T::DTYPE {
+ DType::F16 => {
+ crate::bail!("the accelerate backend does not support f16 matmul")
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::sgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::dgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(dst)
+ }
+
+ #[cfg(feature = "mkl")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result> {
+ let (b, m, n, k) = self.0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ let mut dst = vec![T::zero(); b * m * n];
+ match T::DTYPE {
+ DType::F16 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f16;
+ let b = lhs_p.as_ptr() as *const f16;
+ let c = dst_p.as_mut_ptr() as *mut f16;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::hgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ f16::ONE,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ f16::ZERO,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::sgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::dgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(dst)
+ }
+}
+
+struct MatMulWithBias(MatMul);
+
+impl Deref for MatMulWithBias {
+ type Target = MatMul;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl Map3 for MatMulWithBias {
+ const OP: &'static str = "mat_mul_ac";
+
+ #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ c: &mut [T],
+ c_l: &Layout,
+ s: Option,
+ ) -> Result<()> {
+ use gemm::{gemm, Parallelism};
+
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
+ }
+
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ let num_threads = crate::utils::get_num_threads();
+ let parallelism = if num_threads > 1 {
+ Parallelism::Rayon(num_threads)
+ } else {
+ Parallelism::None
+ };
+
+ match c_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ if o1 != 0 {
+ crate::bail!("`c` start offset must be 0");
+ }
+ if o2 != b * m * n {
+ crate::bail!("`c` end offset must be {}", b * m * n)
+ }
+ }
+ None => crate::bail!("`c` has to be contiguous"),
+ };
+
+ let alpha = T::from_f64(s.unwrap_or(1.0));
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ gemm(
+ /* m: usize = */ m,
+ /* n: usize = */ n,
+ /* k: usize = */ k,
+ /* dst: *mut T = */ dst_p.as_mut_ptr(),
+ /* dst_cs: isize = */ dst_cs as isize,
+ /* dst_rs: isize = */ dst_rs as isize,
+ /* read_dst: bool = */ true,
+ /* lhs: *const T = */ lhs_p.as_ptr(),
+ /* lhs_cs: isize = */ lhs_cs as isize,
+ /* lhs_rs: isize = */ lhs_rs as isize,
+ /* rhs: *const T = */ rhs_p.as_ptr(),
+ /* rhs_cs: isize = */ rhs_cs as isize,
+ /* rhs_rs: isize = */ rhs_rs as isize,
+ /* alpha: T = */ T::one(),
+ /* beta: T = */ alpha,
+ /* conj_dst: bool = */ false,
+ /* conj_lhs: bool = */ false,
+ /* conj_rhs: bool = */ false,
+ parallelism,
+ )
+ }
+ }
+ Ok(())
+ }
+
+ #[cfg(feature = "accelerate")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ c: &mut [T],
+ c_l: &Layout,
+ s: Option,
+ ) -> Result<()> {
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ match c_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ if o1 != 0 {
+ crate::bail!("`c` start offset must be 0");
+ }
+ if o2 != b * m * n {
+ crate::bail!("`c` end offset must be {}", b * m * n)
+ }
+ }
+ None => crate::bail!("`c` has to be contiguous"),
+ };
+
+ match T::DTYPE {
+ DType::F16 => {
+ crate::bail!("the accelerate backend does not support f16 matmul")
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::sgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.) as f32,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 1.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::dgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.) as f64,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 1.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(())
+ }
+
+ #[cfg(feature = "mkl")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ c: &mut [T],
+ c_l: &Layout,
+ s: Option,
+ ) -> Result<()> {
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ match c_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ if o1 != 0 {
+ crate::bail!("`c` start offset must be 0");
+ }
+ if o2 != b * m * n {
+ crate::bail!("`c` end offset must be {}", b * m * n)
+ }
+ }
+ None => crate::bail!("`c` has to be contiguous"),
+ };
+
+ match T::DTYPE {
+ DType::F16 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f16;
+ let b = lhs_p.as_ptr() as *const f16;
+ let c = dst_p.as_mut_ptr() as *mut f16;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::hgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ f16::from_f64(s.unwrap_or(1.)),
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ f16::ONE,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::sgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.) as f32,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 0.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut c[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::dgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.),
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 0.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(())
+ }
+}
+
+struct MatMulWithAlpha(MatMul);
+
+impl Deref for MatMulWithAlpha {
+ type Target = MatMul;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl Map2Alpha for MatMulWithAlpha {
+ const OP: &'static str = "mat_mul_a";
+
+ #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ s: Option,
+ ) -> Result> {
+ use gemm::{gemm, Parallelism};
+
+ match T::DTYPE {
+ DType::F16 | DType::F32 | DType::F64 => {}
+ _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
+ }
+
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ let mut dst = vec![T::zero(); b * m * n];
+ let num_threads = crate::utils::get_num_threads();
+ let parallelism = if num_threads > 1 {
+ Parallelism::Rayon(num_threads)
+ } else {
+ Parallelism::None
+ };
+
+ let alpha = T::from_f64(s.unwrap_or(1.0));
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ /* m: usize = */ m,
+ /* n: usize = */ n,
+ /* k: usize = */ k,
+ /* dst: *mut T = */ dst_p.as_mut_ptr(),
+ /* dst_cs: isize = */ dst_cs as isize,
+ /* dst_rs: isize = */ dst_rs as isize,
+ /* read_dst: bool = */ true,
+ /* lhs: *const T = */ lhs_p.as_ptr(),
+ /* lhs_cs: isize = */ lhs_cs as isize,
+ /* lhs_rs: isize = */ lhs_rs as isize,
+ /* rhs: *const T = */ rhs_p.as_ptr(),
+ /* rhs_cs: isize = */ rhs_cs as isize,
+ /* rhs_rs: isize = */ rhs_rs as isize,
+ /* alpha: T = */ T::one(),
+ /* beta: T = */ alpha,
+ /* conj_dst: bool = */ false,
+ /* conj_lhs: bool = */ false,
+ /* conj_rhs: bool = */ false,
+ parallelism,
+ )
+ }
+ }
+ Ok(dst)
+ }
+
+ #[cfg(feature = "accelerate")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ s: Option,
+ ) -> Result> {
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ let mut dst = vec![T::zero(); b * m * n];
+ match T::DTYPE {
+ DType::F16 => {
+ crate::bail!("the accelerate backend does not support f16 matmul")
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::sgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.) as f32,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 1.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::accelerate::dgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.),
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 1.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(dst)
+ }
+
+ #[cfg(feature = "mkl")]
+ fn f(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ s: Option,
+ ) -> Result> {
+ let (b, m, n, k) = self.0 .0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+
+ let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
+ let c_skip: usize = m * n;
+
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+
+ let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, b'N')
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ (k as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
+ };
+ // The b tensor has dims batching, m, k (lhs)
+ let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, b'N')
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ (m as i32, b'T')
+ } else {
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
+ };
+
+ let mut dst = vec![T::zero(); b * m * n];
+ match T::DTYPE {
+ DType::F16 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f16;
+ let b = lhs_p.as_ptr() as *const f16;
+ let c = dst_p.as_mut_ptr() as *mut f16;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::hgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ f16::from_f64(s.unwrap_or(1.)),
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ f16::ONE,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::sgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.) as f32,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 0.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::dgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ s.unwrap_or(1.),
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ 0.,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
+ }
+ Ok(dst)
+ }
+}
+
+fn elu(v: T, alpha: T) -> T {
+ if v.is_sign_positive() {
+ v
+ } else {
+ (v.exp() - T::one()) * alpha
+ }
+}
+
+impl CpuStorage {
+ pub fn as_slice(&self) -> Result<&[D]> {
+ D::cpu_storage_as_slice(self)
+ }
+
+ pub fn concat(storages: &[CpuStorage]) -> Result {
+ let storage0 = &storages[0];
+ let s = match storage0 {
+ Self::U8(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::U8(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::U8(storages)
+ }
+ Self::U32(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::U32(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::U32(storages)
+ }
+ Self::I16(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::I16(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::I16(storages)
+ }
+ Self::I32(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::I32(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::I32(storages)
+ }
+ Self::I64(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::I64(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::I64(storages)
+ }
+ Self::BF16(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::BF16(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::BF16(storages)
+ }
+ Self::F16(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F16(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F16(storages)
+ }
+ Self::F32(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F32(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F32(storages)
+ }
+ Self::F64(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F64(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::>>()?
+ .concat();
+ Self::F64(storages)
+ }
+ Self::F8E4M3(_) => {
+ let storages = storages
+ .iter()
+ .map(|s| match s {
+ Self::F8E4M3(s) => Ok(s.as_slice()),
+ _ => crate::bail!("dtype mismatch"),
+ })
+ .collect::