From b1b3173d129e7c44c29662fffdea712979c26b6d Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Fri, 29 Dec 2023 00:11:07 -0800 Subject: [PATCH 1/2] Support XPU device --- bindings/python/src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 3ea62ff1..8ec8dad7 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -226,6 +226,7 @@ enum Device { Cuda(usize), Mps, Npu(usize), + Xpu(usize), } impl<'source> FromPyObject<'source> for Device { @@ -236,6 +237,7 @@ impl<'source> FromPyObject<'source> for Device { "cuda" => Ok(Device::Cuda(0)), "mps" => Ok(Device::Mps), "npu" => Ok(Device::Npu(0)), + "xpu" => Ok(Device::Xpu(0)), name if name.starts_with("cuda:") => { let tokens: Vec<_> = name.split(':').collect(); if tokens.len() == 2 { @@ -258,6 +260,17 @@ impl<'source> FromPyObject<'source> for Device { ))) } } + name if name.starts_with("xpu:") => { + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(Device::Xpu(device)) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } + } name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), @@ -277,6 +290,7 @@ impl IntoPy for Device { Device::Cuda(n) => format!("cuda:{n}").into_py(py), Device::Mps => "mps".into_py(py), Device::Npu(n) => format!("npu:{n}").into_py(py), + Device::Xpu(n) => format!("xpu:{n}").into_py(py), } } } From 4c0b5dc136b192ada8622eda956dff49cec99668 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 18 Jan 2024 13:31:40 +0100 Subject: [PATCH 2/2] Fmt. --- bindings/python/src/lib.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 8ec8dad7..d4cd89fb 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -261,15 +261,15 @@ impl<'source> FromPyObject<'source> for Device { } } name if name.starts_with("xpu:") => { - let tokens: Vec<_> = name.split(':').collect(); - if tokens.len() == 2 { - let device: usize = tokens[1].parse()?; - Ok(Device::Xpu(device)) - } else { - Err(SafetensorError::new_err(format!( - "device {name} is invalid" - ))) - } + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(Device::Xpu(device)) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } } name => Err(SafetensorError::new_err(format!( "device {name} is invalid"