Skip to content

Commit

Permalink
Fix NVML persistence mode API name
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung committed May 29, 2024
1 parent 889c034 commit ff96ff3
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 45 deletions.
2 changes: 1 addition & 1 deletion zeusd/src/devices/gpu/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl GpuManager for NvmlGpu<'static> {
}

#[inline]
fn set_persistent_mode(&mut self, enabled: bool) -> Result<(), ZeusdError> {
fn set_persistence_mode(&mut self, enabled: bool) -> Result<(), ZeusdError> {
Ok(self.device.set_persistent(enabled)?)
}

Expand Down
4 changes: 2 additions & 2 deletions zeusd/src/devices/gpu/macos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ impl NvmlGpu {

impl GpuManager for NvmlGpu {
fn device_count() -> Result<u32, ZeusdError> {
Ok(0)
Ok(1)
}

fn set_persistent_mode(&mut self, _enabled: bool) -> Result<(), ZeusdError> {
fn set_persistence_mode(&mut self, _enabled: bool) -> Result<(), ZeusdError> {
Ok(())
}

Expand Down
17 changes: 7 additions & 10 deletions zeusd/src/devices/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub trait GpuManager {
fn device_count() -> Result<u32, ZeusdError>
where
Self: Sized;
fn set_persistent_mode(&mut self, enabled: bool) -> Result<(), ZeusdError>;
fn set_persistence_mode(&mut self, enabled: bool) -> Result<(), ZeusdError>;
fn set_power_management_limit(&mut self, power_limit: u32) -> Result<(), ZeusdError>;
fn set_gpu_locked_clocks(
&mut self,
Expand Down Expand Up @@ -98,9 +98,6 @@ impl GpuManagementTasks {
command: GpuCommand,
request_start_time: Instant,
) -> Result<(), ZeusdError> {
if gpu_id >= self.senders.len() {
return Err(ZeusdError::GpuNotFoundError(gpu_id));
}
if gpu_id >= self.senders.len() {
return Err(ZeusdError::GpuNotFoundError(gpu_id));
}
Expand Down Expand Up @@ -152,8 +149,8 @@ async fn gpu_management_task<T: GpuManager>(
/// A GPU command that can be executed on a GPU.
#[derive(Debug)]
pub enum GpuCommand {
/// Enable or disable persistent mode.
SetPersistentMode { enabled: bool },
/// Enable or disable persistence mode.
SetPersistenceMode { enabled: bool },
/// Set the power management limit in milliwatts.
SetPowerLimit { power_limit_mw: u32 },
/// Set the GPU's locked clock range in MHz.
Expand All @@ -179,18 +176,18 @@ impl GpuCommand {
T: GpuManager,
{
match *self {
Self::SetPersistentMode { enabled } => {
let result = device.set_persistent_mode(enabled);
Self::SetPersistenceMode { enabled } => {
let result = device.set_persistence_mode(enabled);
if result.is_ok() {
tracing::info!(
elapsed = ?request_start_time.elapsed(),
"Persistent mode {}",
"Persistence mode {}",
if enabled { "enabled" } else { "disabled" },
);
} else {
tracing::warn!(
elapsed = ?request_start_time.elapsed(),
"Cannot {} persistent mode",
"Cannot {} persistence mode",
if enabled { "enable" } else { "disable" },
);
}
Expand Down
8 changes: 4 additions & 4 deletions zeusd/src/routes/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::error::ZeusdError;
/// Macro to generate a handler for a GPU command.
///
/// This macro takes
/// - the API name (set_power_limit, set_persistent_mode, etc.),
/// - the API name (set_power_limit, set_persistence_mode, etc.),
/// - the method and path for the request handler,
/// - and a list of `field name <type>` pairs of the corresponding `GpuCommand` variant.
///
Expand Down Expand Up @@ -80,8 +80,8 @@ macro_rules! impl_handler_for_gpu_command {
}

impl_handler_for_gpu_command!(
set_persistent_mode,
post("/{gpu_id}/set_persistent_mode"),
set_persistence_mode,
post("/{gpu_id}/set_persistence_mode"),
enabled<bool>,
);

Expand Down Expand Up @@ -116,7 +116,7 @@ impl_handler_for_gpu_command!(
);

pub fn gpu_routes(cfg: &mut web::ServiceConfig) {
cfg.service(set_persistent_mode_handler)
cfg.service(set_persistence_mode_handler)
.service(set_power_limit_handler)
.service(set_gpu_locked_clocks_handler)
.service(reset_gpu_locked_clocks_handler)
Expand Down
36 changes: 18 additions & 18 deletions zeusd/tests/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ use std::collections::HashSet;
use tokio::task::JoinSet;
use zeusd::routes::gpu::{
ResetGpuLockedClocks, ResetMemLockedClocks, SetGpuLockedClocks, SetMemLockedClocks,
SetPersistentMode, SetPowerLimit,
SetPersistenceMode, SetPowerLimit,
};

use crate::helpers::{TestApp, ZeusdRequest};

#[tokio::test]
async fn test_set_persistent_mode_single() {
async fn test_set_persistence_mode_single() {
let mut app = TestApp::start().await;

let resp = app
.send(
0,
SetPersistentMode {
SetPersistenceMode {
enabled: true,
block: true,
},
Expand All @@ -25,21 +25,21 @@ async fn test_set_persistent_mode_single() {
.expect("Failed to send request");

assert_eq!(resp.status(), 200);
let history = app.persistent_mode_history_for_gpu(0);
let history = app.persistence_mode_history_for_gpu(0);
assert_eq!(history.len(), 1);
assert_eq!(history[0], true);
}

#[tokio::test]
async fn test_set_persistent_mode_multiple() {
async fn test_set_persistence_mode_multiple() {
let mut app = TestApp::start().await;

let num_requests = 10;
for i in 0..num_requests {
let resp = app
.send(
i % 4,
SetPersistentMode {
SetPersistenceMode {
enabled: (i / 4) % 2 == 0,
block: true,
},
Expand All @@ -51,23 +51,23 @@ async fn test_set_persistent_mode_multiple() {
}

assert_eq!(
app.persistent_mode_history_for_gpu(0),
app.persistence_mode_history_for_gpu(0),
vec![true, false, true]
);
assert_eq!(
app.persistent_mode_history_for_gpu(1),
app.persistence_mode_history_for_gpu(1),
vec![true, false, true]
);
assert_eq!(app.persistent_mode_history_for_gpu(2), vec![true, false]);
assert_eq!(app.persistent_mode_history_for_gpu(3), vec![true, false]);
assert_eq!(app.persistence_mode_history_for_gpu(2), vec![true, false]);
assert_eq!(app.persistence_mode_history_for_gpu(3), vec![true, false]);
}

#[tokio::test]
async fn test_set_persistent_mode_invalid() {
async fn test_set_persistence_mode_invalid() {
let app = TestApp::start().await;

let client = reqwest::Client::new();
let url = SetPersistentMode::build_url(&app, 0);
let url = SetPersistenceMode::build_url(&app, 0);
let resp = client
.post(url)
.json(&serde_json::json!(
Expand All @@ -86,7 +86,7 @@ async fn test_set_persistent_mode_invalid() {
.expect("Failed to read response")
.contains("missing field"));

let url = SetPersistentMode::build_url(&app, 0);
let url = SetPersistenceMode::build_url(&app, 0);
let resp = client
.post(url)
.json(&serde_json::json!(
Expand All @@ -105,7 +105,7 @@ async fn test_set_persistent_mode_invalid() {
.expect("Failed to read response")
.contains("invalid type"));

let url = SetPersistentMode::build_url(&app, 5); // Invalid GPU ID
let url = SetPersistenceMode::build_url(&app, 5); // Invalid GPU ID
let resp = client
.post(url)
.json(&serde_json::json!(
Expand All @@ -121,14 +121,14 @@ async fn test_set_persistent_mode_invalid() {
}

#[tokio::test]
async fn test_set_persistent_mode_bulk() {
async fn test_set_persistence_mode_bulk() {
let mut app = TestApp::start().await;

let mut set = JoinSet::new();
for i in 0..10 {
set.spawn(app.send(
0,
SetPersistentMode {
SetPersistenceMode {
enabled: i % 3 == 0,
block: false,
},
Expand All @@ -152,7 +152,7 @@ async fn test_set_persistent_mode_bulk() {
assert_eq!(
app.send(
0,
SetPersistentMode {
SetPersistenceMode {
enabled: false,
block: true,
},
Expand All @@ -163,7 +163,7 @@ async fn test_set_persistent_mode_bulk() {
200
);

let history = app.persistent_mode_history_for_gpu(0);
let history = app.persistence_mode_history_for_gpu(0);
assert_eq!(history.len(), 11);
assert_eq!(history.iter().filter(|enabled| **enabled).count(), 4);
assert_eq!(history.iter().filter(|enabled| !**enabled).count(), 6 + 1);
Expand Down
20 changes: 10 additions & 10 deletions zeusd/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,36 @@ static TRACING: Lazy<()> = Lazy::new(|| {

#[derive(Clone)]
pub struct TestGpu {
persistent_mode_tx: UnboundedSender<bool>,
persistence_mode_tx: UnboundedSender<bool>,
power_limit_tx: UnboundedSender<u32>,
gpu_locked_clocks_tx: UnboundedSender<(u32, u32)>,
mem_locked_clocks_tx: UnboundedSender<(u32, u32)>,
valid_power_limit_range: (u32, u32),
}

pub struct TestGpuObserver {
persistent_mode_rx: UnboundedReceiver<bool>,
persistence_mode_rx: UnboundedReceiver<bool>,
power_limit_rx: UnboundedReceiver<u32>,
gpu_locked_clocks_rx: UnboundedReceiver<(u32, u32)>,
mem_locked_clocks_rx: UnboundedReceiver<(u32, u32)>,
}

impl TestGpu {
fn init() -> Result<(Self, TestGpuObserver), ZeusdError> {
let (persistent_mode_tx, persistent_mode_rx) = tokio::sync::mpsc::unbounded_channel();
let (persistence_mode_tx, persistence_mode_rx) = tokio::sync::mpsc::unbounded_channel();
let (power_limit_tx, power_limit_rx) = tokio::sync::mpsc::unbounded_channel();
let (gpu_locked_clocks_tx, gpu_locked_clocks_rx) = tokio::sync::mpsc::unbounded_channel();
let (mem_locked_clocks_tx, mem_locked_clocks_rx) = tokio::sync::mpsc::unbounded_channel();

let gpu = TestGpu {
persistent_mode_tx,
persistence_mode_tx,
power_limit_tx,
gpu_locked_clocks_tx,
mem_locked_clocks_tx,
valid_power_limit_range: (100_000, 300_000),
};
let observer = TestGpuObserver {
persistent_mode_rx,
persistence_mode_rx,
power_limit_rx,
gpu_locked_clocks_rx,
mem_locked_clocks_rx,
Expand All @@ -69,8 +69,8 @@ impl GpuManager for TestGpu {
Ok(NUM_GPUS)
}

fn set_persistent_mode(&mut self, enabled: bool) -> Result<(), ZeusdError> {
self.persistent_mode_tx.send(enabled).unwrap();
fn set_persistence_mode(&mut self, enabled: bool) -> Result<(), ZeusdError> {
self.persistence_mode_tx.send(enabled).unwrap();
Ok(())
}

Expand Down Expand Up @@ -151,7 +151,7 @@ macro_rules! impl_zeusd_request {
};
}

impl_zeusd_request!(SetPersistentMode);
impl_zeusd_request!(SetPersistenceMode);
impl_zeusd_request!(SetPowerLimit);
impl_zeusd_request!(SetGpuLockedClocks);
impl_zeusd_request!(ResetGpuLockedClocks);
Expand Down Expand Up @@ -194,8 +194,8 @@ impl TestApp {
client.post(url).json(&payload).send()
}

pub fn persistent_mode_history_for_gpu(&mut self, gpu_id: usize) -> Vec<bool> {
let rx = &mut self.observers[gpu_id].persistent_mode_rx;
pub fn persistence_mode_history_for_gpu(&mut self, gpu_id: usize) -> Vec<bool> {
let rx = &mut self.observers[gpu_id].persistence_mode_rx;
std::iter::from_fn(|| rx.try_recv().ok()).collect()
}

Expand Down

0 comments on commit ff96ff3

Please sign in to comment.