diff --git a/src/opencl/error.rs b/src/opencl/error.rs index 6bb1488..c3bda06 100644 --- a/src/opencl/error.rs +++ b/src/opencl/error.rs @@ -18,8 +18,8 @@ pub enum GPUError { KernelNotFound(String), #[error("IO Error: {0}")] IO(#[from] std::io::Error), - #[error("Cannot get bus ID for device with vendor {0}")] - MissingBusId(String), + #[error("Vendor {0} is not supported.")] + UnsupportedVendor(String), } #[allow(clippy::upper_case_acronyms)] diff --git a/src/opencl/mod.rs b/src/opencl/mod.rs index 879c22f..420fa80 100644 --- a/src/opencl/mod.rs +++ b/src/opencl/mod.rs @@ -2,6 +2,7 @@ mod error; mod utils; use std::collections::HashMap; +use std::convert::TryFrom; use std::fmt; use std::hash::{Hash, Hasher}; use std::ptr; @@ -19,36 +20,37 @@ use opencl3::types::CL_BLOCKING; pub type BusId = u32; +const AMD_DEVICE_VENDOR_STRING: &str = "AMD"; +const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation"; + #[allow(non_camel_case_types)] pub type cl_device_id = opencl3::types::cl_device_id; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Brand { +pub enum Vendor { Amd, - Apple, Nvidia, } -impl Brand { - /// Returns a brand by name if it exists - fn by_name(name: &str) -> Option { - match name { - "NVIDIA CUDA" => Some(Self::Nvidia), - "AMD Accelerated Parallel Processing" => Some(Self::Amd), - "Apple" => Some(Self::Apple), - _ => None, +impl TryFrom<&str> for Vendor { + type Error = GPUError; + + fn try_from(vendor: &str) -> GPUResult { + match vendor { + AMD_DEVICE_VENDOR_STRING => Ok(Self::Amd), + NVIDIA_DEVICE_VENDOR_STRING => Ok(Self::Nvidia), + _ => Err(GPUError::UnsupportedVendor(vendor.to_string())), } } } -impl fmt::Display for Brand { +impl fmt::Display for Vendor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let brand = match self { - Brand::Nvidia => "NVIDIA CUDA", - Brand::Amd => "AMD Accelerated Parallel Processing", - Brand::Apple => "Apple", + let vendor = match self { + Self::Amd => AMD_DEVICE_VENDOR_STRING, + Self::Nvidia => NVIDIA_DEVICE_VENDOR_STRING, }; - write!(f, "{}", brand) + write!(f, "{}", vendor) } } @@ -59,7 +61,7 @@ pub struct Buffer { #[derive(Debug, Clone)] pub struct Device { - brand: Brand, + vendor: Vendor, name: String, memory: u64, bus_id: Option, @@ -81,9 +83,10 @@ impl PartialEq for Device { impl Eq for Device {} impl Device { - pub fn brand(&self) -> Brand { - self.brand + pub fn vendor(&self) -> Vendor { + self.vendor } + pub fn name(&self) -> String { self.name.clone() } @@ -101,7 +104,7 @@ impl Device { self.bus_id } - /// Return all available GPU devices of supported brands. + /// Return all available GPU devices of supported vendors. pub fn all() -> Vec<&'static Device> { Self::all_iter().collect() } @@ -384,7 +387,8 @@ impl<'a> Kernel<'a> { #[cfg(test)] mod test { - use super::Device; + use super::{Device, GPUError, Vendor, AMD_DEVICE_VENDOR_STRING, NVIDIA_DEVICE_VENDOR_STRING}; + use std::convert::TryFrom; #[test] fn test_device_all() { @@ -393,4 +397,22 @@ mod test { dbg!(&devices.len()); } } + + #[test] + fn test_vendor_from_str() { + assert_eq!( + Vendor::try_from(AMD_DEVICE_VENDOR_STRING).unwrap(), + Vendor::Amd, + "AMD vendor string can be converted." + ); + assert_eq!( + Vendor::try_from(NVIDIA_DEVICE_VENDOR_STRING).unwrap(), + Vendor::Nvidia, + "Nvidia vendor string can be converted." + ); + assert!(matches!( + Vendor::try_from("unknown vendor"), + Err(GPUError::UnsupportedVendor(_)) + )); + } } diff --git a/src/opencl/utils.rs b/src/opencl/utils.rs index 3cc17e7..26d6b8a 100644 --- a/src/opencl/utils.rs +++ b/src/opencl/utils.rs @@ -1,3 +1,4 @@ +use std::convert::{TryFrom, TryInto}; use std::fmt::Write; use lazy_static::lazy_static; @@ -5,17 +6,13 @@ use log::{debug, warn}; use opencl3::device::DeviceInfo::CL_DEVICE_GLOBAL_MEM_SIZE; use sha2::{Digest, Sha256}; -use super::{Brand, Device, GPUError, GPUResult}; - -const AMD_DEVICE_VENDOR_STRING: &str = "AMD"; -const NVIDIA_DEVICE_VENDOR_STRING: &str = "NVIDIA Corporation"; +use super::{Device, GPUError, GPUResult, Vendor}; fn get_bus_id(d: &opencl3::device::Device) -> Result { - let vendor = d.vendor()?; - match vendor.as_str() { - AMD_DEVICE_VENDOR_STRING => d.pci_bus_id_amd().map_err(Into::into), - NVIDIA_DEVICE_VENDOR_STRING => d.pci_bus_id_nv().map_err(Into::into), - _ => Err(GPUError::MissingBusId(vendor)), + let vendor = Vendor::try_from(d.vendor()?.as_str())?; + match vendor { + Vendor::Amd => d.pci_bus_id_amd().map_err(Into::into), + Vendor::Nvidia => d.pci_bus_id_nv().map_err(Into::into), } } @@ -58,52 +55,46 @@ fn build_device_list() -> Vec { let platforms: Vec<_> = opencl3::platform::get_platforms().unwrap_or_default(); for platform in platforms.iter() { - let platform_name = match platform.name() { - Ok(pn) => pn, - Err(error) => { - warn!("Cannot get platform name: {:?}", error); - continue; - } - }; - if let Some(brand) = Brand::by_name(&platform_name) { - let devices = platform - .get_devices(opencl3::device::CL_DEVICE_TYPE_GPU) - .map_err(Into::into) - .and_then(|devices| { - devices - .into_iter() - .map(opencl3::device::Device::new) - .filter(|d| { - if let Ok(vendor) = d.vendor() { - match vendor.as_str() { - // Only use devices from the accepted vendors ... - AMD_DEVICE_VENDOR_STRING | NVIDIA_DEVICE_VENDOR_STRING => { - // ... which are available. - return d.available().unwrap_or(0) != 0; - } - _ => (), - } + let devices = platform + .get_devices(opencl3::device::CL_DEVICE_TYPE_GPU) + .map_err(Into::into) + .and_then(|devices| { + devices + .into_iter() + .map(opencl3::device::Device::new) + .filter(|d| { + if let Ok(vendor) = d.vendor() { + // Only use devices from the accepted vendors ... + if Vendor::try_from(vendor.as_str()).is_ok() { + // ... which are available. + return d.available().unwrap_or(0) != 0; } - false + } + false + }) + .map(|d| -> GPUResult<_> { + Ok(Device { + vendor: d.vendor()?.as_str().try_into()?, + name: d.name()?, + memory: get_memory(&d)?, + bus_id: get_bus_id(&d).ok(), + device: d, }) - .map(|d| -> GPUResult<_> { - Ok(Device { - brand, - name: d.name()?, - memory: get_memory(&d)?, - bus_id: get_bus_id(&d).ok(), - device: d, - }) - }) - .collect::>>() - }); - match devices { - Ok(mut devices) => { - all_devices.append(&mut devices); - } - Err(err) => { - warn!("Unable to retrieve devices for {:?}: {:?}", brand, err); - } + }) + .collect::>>() + }); + match devices { + Ok(mut devices) => { + all_devices.append(&mut devices); + } + Err(err) => { + let platform_name = platform + .name() + .unwrap_or_else(|_| "".to_string()); + warn!( + "Unable to retrieve devices for {}: {:?}", + platform_name, err + ); } } }