Skip to content

Commit

Permalink
Metal part 1 - Scaffolding for metal. (#1308)
Browse files Browse the repository at this point in the history
* Metal part 1 - Scaffolding for metal.

* Remove tracing.
  • Loading branch information
Narsil authored Nov 10, 2023
1 parent 18d3000 commit 26c4e5b
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 16 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }

[profile.release-with-debug]
inherits = "release"
Expand Down
2 changes: 2 additions & 0 deletions candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
Expand All @@ -39,3 +40,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal"]
51 changes: 42 additions & 9 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal,
}

#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}

pub trait NdArray {
Expand Down Expand Up @@ -128,10 +130,15 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}

pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}

pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}

Expand All @@ -147,21 +154,20 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}

pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
matches!(self, Self::Cpu)
}

pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
matches!(self, Self::Cuda(_))
}

pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}

pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
Expand Down Expand Up @@ -194,6 +200,11 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
}
}
}

Expand Down Expand Up @@ -228,6 +239,10 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
}

Expand All @@ -250,6 +265,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}

Expand All @@ -263,6 +282,10 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}

Expand All @@ -274,6 +297,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}

Expand All @@ -285,6 +313,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}
2 changes: 2 additions & 0 deletions candle-core/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
_ => todo!(),
};

write!(f, "Tensor[")?;
Expand Down Expand Up @@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal => todo!(),
};

write!(
Expand Down
223 changes: 223 additions & 0 deletions candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};

#[derive(Debug, Clone)]
pub struct MetalDevice;

#[derive(Debug)]
pub struct MetalStorage;

#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}

impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}

macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}

impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;

fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn dtype(&self) -> DType {
fail!()
}

fn device(&self) -> &Self::Device {
fail!()
}

fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}

fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}

impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}

fn location(&self) -> crate::DeviceLocation {
fail!()
}

fn same_device(&self, _: &Self) -> bool {
fail!()
}

fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}
Loading

0 comments on commit 26c4e5b

Please sign in to comment.