Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Enable PluginManager::nn_preload #74

Merged
merged 7 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = "Apache-2.0"
name = "wasmedge-sdk"
readme = "README.md"
repository = "https://github.com/WasmEdge/wasmedge-rust-sdk"
version = "0.12.2"
version = "0.12.3-dev"

[dependencies]
anyhow = "1.0"
Expand Down
2 changes: 1 addition & 1 deletion crates/wasmedge-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ links = "wasmedge"
name = "wasmedge-sys"
readme = "README.md"
repository = "https://github.com/WasmEdge/wasmedge-rust-sdk"
version = "0.17.2"
version = "0.17.3"

[dependencies]
fiber-for-wasmedge = { version = "8.0.1", optional = true }
Expand Down
22 changes: 11 additions & 11 deletions crates/wasmedge-sys/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ impl PluginManager {
Ok(())
}

// #[cfg(feature = "wasi_nn")]
// #[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
// pub fn nn_preload(preloads: Vec<&str>) {
// let c_args: Vec<CString> = preloads
// .iter()
// .map(|&x| std::ffi::CString::new(x).unwrap())
// .collect();
// let c_strs: Vec<*const i8> = c_args.iter().map(|x| x.as_ptr()).collect();
// let len = c_strs.len() as u32;
// unsafe { ffi::WasmEdge_PluginInitWASINN(c_strs.as_ptr(), len) }
// }
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
pub fn nn_preload(preloads: Vec<&str>) {
let c_args: Vec<CString> = preloads
.iter()
.map(|&x| std::ffi::CString::new(x).unwrap())
.collect();
let c_strs: Vec<*const i8> = c_args.iter().map(|x| x.as_ptr()).collect();
let len = c_strs.len() as u32;
unsafe { ffi::WasmEdge_PluginInitWASINN(c_strs.as_ptr(), len) }
}

/// Returns the count of loaded plugins.
pub fn count() -> u32 {
Expand Down
107 changes: 102 additions & 5 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,96 @@ pub mod ffi {
};
}

#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug)]
pub struct NNPreload {
/// The alias of the model in the WASI-NN environment.
pub alias: String,
/// The inference backend.
pub backend: NNBackend,
/// The execution target, on which the inference runs.
pub target: ExecutionTarget,
/// The path to the model file. Note that the path is the guest path instead of the host path.
pub path: std::path::PathBuf,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl NNPreload {
pub fn new(
alias: impl AsRef<str>,
backend: NNBackend,
target: ExecutionTarget,
path: impl AsRef<std::path::Path>,
) -> Self {
Self {
alias: alias.as_ref().to_owned(),
backend,
target,
path: path.as_ref().to_owned(),
}
}
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for NNPreload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{alias}:{backend}:{target}:{path}",
alias = self.alias,
backend = self.backend,
target = self.target,
path = self.path.to_string_lossy().into_owned()
)
}
}

#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[allow(non_camel_case_types)]
pub enum NNBackend {
PyTorch,
TensorFlowLite,
OpenVINO,
GGML,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for NNBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NNBackend::PyTorch => write!(f, "PyTorch"),
NNBackend::TensorFlowLite => write!(f, "TensorFlowLite"),
NNBackend::OpenVINO => write!(f, "OpenVINO"),
NNBackend::GGML => write!(f, "GGML"),
}
}
}

/// Define where the graph should be executed.
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[allow(non_camel_case_types)]
pub enum ExecutionTarget {
CPU,
GPU,
TPU,
}
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
impl std::fmt::Display for ExecutionTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExecutionTarget::CPU => write!(f, "CPU"),
ExecutionTarget::GPU => write!(f, "GPU"),
ExecutionTarget::TPU => write!(f, "TPU"),
}
}
}

/// Defines the API to manage plugins.
#[derive(Debug)]
pub struct PluginManager {}
Expand Down Expand Up @@ -46,11 +136,18 @@ impl PluginManager {
}
}

// #[cfg(feature = "wasi_nn")]
// #[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
// pub fn nn_preload(preloads: Vec<&str>) {
// sys::plugin::PluginManager::nn_preload(preloads);
// }
#[cfg(feature = "wasi_nn")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasi_nn")))]
pub fn nn_preload(preloads: Vec<NNPreload>) {
let mut nn_preloads = Vec::new();
for preload in preloads {
nn_preloads.push(preload.to_string());
}

let nn_preloads_str: Vec<&str> = nn_preloads.iter().map(|s| s.as_str()).collect();

sys::plugin::PluginManager::nn_preload(nn_preloads_str);
}

/// Returns the count of loaded plugins.
pub fn count() -> u32 {
Expand Down