Skip to content

Commit

Permalink
Add support for accelerate in the pyo3 bindings. (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 24, 2023
1 parent 807e3f9 commit 7bd0fab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion candle-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ name = "candle"
crate-type = ["cdylib"]

[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }

[build-dependencies]
pyo3-build-config = "0.19"

[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
3 changes: 3 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use half::{bf16, f16};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

pub fn wrap_err(err: ::candle::Error) -> PyErr {
Expand Down
5 changes: 5 additions & 0 deletions candle-pyo3/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import candle

print(f"mkl: {candle.utils.has_mkl()}")
print(f"accelerate: {candle.utils.has_accelerate()}")
print(f"num-threads: {candle.utils.get_num_threads()}")
print(f"cuda: {candle.utils.cuda_is_available()}")

t = candle.Tensor(42.0)
print(t)
print(t.shape, t.rank, t.device)
Expand Down

0 comments on commit 7bd0fab

Please sign in to comment.