From 26c4e5bf1d10532c9b681f07a7b08b2c84844bee Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 08:35:48 +0100 Subject: [PATCH] Metal part 1 - Scaffolding for metal. (#1308) * Metal part 1 - Scaffolding for metal. * Remove tracing. --- Cargo.toml | 1 + candle-core/Cargo.toml | 2 + candle-core/src/device.rs | 51 +++++- candle-core/src/display.rs | 2 + candle-core/src/dummy_metal_backend.rs | 223 +++++++++++++++++++++++++ candle-core/src/error.rs | 8 +- candle-core/src/lib.rs | 7 + candle-core/src/op.rs | 44 ++++- candle-core/src/storage.rs | 108 +++++++++++- candle-core/src/tensor.rs | 9 +- candle-core/src/utils.rs | 4 + candle-examples/src/lib.rs | 17 +- candle-pyo3/src/lib.rs | 13 ++ 13 files changed, 473 insertions(+), 16 deletions(-) create mode 100644 candle-core/src/dummy_metal_backend.rs diff --git a/Cargo.toml b/Cargo.toml index a1981993fc..a0d597e77a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 8e57127a46..c5521c92fc 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,6 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true } +metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } @@ -39,3 +40,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] +metal = ["dep:metal"] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 9dfcd7d50b..de57c03ac6 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, + Metal, } #[derive(Debug, Clone)] pub enum Device { Cpu, Cuda(crate::CudaDevice), + Metal(crate::MetalDevice), } pub trait NdArray { @@ -128,10 +130,15 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn new_metal(ordinal: usize) -> Result { + Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) + } + pub fn set_seed(&self, seed: u64) -> Result<()> { match self { - Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed), + Self::Cpu => CpuDevice.set_seed(seed), Self::Cuda(c) => c.set_seed(seed), + Self::Metal(m) => m.set_seed(seed), } } @@ -147,21 +154,20 @@ impl Device { match self { Self::Cpu => DeviceLocation::Cpu, Self::Cuda(device) => device.location(), + Device::Metal(device) => device.location(), } } pub fn is_cpu(&self) -> bool { - match self { - Self::Cpu => true, - Self::Cuda(_) => false, - } + matches!(self, Self::Cpu) } pub fn is_cuda(&self) -> bool { - match self { - Self::Cpu => false, - Self::Cuda(_) => true, - } + matches!(self, Self::Cuda(_)) + } + + pub fn is_metal(&self) -> bool { + matches!(self, Self::Metal(_)) } pub fn cuda_if_available(ordinal: usize) -> Result { @@ -194,6 +200,11 @@ impl Device { Ok(Storage::Cuda(storage)) } } + Device::Metal(_device) => { + // let storage = device.rand_uniform(shape, dtype, lo, up)?; + // Ok(Storage::Metal(storage)) + crate::bail!("Metal rand_uniform not implemented") + } } } @@ -228,6 +239,10 @@ impl Device { Ok(Storage::Cuda(storage)) } } + Device::Metal(device) => { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Metal(storage)) + } } } @@ -250,6 +265,10 @@ impl Device { let storage = device.ones_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } + Device::Metal(device) => { + let storage = device.ones_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) + } } } @@ -263,6 +282,10 @@ impl Device { let storage = device.zeros_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } + Device::Metal(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) + } } } @@ -274,6 +297,11 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } + Device::Metal(device) => { + let storage = array.to_cpu_storage(); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) + } } } @@ -285,6 +313,11 @@ impl Device { let storage = device.storage_from_cpu_storage(&storage)?; Ok(Storage::Cuda(storage)) } + Device::Metal(device) => { + let storage = S::to_cpu_storage_owned(data); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) + } } } } diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index b497699b77..215c28f64f 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -14,6 +14,7 @@ impl Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } + _ => todo!(), }; write!(f, "Tensor[")?; @@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } + crate::DeviceLocation::Metal => todo!(), }; write!( diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs new file mode 100644 index 0000000000..e9d923310b --- /dev/null +++ b/candle-core/src/dummy_metal_backend.rs @@ -0,0 +1,223 @@ +#![allow(dead_code)] +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; + +#[derive(Debug, Clone)] +pub struct MetalDevice; + +#[derive(Debug)] +pub struct MetalStorage; + +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +macro_rules! fail { + () => { + unimplemented!("metal support has not been enabled, add `metal` feature to enable.") + }; +} + +impl crate::backend::BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn dtype(&self) -> DType { + fail!() + } + + fn device(&self) -> &Self::Device { + fail!() + } + + fn to_cpu_storage(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn to_dtype(&self, _: &Layout, _: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn unary_impl(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv1d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv1D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn matmul( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } +} + +impl crate::backend::BackendDevice for MetalDevice { + type Storage = MetalStorage; + fn new(_: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn location(&self) -> crate::DeviceLocation { + fail!() + } + + fn same_device(&self, _: &Self) -> bool { + fail!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } +} diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 96a2b8096b..60ddea1149 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,4 +1,4 @@ -use crate::{DType, DeviceLocation, Layout, Shape}; +use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] pub struct MatMulUnexpectedStriding { @@ -152,6 +152,9 @@ pub enum Error { #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, + #[error("the candle crate has not been built with metal support")] + NotCompiledWithMetalSupport, + #[error("cannot find tensor {path}")] CannotFindTensor { path: String }, @@ -159,6 +162,9 @@ pub enum Error { #[error(transparent)] Cuda(Box), + #[error("Metal error {0}")] + Metal(#[from] MetalError), + #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 73830229cf..da61bdb574 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -49,6 +49,7 @@ mod device; pub mod display; mod dtype; mod dummy_cuda_backend; +mod dummy_metal_backend; pub mod error; mod indexer; pub mod layout; @@ -87,6 +88,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage}; #[cfg(not(feature = "cuda"))] pub use dummy_cuda_backend::{CudaDevice, CudaStorage}; +#[cfg(feature = "metal")] +pub use metal_backend::{MetalDevice, MetalError, MetalStorage}; + +#[cfg(not(feature = "metal"))] +pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage}; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 1345078c7c..fbb20f6ce6 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,5 +1,5 @@ #![allow(clippy::redundant_closure_call)] -use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor}; +use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor}; use half::{bf16, f16}; use num_traits::float::Float; @@ -184,6 +184,18 @@ pub trait CustomOp1 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _storage: &MetalStorage, + _layout: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + /// This function takes as argument the argument `arg` used in the forward pass, the result /// produced by the forward operation `res` and the gradient of the result `grad_res`. /// The function should return the gradient of the argument. @@ -219,6 +231,20 @@ pub trait CustomOp2 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + fn bwd( &self, _arg1: &Tensor, @@ -261,6 +287,22 @@ pub trait CustomOp3 { )) } + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + fn bwd( &self, _arg1: &Tensor, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index dc75c02ccb..9e1a2c1dee 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,6 +1,6 @@ use crate::backend::BackendStorage; use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp}; -use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; +use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape pub enum Storage { Cpu(CpuStorage), Cuda(CudaStorage), + Metal(MetalStorage), } impl Storage { @@ -18,6 +19,10 @@ impl Storage { let storage = storage.try_clone(layout)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.try_clone(layout)?; + Ok(Self::Metal(storage)) + } } } @@ -25,6 +30,7 @@ impl Storage { match self { Self::Cpu(_) => Device::Cpu, Self::Cuda(storage) => Device::Cuda(storage.device().clone()), + Self::Metal(storage) => Device::Metal(storage.device().clone()), } } @@ -32,6 +38,7 @@ impl Storage { match self { Self::Cpu(storage) => storage.dtype(), Self::Cuda(storage) => storage.dtype(), + Self::Metal(storage) => storage.dtype(), } } @@ -65,6 +72,10 @@ impl Storage { let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.affine(layout, mul, add)?; + Ok(Self::Metal(storage)) + } } } @@ -78,6 +89,10 @@ impl Storage { let storage = storage.powf(layout, alpha)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Metal(storage)) + } } } @@ -91,6 +106,10 @@ impl Storage { let storage = storage.elu(layout, alpha)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Metal(storage)) + } } } @@ -112,6 +131,10 @@ impl Storage { let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. @@ -135,6 +158,10 @@ impl Storage { let storage = storage.reduce_op(op, layout, s)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.reduce_op(op, layout, s)?; + Ok(Self::Metal(storage)) + } } } @@ -148,6 +175,10 @@ impl Storage { let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.to_dtype(layout, dtype)?; + Ok(Self::Metal(storage)) + } } } @@ -161,6 +192,10 @@ impl Storage { let (storage, shape) = c.cuda_fwd(storage, l)?; Ok((Self::Cuda(storage), shape)) } + Self::Metal(storage) => { + let (storage, shape) = c.metal_fwd(storage, l)?; + Ok((Self::Metal(storage), shape)) + } } } @@ -181,6 +216,10 @@ impl Storage { let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?; Ok((Self::Cuda(s), shape)) } + (Self::Metal(s1), Self::Metal(s2)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?; + Ok((Self::Metal(s), shape)) + } _ => unreachable!(), } } @@ -205,6 +244,10 @@ impl Storage { let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?; Ok((Self::Cuda(s), shape)) } + (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Metal(s), shape)) + } _ => unreachable!(), } } @@ -219,6 +262,10 @@ impl Storage { let storage = storage.unary_impl::(layout)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.unary_impl::(layout)?; + Ok(Self::Metal(storage)) + } } } @@ -239,6 +286,10 @@ impl Storage { let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. @@ -270,6 +321,10 @@ impl Storage { let s = inp.conv1d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -324,6 +379,10 @@ impl Storage { let s = inp.conv2d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -351,6 +410,10 @@ impl Storage { let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; Ok(Self::Cuda(s)) } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -375,6 +438,10 @@ impl Storage { let storage = storage.avg_pool2d(layout, kernel_size, stride)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } } } @@ -393,6 +460,10 @@ impl Storage { let storage = storage.max_pool2d(layout, kernel_size, stride)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } } } @@ -406,6 +477,10 @@ impl Storage { let storage = storage.upsample_nearest1d(layout, sz)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Metal(storage)) + } } } @@ -419,6 +494,10 @@ impl Storage { let storage = storage.upsample_nearest2d(layout, h, w)?; Ok(Self::Cuda(storage)) } + Self::Metal(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Metal(storage)) + } } } @@ -442,6 +521,10 @@ impl Storage { let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } + (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => { + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; + Ok(Self::Metal(storage)) + } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -468,6 +551,10 @@ impl Storage { let storage = s.gather(l, indexes, indexes_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -492,6 +579,10 @@ impl Storage { let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -516,6 +607,10 @@ impl Storage { let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } _ => unreachable!(), } } @@ -537,6 +632,10 @@ impl Storage { let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -564,6 +663,10 @@ impl Storage { let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), @@ -583,6 +686,9 @@ impl Storage { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), + (Self::Metal(src), Self::Metal(dst)) => { + Ok(src.copy_strided_src(dst, dst_offset, src_l)?) + } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 133b278229..f032a89633 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -6,7 +6,7 @@ use crate::op::{ }; use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; -use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; /// Unique identifier for tensors. @@ -529,6 +529,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1454,6 +1455,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1484,6 +1486,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1524,6 +1527,7 @@ impl Tensor { match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -1849,6 +1853,9 @@ impl Tensor { Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?) } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), + _ => { + bail!("not implemented yet") + } }; let op = BackpropOp::new1(self, Op::ToDevice); let tensor_ = Tensor_ { diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index a9c2df0b2e..78c45a9a9d 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool { cfg!(feature = "cuda") } +pub fn metal_is_available() -> bool { + cfg!(feature = "metal") +} + pub fn with_avx() -> bool { cfg!(target_feature = "avx") } diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 4ef97f8862..dff31b8552 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -2,17 +2,28 @@ pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; +use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Result, Tensor}; pub fn device(cpu: bool) -> Result { if cpu { Ok(Device::Cpu) + } else if cuda_is_available() { + Ok(Device::new_cuda(0)?) + } else if metal_is_available() { + Ok(Device::new_metal(0)?) } else { - let device = Device::cuda_if_available(0)?; - if !device.is_cuda() { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + println!( + "Running on CPU, to run on GPU(metal), build this example with `--features metal`" + ); + } + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] + { println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); } - Ok(device) + Ok(Device::Cpu) } } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 05a786efa0..b0c623d3f1 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -71,11 +71,13 @@ impl PyDType { } static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); +static METAL_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum PyDevice { Cpu, Cuda, + Metal, } impl PyDevice { @@ -83,6 +85,7 @@ impl PyDevice { match device { Device::Cpu => Self::Cpu, Device::Cuda(_) => Self::Cuda, + Device::Metal(_) => Self::Metal, } } @@ -98,6 +101,15 @@ impl PyDevice { *device = Some(d.clone()); Ok(d) } + Self::Metal => { + let mut device = METAL_DEVICE.lock().unwrap(); + if let Some(device) = device.as_ref() { + return Ok(device.clone()); + }; + let d = Device::new_metal(0).map_err(wrap_err)?; + *device = Some(d.clone()); + Ok(d) + } } } } @@ -119,6 +131,7 @@ impl ToPyObject for PyDevice { let str = match self { PyDevice::Cpu => "cpu", PyDevice::Cuda => "cuda", + PyDevice::Metal => "metal", }; str.to_object(py) }