diff --git a/icicle/include/icicle/device_api.h b/icicle/include/icicle/device_api.h index 483b5f96b..0768f135c 100644 --- a/icicle/include/icicle/device_api.h +++ b/icicle/include/icicle/device_api.h @@ -188,6 +188,7 @@ namespace icicle { public: static eIcicleError set_thread_local_device(const Device& device); + static eIcicleError set_default_device(const Device& device); static const Device& get_thread_local_device(); static const DeviceAPI* get_thread_local_deviceAPI(); static DeviceTracker& get_global_memory_tracker() { return sMemTracker; } diff --git a/icicle/include/icicle/runtime.h b/icicle/include/icicle/runtime.h index a14d80bd8..3dad72986 100644 --- a/icicle/include/icicle/runtime.h +++ b/icicle/include/icicle/runtime.h @@ -36,6 +36,14 @@ extern "C" eIcicleError icicle_load_backend_from_env_or_default(); */ extern "C" eIcicleError icicle_set_device(const icicle::Device& device); +/** + * @brief Set default device for all threads + * + + * @return eIcicleError::SUCCESS if successful, otherwise throws INVALID_DEVICE + */ +extern "C" eIcicleError icicle_set_default_device(const icicle::Device& device); + /** * @brief Get active device for thread * diff --git a/icicle/src/device_api.cpp b/icicle/src/device_api.cpp index 179b8a3cb..697a05885 100644 --- a/icicle/src/device_api.cpp +++ b/icicle/src/device_api.cpp @@ -58,6 +58,15 @@ namespace icicle { const Device& get_default_device() { return m_default_device; } + eIcicleError set_default_device(const Device& dev) + { + if (!is_device_registered(dev.type)) { + THROW_ICICLE_ERR(eIcicleError::INVALID_DEVICE, "Device type " + std::string(dev.type) + " has not been registered"); + } + m_default_device = dev; + return eIcicleError::SUCCESS; + } + std::vector get_registered_devices_list() { std::vector registered_devices; @@ -116,6 +125,11 @@ namespace icicle { return default_deviceAPI.get(); } + eIcicleError DeviceAPI::set_default_device(const Device& dev) + { + return DeviceAPIRegistry::Global().set_default_device(dev); + } + /********************************************************************************** */ DeviceAPI* get_deviceAPI(const Device& device) { return DeviceAPIRegistry::Global().get_deviceAPI(device).get(); } diff --git a/icicle/src/runtime.cpp b/icicle/src/runtime.cpp index 8e9028cfc..03ab6c162 100644 --- a/icicle/src/runtime.cpp +++ b/icicle/src/runtime.cpp @@ -14,6 +14,8 @@ using namespace icicle; extern "C" eIcicleError icicle_set_device(const Device& device) { return DeviceAPI::set_thread_local_device(device); } +extern "C" eIcicleError icicle_set_default_device(const Device& device) { return DeviceAPI::set_default_device(device); } + extern "C" eIcicleError icicle_get_active_device(icicle::Device& device) { const Device& active_device = DeviceAPI::get_thread_local_device(); diff --git a/icicle/tests/test_device_api.cpp b/icicle/tests/test_device_api.cpp index f63d3c987..8a5ea6887 100644 --- a/icicle/tests/test_device_api.cpp +++ b/icicle/tests/test_device_api.cpp @@ -1,5 +1,6 @@ #include +#include #include #include "icicle/runtime.h" @@ -19,6 +20,36 @@ TEST_F(DeviceApiTest, UnregisteredDeviceError) EXPECT_ANY_THROW(get_deviceAPI(dev)); } +TEST_F(DeviceApiTest, SetDefaultDevice) +{ + icicle::Device active_dev = {UNKOWN_DEVICE, -1}; + + icicle::Device cpu_dev = {s_ref_device, 0}; + EXPECT_NO_THROW(icicle_set_device(cpu_dev)); + EXPECT_NO_THROW(icicle_get_active_device(active_dev)); + + ASSERT_EQ(cpu_dev, active_dev); + + active_dev = {UNKOWN_DEVICE, -1}; + + icicle::Device gpu_dev = {s_main_device, 0}; + EXPECT_NO_THROW(icicle_set_default_device(gpu_dev)); + + // setting a new default device doesn't override already set local thread devices + EXPECT_NO_THROW(icicle_get_active_device(active_dev)); + ASSERT_EQ(cpu_dev, active_dev); + + active_dev = {UNKOWN_DEVICE, -1}; + auto thread_func = [&active_dev, &gpu_dev]() { + EXPECT_NO_THROW(icicle_get_active_device(active_dev)); + ASSERT_EQ(gpu_dev, active_dev); + }; + + std::thread worker_thread(thread_func); + + worker_thread.join(); +} + TEST_F(DeviceApiTest, MemoryCopySync) { int input[2] = {1, 2}; diff --git a/wrappers/golang/runtime/device.go b/wrappers/golang/runtime/device.go index fac5b0f09..7f0965bab 100644 --- a/wrappers/golang/runtime/device.go +++ b/wrappers/golang/runtime/device.go @@ -50,6 +50,12 @@ func SetDevice(device *Device) EIcicleError { return EIcicleError(cErr) } +func SetDefaultDevice(device *Device) EIcicleError { + cDevice := (*C.Device)(unsafe.Pointer(device)) + cErr := C.icicle_set_default_device(cDevice) + return EIcicleError(cErr) +} + func GetActiveDevice() (*Device, EIcicleError) { device := CreateDevice("invalid", -1) cDevice := (*C.Device)(unsafe.Pointer(&device)) diff --git a/wrappers/golang/runtime/include/runtime.h b/wrappers/golang/runtime/include/runtime.h index 003bb0894..16d4d348a 100644 --- a/wrappers/golang/runtime/include/runtime.h +++ b/wrappers/golang/runtime/include/runtime.h @@ -13,6 +13,7 @@ typedef struct DeviceProperties DeviceProperties; int icicle_load_backend(const char* path, bool is_recursive); int icicle_load_backend_from_env_or_default(); int icicle_set_device(const Device* device); +int icicle_set_default_device(const Device* device); int icicle_get_active_device(Device* device); int icicle_is_host_memory(const void* ptr); int icicle_is_active_device_memory(const void* ptr); diff --git a/wrappers/golang/runtime/tests/device_test.go b/wrappers/golang/runtime/tests/device_test.go index a4c114389..3219604c2 100644 --- a/wrappers/golang/runtime/tests/device_test.go +++ b/wrappers/golang/runtime/tests/device_test.go @@ -1,70 +1,122 @@ package tests import ( + "fmt" "os/exec" + "runtime" + "strconv" + "strings" + "syscall" "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + icicle_runtime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/stretchr/testify/assert" ) func TestGetDeviceType(t *testing.T) { expectedDeviceName := "test" - config := runtime.CreateDevice(expectedDeviceName, 0) + config := icicle_runtime.CreateDevice(expectedDeviceName, 0) assert.Equal(t, expectedDeviceName, config.GetDeviceType()) expectedDeviceNameLong := "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest" - configLargeName := runtime.CreateDevice(expectedDeviceNameLong, 1) + configLargeName := icicle_runtime.CreateDevice(expectedDeviceNameLong, 1) assert.NotEqual(t, expectedDeviceNameLong, configLargeName.GetDeviceType()) } func TestIsDeviceAvailable(t *testing.T) { - runtime.LoadBackendFromEnvOrDefault() - dev := runtime.CreateDevice("CUDA", 0) - _ = runtime.SetDevice(&dev) - res, err := runtime.GetDeviceCount() - - expectedNumDevices, error := exec.Command("nvidia-smi", "-L", "|", "wc", "-l").Output() - if error != nil { - t.Skip("Failed to get number of devices") + dev := icicle_runtime.CreateDevice("CUDA", 0) + _ = icicle_runtime.SetDevice(&dev) + res, err := icicle_runtime.GetDeviceCount() + + smiCommand := exec.Command("nvidia-smi", "-L") + smiCommandStdout, _ := smiCommand.StdoutPipe() + wcCommand := exec.Command("wc", "-l") + wcCommand.Stdin = smiCommandStdout + + smiCommand.Start() + + expectedNumDevicesRaw, wcErr := wcCommand.Output() + smiCommand.Wait() + + expectedNumDevicesAsString := strings.TrimRight(string(expectedNumDevicesRaw), " \n\r\t") + expectedNumDevices, _ := strconv.Atoi(expectedNumDevicesAsString) + if wcErr != nil { + t.Skip("Failed to get number of devices:", wcErr) } - assert.Equal(t, runtime.Success, err) + assert.Equal(t, icicle_runtime.Success, err) assert.Equal(t, expectedNumDevices, res) - err = runtime.LoadBackendFromEnvOrDefault() - assert.Equal(t, runtime.Success, err) - devCuda := runtime.CreateDevice("CUDA", 0) - assert.True(t, runtime.IsDeviceAvailable(&devCuda)) - devCpu := runtime.CreateDevice("CPU", 0) - assert.True(t, runtime.IsDeviceAvailable(&devCpu)) - devInvalid := runtime.CreateDevice("invalid", 0) - assert.False(t, runtime.IsDeviceAvailable(&devInvalid)) + assert.Equal(t, icicle_runtime.Success, err) + devCuda := icicle_runtime.CreateDevice("CUDA", 0) + assert.True(t, icicle_runtime.IsDeviceAvailable(&devCuda)) + devCpu := icicle_runtime.CreateDevice("CPU", 0) + assert.True(t, icicle_runtime.IsDeviceAvailable(&devCpu)) + devInvalid := icicle_runtime.CreateDevice("invalid", 0) + assert.False(t, icicle_runtime.IsDeviceAvailable(&devInvalid)) +} + +func TestSetDefaultDevice(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + tidOuter := syscall.Gettid() + + gpuDevice := icicle_runtime.CreateDevice("CUDA", 0) + icicle_runtime.SetDefaultDevice(&gpuDevice) + + activeDevice, err := icicle_runtime.GetActiveDevice() + assert.Equal(t, icicle_runtime.Success, err) + assert.Equal(t, gpuDevice, *activeDevice) + + done := make(chan struct{}, 1) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + // Ensure we are operating on an OS thread other than the original one + tidInner := syscall.Gettid() + for tidInner == tidOuter { + fmt.Println("Locked thread is the same as original, getting new locked thread") + runtime.UnlockOSThread() + runtime.LockOSThread() + tidInner = syscall.Gettid() + } + + activeDevice, err := icicle_runtime.GetActiveDevice() + assert.Equal(t, icicle_runtime.Success, err) + assert.Equal(t, gpuDevice, *activeDevice) + + close(done) + }() + + <-done + + cpuDevice := icicle_runtime.CreateDevice("CPU", 0) + icicle_runtime.SetDefaultDevice(&cpuDevice) } func TestRegisteredDevices(t *testing.T) { - err := runtime.LoadBackendFromEnvOrDefault() - assert.Equal(t, runtime.Success, err) - devices, _ := runtime.GetRegisteredDevices() + devices, _ := icicle_runtime.GetRegisteredDevices() assert.Equal(t, []string{"CUDA", "CPU"}, devices) } func TestDeviceProperties(t *testing.T) { - _, err := runtime.GetDeviceProperties() - assert.Equal(t, runtime.Success, err) + _, err := icicle_runtime.GetDeviceProperties() + assert.Equal(t, icicle_runtime.Success, err) } func TestActiveDevice(t *testing.T) { - runtime.SetDevice(&DEVICE) - activeDevice, err := runtime.GetActiveDevice() - assert.Equal(t, runtime.Success, err) - assert.Equal(t, DEVICE, *activeDevice) - memory1, err := runtime.GetAvailableMemory() - if err == runtime.ApiNotImplemented { - t.Skipf("GetAvailableMemory() function is not implemented on %s device", DEVICE.GetDeviceType()) + devCpu := icicle_runtime.CreateDevice("CUDA", 0) + icicle_runtime.SetDevice(&devCpu) + activeDevice, err := icicle_runtime.GetActiveDevice() + assert.Equal(t, icicle_runtime.Success, err) + assert.Equal(t, devCpu, *activeDevice) + memory1, err := icicle_runtime.GetAvailableMemory() + if err == icicle_runtime.ApiNotImplemented { + t.Skipf("GetAvailableMemory() function is not implemented on %s device", devCpu.GetDeviceType()) } - assert.Equal(t, runtime.Success, err) + assert.Equal(t, icicle_runtime.Success, err) assert.Greater(t, memory1.Total, uint(0)) assert.Greater(t, memory1.Free, uint(0)) } diff --git a/wrappers/golang/runtime/tests/main_test.go b/wrappers/golang/runtime/tests/main_test.go index 226fee783..800c34816 100644 --- a/wrappers/golang/runtime/tests/main_test.go +++ b/wrappers/golang/runtime/tests/main_test.go @@ -6,19 +6,7 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" ) -var DEVICE runtime.Device - func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() - devices, e := runtime.GetRegisteredDevices() - if e != runtime.Success { - panic("Failed to load registered devices") - } - for _, deviceType := range devices { - DEVICE = runtime.CreateDevice(deviceType, 0) - runtime.SetDevice(&DEVICE) - - // execute tests - m.Run() - } + m.Run() } diff --git a/wrappers/golang/runtime/tests/stream_test.go b/wrappers/golang/runtime/tests/stream_test.go index c32cae896..8acd7089f 100644 --- a/wrappers/golang/runtime/tests/stream_test.go +++ b/wrappers/golang/runtime/tests/stream_test.go @@ -8,19 +8,15 @@ import ( ) func TestCreateStream(t *testing.T) { - err := runtime.LoadBackendFromEnvOrDefault() - assert.Equal(t, runtime.Success, err) dev := runtime.CreateDevice("CUDA", 0) assert.True(t, runtime.IsDeviceAvailable(&dev)) - err = runtime.SetDevice(&dev) + err := runtime.SetDevice(&dev) assert.Equal(t, runtime.Success, err) _, err = runtime.CreateStream() assert.Equal(t, runtime.Success, err, "Unable to create stream due to %d", err) } func TestDestroyStream(t *testing.T) { - err := runtime.LoadBackendFromEnvOrDefault() - assert.Equal(t, runtime.Success, err) dev := runtime.CreateDevice("CUDA", 0) assert.True(t, runtime.IsDeviceAvailable(&dev)) stream, err := runtime.CreateStream() @@ -31,8 +27,6 @@ func TestDestroyStream(t *testing.T) { } func TestSyncStream(t *testing.T) { - err := runtime.LoadBackendFromEnvOrDefault() - assert.Equal(t, runtime.Success, err) dev := runtime.CreateDevice("CUDA", 0) assert.True(t, runtime.IsDeviceAvailable(&dev)) runtime.SetDevice(&dev) diff --git a/wrappers/rust/icicle-runtime/src/device.rs b/wrappers/rust/icicle-runtime/src/device.rs index 9248a805b..ee8f2de67 100644 --- a/wrappers/rust/icicle-runtime/src/device.rs +++ b/wrappers/rust/icicle-runtime/src/device.rs @@ -4,7 +4,7 @@ use std::os::raw::c_char; const MAX_TYPE_SIZE: usize = 64; -#[derive(Clone)] +#[derive(Clone, PartialEq)] #[repr(C)] pub struct Device { device_type: [c_char; MAX_TYPE_SIZE], diff --git a/wrappers/rust/icicle-runtime/src/runtime.rs b/wrappers/rust/icicle-runtime/src/runtime.rs index c1c88d162..fce1733a5 100644 --- a/wrappers/rust/icicle-runtime/src/runtime.rs +++ b/wrappers/rust/icicle-runtime/src/runtime.rs @@ -11,6 +11,7 @@ extern "C" { fn icicle_load_backend(path: *const c_char, is_recursive: bool) -> eIcicleError; fn icicle_load_backend_from_env_or_default() -> eIcicleError; fn icicle_set_device(device: &Device) -> eIcicleError; + fn icicle_set_default_device(device: &Device) -> eIcicleError; fn icicle_get_active_device(device: &mut Device) -> eIcicleError; fn icicle_is_host_memory(ptr: *const c_void) -> eIcicleError; fn icicle_is_active_device_memory(ptr: *const c_void) -> eIcicleError; @@ -66,6 +67,15 @@ pub fn set_device(device: &Device) -> Result<(), eIcicleError> { } } +pub fn set_default_device(device: &Device) -> Result<(), eIcicleError> { + let result = unsafe { icicle_set_default_device(device) }; + if result == eIcicleError::Success { + Ok(()) + } else { + Err(result) + } +} + pub fn get_active_device() -> Result { let mut device: Device = Device::new("invalid", -1); unsafe { icicle_get_active_device(&mut device).wrap_value::(device) } diff --git a/wrappers/rust/icicle-runtime/src/tests.rs b/wrappers/rust/icicle-runtime/src/tests.rs index e2a22b3c4..5555dd6d3 100644 --- a/wrappers/rust/icicle-runtime/src/tests.rs +++ b/wrappers/rust/icicle-runtime/src/tests.rs @@ -6,6 +6,7 @@ mod tests { use crate::test_utilities; use crate::*; use std::sync::Once; + use std::thread; static INIT: Once = Once::new(); @@ -28,6 +29,37 @@ mod tests { test_utilities::test_set_ref_device(); } + #[test] + fn test_set_default_device() { + initialize(); + + // block scope is necessary in order to free the mutex lock + // to be used by the spawned thread + let outer_thread_id = thread::current().id(); + { + let main_device = test_utilities::TEST_MAIN_DEVICE + .lock() + .unwrap(); + set_default_device(&main_device).unwrap(); + + let active_device = get_active_device().unwrap(); + assert_eq!(*main_device, active_device); + } + + let handle = thread::spawn(move || { + let inner_thread_id = thread::current().id(); + assert_ne!(outer_thread_id, inner_thread_id); + + let active_device = get_active_device().unwrap(); + let main_device = test_utilities::TEST_MAIN_DEVICE + .lock() + .unwrap(); + assert_eq!(*main_device, active_device); + }); + + let _ = handle.join(); + } + #[test] fn test_sync_memory_copy() { initialize();