Skip to content

Commit

Permalink
feat: replace Brand with Vendor
Browse files Browse the repository at this point in the history
The `Brand` was related to the platform. On Macs with an AMD graphics
card, it would return `Apple`. This is not really helpful, therefore
change it to `Vendor`, which will always contain the Vendor of the
graphics card.

BREAKING CHANGE: `Brand` is removed, use `Vendor` instead.
  • Loading branch information
vmx committed Jul 6, 2021
1 parent 2827a11 commit e96c2f6
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 76 deletions.
4 changes: 2 additions & 2 deletions src/opencl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ pub enum GPUError {
KernelNotFound(String),
#[error("IO Error: {0}")]
IO(#[from] std::io::Error),
#[error("Cannot get bus ID for device with vendor {0}")]
MissingBusId(String),
#[error("Vendor {0} is not supported.")]
UnsupportedVendor(String),
}

#[allow(clippy::upper_case_acronyms)]
Expand Down
64 changes: 43 additions & 21 deletions src/opencl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod error;
mod utils;

use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ptr;
Expand All @@ -19,36 +20,37 @@ use opencl3::types::CL_BLOCKING;

pub type BusId = u32;

const AMD_DEVICE_VENDOR_STRING: &str = "AMD";
const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation";

#[allow(non_camel_case_types)]
pub type cl_device_id = opencl3::types::cl_device_id;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Brand {
pub enum Vendor {
Amd,
Apple,
Nvidia,
}

impl Brand {
/// Returns a brand by name if it exists
fn by_name(name: &str) -> Option<Self> {
match name {
"NVIDIA CUDA" => Some(Self::Nvidia),
"AMD Accelerated Parallel Processing" => Some(Self::Amd),
"Apple" => Some(Self::Apple),
_ => None,
impl TryFrom<&str> for Vendor {
type Error = GPUError;

fn try_from(vendor: &str) -> GPUResult<Self> {
match vendor {
AMD_DEVICE_VENDOR_STRING => Ok(Self::Amd),
NVIDIA_DEVICE_VENDOR_STRING => Ok(Self::Nvidia),
_ => Err(GPUError::UnsupportedVendor(vendor.to_string())),
}
}
}

impl fmt::Display for Brand {
impl fmt::Display for Vendor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let brand = match self {
Brand::Nvidia => "NVIDIA CUDA",
Brand::Amd => "AMD Accelerated Parallel Processing",
Brand::Apple => "Apple",
let vendor = match self {
Self::Amd => AMD_DEVICE_VENDOR_STRING,
Self::Nvidia => NVIDIA_DEVICE_VENDOR_STRING,
};
write!(f, "{}", brand)
write!(f, "{}", vendor)
}
}

Expand All @@ -59,7 +61,7 @@ pub struct Buffer<T> {

#[derive(Debug, Clone)]
pub struct Device {
brand: Brand,
vendor: Vendor,
name: String,
memory: u64,
bus_id: Option<BusId>,
Expand All @@ -81,9 +83,10 @@ impl PartialEq for Device {
impl Eq for Device {}

impl Device {
pub fn brand(&self) -> Brand {
self.brand
pub fn vendor(&self) -> Vendor {
self.vendor
}

pub fn name(&self) -> String {
self.name.clone()
}
Expand All @@ -101,7 +104,7 @@ impl Device {
self.bus_id
}

/// Return all available GPU devices of supported brands.
/// Return all available GPU devices of supported vendors.
pub fn all() -> Vec<&'static Device> {
Self::all_iter().collect()
}
Expand Down Expand Up @@ -384,7 +387,8 @@ impl<'a> Kernel<'a> {

#[cfg(test)]
mod test {
use super::Device;
use super::{Device, GPUError, Vendor, AMD_DEVICE_VENDOR_STRING, NVIDIA_DEVICE_VENDOR_STRING};
use std::convert::TryFrom;

#[test]
fn test_device_all() {
Expand All @@ -393,4 +397,22 @@ mod test {
dbg!(&devices.len());
}
}

#[test]
fn test_vendor_from_str() {
assert_eq!(
Vendor::try_from(AMD_DEVICE_VENDOR_STRING).unwrap(),
Vendor::Amd,
"AMD vendor string can be converted."
);
assert_eq!(
Vendor::try_from(NVIDIA_DEVICE_VENDOR_STRING).unwrap(),
Vendor::Nvidia,
"Nvidia vendor string can be converted."
);
assert!(matches!(
Vendor::try_from("unknown vendor"),
Err(GPUError::UnsupportedVendor(_))
));
}
}
97 changes: 44 additions & 53 deletions src/opencl/utils.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
use std::convert::{TryFrom, TryInto};
use std::fmt::Write;

use lazy_static::lazy_static;
use log::{debug, warn};
use opencl3::device::DeviceInfo::CL_DEVICE_GLOBAL_MEM_SIZE;
use sha2::{Digest, Sha256};

use super::{Brand, Device, GPUError, GPUResult};

const AMD_DEVICE_VENDOR_STRING: &str = "AMD";
const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation";
use super::{Device, GPUError, GPUResult, Vendor};

fn get_bus_id(d: &opencl3::device::Device) -> Result<u32, GPUError> {
let vendor = d.vendor()?;
match vendor.as_str() {
AMD_DEVICE_VENDOR_STRING => d.pci_bus_id_amd().map_err(Into::into),
NVIDIA_DEVICE_VENDOR_STRING => d.pci_bus_id_nv().map_err(Into::into),
_ => Err(GPUError::MissingBusId(vendor)),
let vendor = Vendor::try_from(d.vendor()?.as_str())?;
match vendor {
Vendor::Amd => d.pci_bus_id_amd().map_err(Into::into),
Vendor::Nvidia => d.pci_bus_id_nv().map_err(Into::into),
}
}

Expand Down Expand Up @@ -58,52 +55,46 @@ fn build_device_list() -> Vec<Device> {
let platforms: Vec<_> = opencl3::platform::get_platforms().unwrap_or_default();

for platform in platforms.iter() {
let platform_name = match platform.name() {
Ok(pn) => pn,
Err(error) => {
warn!("Cannot get platform name: {:?}", error);
continue;
}
};
if let Some(brand) = Brand::by_name(&platform_name) {
let devices = platform
.get_devices(opencl3::device::CL_DEVICE_TYPE_GPU)
.map_err(Into::into)
.and_then(|devices| {
devices
.into_iter()
.map(opencl3::device::Device::new)
.filter(|d| {
if let Ok(vendor) = d.vendor() {
match vendor.as_str() {
// Only use devices from the accepted vendors ...
AMD_DEVICE_VENDOR_STRING | NVIDIA_DEVICE_VENDOR_STRING => {
// ... which are available.
return d.available().unwrap_or(0) != 0;
}
_ => (),
}
let devices = platform
.get_devices(opencl3::device::CL_DEVICE_TYPE_GPU)
.map_err(Into::into)
.and_then(|devices| {
devices
.into_iter()
.map(opencl3::device::Device::new)
.filter(|d| {
if let Ok(vendor) = d.vendor() {
// Only use devices from the accepted vendors ...
if Vendor::try_from(vendor.as_str()).is_ok() {
// ... which are available.
return d.available().unwrap_or(0) != 0;
}
false
}
false
})
.map(|d| -> GPUResult<_> {
Ok(Device {
vendor: d.vendor()?.as_str().try_into()?,
name: d.name()?,
memory: get_memory(&d)?,
bus_id: get_bus_id(&d).ok(),
device: d,
})
.map(|d| -> GPUResult<_> {
Ok(Device {
brand,
name: d.name()?,
memory: get_memory(&d)?,
bus_id: get_bus_id(&d).ok(),
device: d,
})
})
.collect::<GPUResult<Vec<_>>>()
});
match devices {
Ok(mut devices) => {
all_devices.append(&mut devices);
}
Err(err) => {
warn!("Unable to retrieve devices for {:?}: {:?}", brand, err);
}
})
.collect::<GPUResult<Vec<_>>>()
});
match devices {
Ok(mut devices) => {
all_devices.append(&mut devices);
}
Err(err) => {
let platform_name = platform
.name()
.unwrap_or_else(|_| "<unknown platform>".to_string());
warn!(
"Unable to retrieve devices for {}: {:?}",
platform_name, err
);
}
}
}
Expand Down

0 comments on commit e96c2f6

Please sign in to comment.