Skip to content

Commit

Permalink
Add capability to set a default device for all threads
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder committed Dec 23, 2024
1 parent 956d3fc commit 161d84f
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 54 deletions.
1 change: 1 addition & 0 deletions icicle/include/icicle/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ namespace icicle {

public:
static eIcicleError set_thread_local_device(const Device& device);
static eIcicleError set_default_device(const Device& device);
static const Device& get_thread_local_device();
static const DeviceAPI* get_thread_local_deviceAPI();
static DeviceTracker& get_global_memory_tracker() { return sMemTracker; }
Expand Down
8 changes: 8 additions & 0 deletions icicle/include/icicle/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ extern "C" eIcicleError icicle_load_backend_from_env_or_default();
*/
extern "C" eIcicleError icicle_set_device(const icicle::Device& device);

/**
* @brief Set default device for all threads
*
* @return eIcicleError::SUCCESS if successful, otherwise throws INVALID_DEVICE
*/
extern "C" eIcicleError icicle_set_default_device(const icicle::Device& device);

/**
* @brief Get active device for thread
*
Expand Down
14 changes: 14 additions & 0 deletions icicle/src/device_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ namespace icicle {

const Device& get_default_device() { return m_default_device; }

eIcicleError set_default_device(const Device& dev)
{
if (!is_device_registered(dev.type)) {
THROW_ICICLE_ERR(eIcicleError::INVALID_DEVICE, "Device type " + std::string(dev.type) + " has not been registered");
}
m_default_device = dev;
return eIcicleError::SUCCESS;
}

std::vector<std::string> get_registered_devices_list()
{
std::vector<std::string> registered_devices;
Expand Down Expand Up @@ -116,6 +125,11 @@ namespace icicle {
return default_deviceAPI.get();
}

eIcicleError DeviceAPI::set_default_device(const Device& dev)
{
return DeviceAPIRegistry::Global().set_default_device(dev);
}

/********************************************************************************** */

DeviceAPI* get_deviceAPI(const Device& device) { return DeviceAPIRegistry::Global().get_deviceAPI(device).get(); }
Expand Down
2 changes: 2 additions & 0 deletions icicle/src/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using namespace icicle;

extern "C" eIcicleError icicle_set_device(const Device& device) { return DeviceAPI::set_thread_local_device(device); }

extern "C" eIcicleError icicle_set_default_device(const Device& device) { return DeviceAPI::set_default_device(device); }

extern "C" eIcicleError icicle_get_active_device(icicle::Device& device)
{
const Device& active_device = DeviceAPI::get_thread_local_device();
Expand Down
31 changes: 31 additions & 0 deletions icicle/tests/test_device_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include <gtest/gtest.h>
#include <thread>
#include <iostream>

#include "icicle/runtime.h"
Expand All @@ -19,6 +20,36 @@ TEST_F(DeviceApiTest, UnregisteredDeviceError)
EXPECT_ANY_THROW(get_deviceAPI(dev));
}

TEST_F(DeviceApiTest, SetDefaultDevice)
{
icicle::Device active_dev = {UNKOWN_DEVICE, -1};

icicle::Device cpu_dev = {s_ref_device, 0};
EXPECT_NO_THROW(icicle_set_device(cpu_dev));
EXPECT_NO_THROW(icicle_get_active_device(active_dev));

ASSERT_EQ(cpu_dev, active_dev);

active_dev = {UNKOWN_DEVICE, -1};

icicle::Device gpu_dev = {s_main_device, 0};
EXPECT_NO_THROW(icicle_set_default_device(gpu_dev));

// setting a new default device doesn't override already set local thread devices
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
ASSERT_EQ(cpu_dev, active_dev);

active_dev = {UNKOWN_DEVICE, -1};
auto thread_func = [&active_dev, &gpu_dev]() {
EXPECT_NO_THROW(icicle_get_active_device(active_dev));
ASSERT_EQ(gpu_dev, active_dev);
};

std::thread worker_thread(thread_func);

worker_thread.join();
}

TEST_F(DeviceApiTest, MemoryCopySync)
{
int input[2] = {1, 2};
Expand Down
6 changes: 6 additions & 0 deletions wrappers/golang/runtime/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ func SetDevice(device *Device) EIcicleError {
return EIcicleError(cErr)
}

func SetDefaultDevice(device *Device) EIcicleError {
cDevice := (*C.Device)(unsafe.Pointer(device))
cErr := C.icicle_set_default_device(cDevice)
return EIcicleError(cErr)
}

func GetActiveDevice() (*Device, EIcicleError) {
device := CreateDevice("invalid", -1)
cDevice := (*C.Device)(unsafe.Pointer(&device))
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/runtime/include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ typedef struct DeviceProperties DeviceProperties;
int icicle_load_backend(const char* path, bool is_recursive);
int icicle_load_backend_from_env_or_default();
int icicle_set_device(const Device* device);
int icicle_set_default_device(const Device* device);
int icicle_get_active_device(Device* device);
int icicle_is_host_memory(const void* ptr);
int icicle_is_active_device_memory(const void* ptr);
Expand Down
118 changes: 85 additions & 33 deletions wrappers/golang/runtime/tests/device_test.go
Original file line number Diff line number Diff line change
@@ -1,70 +1,122 @@
package tests

import (
"fmt"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"testing"

"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
icicle_runtime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"

"github.com/stretchr/testify/assert"
)

func TestGetDeviceType(t *testing.T) {
expectedDeviceName := "test"
config := runtime.CreateDevice(expectedDeviceName, 0)
config := icicle_runtime.CreateDevice(expectedDeviceName, 0)
assert.Equal(t, expectedDeviceName, config.GetDeviceType())

expectedDeviceNameLong := "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest"
configLargeName := runtime.CreateDevice(expectedDeviceNameLong, 1)
configLargeName := icicle_runtime.CreateDevice(expectedDeviceNameLong, 1)
assert.NotEqual(t, expectedDeviceNameLong, configLargeName.GetDeviceType())
}

func TestIsDeviceAvailable(t *testing.T) {
runtime.LoadBackendFromEnvOrDefault()
dev := runtime.CreateDevice("CUDA", 0)
_ = runtime.SetDevice(&dev)
res, err := runtime.GetDeviceCount()

expectedNumDevices, error := exec.Command("nvidia-smi", "-L", "|", "wc", "-l").Output()
if error != nil {
t.Skip("Failed to get number of devices")
dev := icicle_runtime.CreateDevice("CUDA", 0)
_ = icicle_runtime.SetDevice(&dev)
res, err := icicle_runtime.GetDeviceCount()

smiCommand := exec.Command("nvidia-smi", "-L")
smiCommandStdout, _ := smiCommand.StdoutPipe()
wcCommand := exec.Command("wc", "-l")
wcCommand.Stdin = smiCommandStdout

smiCommand.Start()

expectedNumDevicesRaw, wcErr := wcCommand.Output()
smiCommand.Wait()

expectedNumDevicesAsString := strings.TrimRight(string(expectedNumDevicesRaw), " \n\r\t")
expectedNumDevices, _ := strconv.Atoi(expectedNumDevicesAsString)
if wcErr != nil {
t.Skip("Failed to get number of devices:", wcErr)
}

assert.Equal(t, runtime.Success, err)
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, expectedNumDevices, res)

err = runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
devCuda := runtime.CreateDevice("CUDA", 0)
assert.True(t, runtime.IsDeviceAvailable(&devCuda))
devCpu := runtime.CreateDevice("CPU", 0)
assert.True(t, runtime.IsDeviceAvailable(&devCpu))
devInvalid := runtime.CreateDevice("invalid", 0)
assert.False(t, runtime.IsDeviceAvailable(&devInvalid))
assert.Equal(t, icicle_runtime.Success, err)
devCuda := icicle_runtime.CreateDevice("CUDA", 0)
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCuda))
devCpu := icicle_runtime.CreateDevice("CPU", 0)
assert.True(t, icicle_runtime.IsDeviceAvailable(&devCpu))
devInvalid := icicle_runtime.CreateDevice("invalid", 0)
assert.False(t, icicle_runtime.IsDeviceAvailable(&devInvalid))
}

func TestSetDefaultDevice(t *testing.T) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
tidOuter := syscall.Gettid()

gpuDevice := icicle_runtime.CreateDevice("CUDA", 0)
icicle_runtime.SetDefaultDevice(&gpuDevice)

activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, gpuDevice, *activeDevice)

done := make(chan struct{}, 1)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()

// Ensure we are operating on an OS thread other than the original one
tidInner := syscall.Gettid()
for tidInner == tidOuter {
fmt.Println("Locked thread is the same as original, getting new locked thread")
runtime.UnlockOSThread()
runtime.LockOSThread()
tidInner = syscall.Gettid()
}

activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, gpuDevice, *activeDevice)

close(done)
}()

<-done

cpuDevice := icicle_runtime.CreateDevice("CPU", 0)
icicle_runtime.SetDefaultDevice(&cpuDevice)
}

func TestRegisteredDevices(t *testing.T) {
err := runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
devices, _ := runtime.GetRegisteredDevices()
devices, _ := icicle_runtime.GetRegisteredDevices()
assert.Equal(t, []string{"CUDA", "CPU"}, devices)
}

func TestDeviceProperties(t *testing.T) {
_, err := runtime.GetDeviceProperties()
assert.Equal(t, runtime.Success, err)
_, err := icicle_runtime.GetDeviceProperties()
assert.Equal(t, icicle_runtime.Success, err)
}

func TestActiveDevice(t *testing.T) {
runtime.SetDevice(&DEVICE)
activeDevice, err := runtime.GetActiveDevice()
assert.Equal(t, runtime.Success, err)
assert.Equal(t, DEVICE, *activeDevice)
memory1, err := runtime.GetAvailableMemory()
if err == runtime.ApiNotImplemented {
t.Skipf("GetAvailableMemory() function is not implemented on %s device", DEVICE.GetDeviceType())
devCpu := icicle_runtime.CreateDevice("CUDA", 0)
icicle_runtime.SetDevice(&devCpu)
activeDevice, err := icicle_runtime.GetActiveDevice()
assert.Equal(t, icicle_runtime.Success, err)
assert.Equal(t, devCpu, *activeDevice)
memory1, err := icicle_runtime.GetAvailableMemory()
if err == icicle_runtime.ApiNotImplemented {
t.Skipf("GetAvailableMemory() function is not implemented on %s device", devCpu.GetDeviceType())
}
assert.Equal(t, runtime.Success, err)
assert.Equal(t, icicle_runtime.Success, err)
assert.Greater(t, memory1.Total, uint(0))
assert.Greater(t, memory1.Free, uint(0))
}
14 changes: 1 addition & 13 deletions wrappers/golang/runtime/tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,7 @@ import (
"github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime"
)

var DEVICE runtime.Device

func TestMain(m *testing.M) {
runtime.LoadBackendFromEnvOrDefault()
devices, e := runtime.GetRegisteredDevices()
if e != runtime.Success {
panic("Failed to load registered devices")
}
for _, deviceType := range devices {
DEVICE = runtime.CreateDevice(deviceType, 0)
runtime.SetDevice(&DEVICE)

// execute tests
m.Run()
}
m.Run()
}
8 changes: 1 addition & 7 deletions wrappers/golang/runtime/tests/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ import (
)

func TestCreateStream(t *testing.T) {
err := runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
dev := runtime.CreateDevice("CUDA", 0)
assert.True(t, runtime.IsDeviceAvailable(&dev))
err = runtime.SetDevice(&dev)
err := runtime.SetDevice(&dev)
assert.Equal(t, runtime.Success, err)
_, err = runtime.CreateStream()
assert.Equal(t, runtime.Success, err, "Unable to create stream due to %d", err)
}

func TestDestroyStream(t *testing.T) {
err := runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
dev := runtime.CreateDevice("CUDA", 0)
assert.True(t, runtime.IsDeviceAvailable(&dev))
stream, err := runtime.CreateStream()
Expand All @@ -31,8 +27,6 @@ func TestDestroyStream(t *testing.T) {
}

func TestSyncStream(t *testing.T) {
err := runtime.LoadBackendFromEnvOrDefault()
assert.Equal(t, runtime.Success, err)
dev := runtime.CreateDevice("CUDA", 0)
assert.True(t, runtime.IsDeviceAvailable(&dev))
runtime.SetDevice(&dev)
Expand Down
2 changes: 1 addition & 1 deletion wrappers/rust/icicle-runtime/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::os::raw::c_char;

const MAX_TYPE_SIZE: usize = 64;

#[derive(Clone)]
#[derive(Clone, PartialEq)]
#[repr(C)]
pub struct Device {
device_type: [c_char; MAX_TYPE_SIZE],
Expand Down
10 changes: 10 additions & 0 deletions wrappers/rust/icicle-runtime/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ extern "C" {
fn icicle_load_backend(path: *const c_char, is_recursive: bool) -> eIcicleError;
fn icicle_load_backend_from_env_or_default() -> eIcicleError;
fn icicle_set_device(device: &Device) -> eIcicleError;
fn icicle_set_default_device(device: &Device) -> eIcicleError;
fn icicle_get_active_device(device: &mut Device) -> eIcicleError;
fn icicle_is_host_memory(ptr: *const c_void) -> eIcicleError;
fn icicle_is_active_device_memory(ptr: *const c_void) -> eIcicleError;
Expand Down Expand Up @@ -66,6 +67,15 @@ pub fn set_device(device: &Device) -> Result<(), eIcicleError> {
}
}

pub fn set_default_device(device: &Device) -> Result<(), eIcicleError> {
let result = unsafe { icicle_set_default_device(device) };
if result == eIcicleError::Success {
Ok(())
} else {
Err(result)
}
}

pub fn get_active_device() -> Result<Device, eIcicleError> {
let mut device: Device = Device::new("invalid", -1);
unsafe { icicle_get_active_device(&mut device).wrap_value::<Device>(device) }
Expand Down
Loading

0 comments on commit 161d84f

Please sign in to comment.