Skip to content

Commit

Permalink
Respects torch.device(0) new behavior without breaking backward compa…
Browse files Browse the repository at this point in the history
…tibility (#509)

* Respects torch.device(0) new behavior without breaking backward
compatibilty.

* Fixing anonymous device type, apparently number string is not ok.
  • Loading branch information
Narsil authored Aug 1, 2024
1 parent 8d21261 commit 74c4e16
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 45 deletions.
67 changes: 22 additions & 45 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ enum Device {
Npu(usize),
Xpu(usize),
Xla(usize),
/// User didn't specify acceletor, torch
/// is responsible for choosing.
Anonymous(usize),
}

/// Parsing the device index.
fn parse_device(name: &str) -> PyResult<usize> {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(device)
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}

impl<'source> FromPyObject<'source> for Device {
Expand All @@ -279,56 +295,16 @@ impl<'source> FromPyObject<'source> for Device {
"npu" => Ok(Device::Npu(0)),
"xpu" => Ok(Device::Xpu(0)),
"xla" => Ok(Device::Xla(0)),
name if name.starts_with("cuda:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Cuda(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("npu:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Npu(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("xpu:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Xpu(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("xla:") => {
let tokens: Vec<_> = name.split(':').collect();
if tokens.len() == 2 {
let device: usize = tokens[1].parse()?;
Ok(Device::Xla(device))
} else {
Err(SafetensorError::new_err(format!(
"device {name} is invalid"
)))
}
}
name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
name if name.starts_with("xla:") => parse_device(name).map(Device::Xla),
name => Err(SafetensorError::new_err(format!(
"device {name} is invalid"
))),
}
} else if let Ok(number) = ob.extract::<usize>() {
Ok(Device::Cuda(number))
Ok(Device::Anonymous(number))
} else {
Err(SafetensorError::new_err(format!("device {ob} is invalid")))
}
Expand All @@ -344,6 +320,7 @@ impl IntoPy<PyObject> for Device {
Device::Npu(n) => format!("npu:{n}").into_py(py),
Device::Xpu(n) => format!("xpu:{n}").into_py(py),
Device::Xla(n) => format!("xla:{n}").into_py(py),
Device::Anonymous(n) => n.into_py(py),
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def test_npu(self):
for k, v in reloaded.items():
self.assertTrue(torch.allclose(data[k], reloaded[k]))

@unittest.skipIf(not torch.cuda.is_available(), "Cuda is not available")
def test_anonymous_accelerator(self):
data = {
"test1": torch.zeros((2, 2), dtype=torch.float32).to(device=0),
"test2": torch.zeros((2, 2), dtype=torch.float16).to(device=0),
}
local = "./tests/data/out_safe_pt_mmap_small_anonymous.safetensors"
save_file(data, local)

reloaded = load_file(local, device=0)
for k, v in reloaded.items():
self.assertTrue(torch.allclose(data[k], reloaded[k]))

def test_sparse(self):
data = {"test": torch.sparse_coo_tensor(size=(2, 3))}
local = "./tests/data/out_safe_pt_sparse.safetensors"
Expand Down

0 comments on commit 74c4e16

Please sign in to comment.