Skip to content

Commit

Permalink
updated llama.cpp to latest + numa changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Feb 21, 2024
1 parent ff5bbae commit 1c6130c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
95 changes: 91 additions & 4 deletions llama-cpp-2/src/llama_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,86 @@ impl LlamaBackend {
#[tracing::instrument(skip_all)]
pub fn init() -> crate::Result<LlamaBackend> {
Self::mark_init()?;
unsafe { llama_cpp_sys_2::llama_backend_init(false) }
unsafe { llama_cpp_sys_2::llama_backend_init() }
Ok(LlamaBackend {})
}

/// Initialize the llama backend (with numa).
/// ```
///# use llama_cpp_2::llama_backend::LlamaBackend;
///# use std::error::Error;
///# use llama_cpp_2::llama_backend::NumaStrategy;
///
///# fn main() -> Result<(), Box<dyn Error>> {
/// let llama_backend = LlamaBackend::init_numa()?;
///
/// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?;
///
///# Ok(())
///# }
/// ```
#[tracing::instrument(skip_all)]
pub fn init_numa() -> crate::Result<LlamaBackend> {
pub fn init_numa(strategy: NumaStrategy) -> crate::Result<LlamaBackend> {
Self::mark_init()?;
unsafe { llama_cpp_sys_2::llama_backend_init(true) }
unsafe {
llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy))
}
Ok(LlamaBackend {})
}
}

/// A rusty wrapper around `numa_strategy`.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum NumaStrategy {
/// The numa strategy is disabled.
DISABLED,
/// help wanted: what does this do?
DISTRIBUTE,
/// help wanted: what does this do?
ISOLATE,
/// help wanted: what does this do?
NUMACTL,
/// help wanted: what does this do?
MIRROR,
/// help wanted: what does this do?
COUNT,
}

/// An invalid numa strategy was provided.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub struct InvalidNumaStrategy(
/// The invalid numa strategy that was provided.
pub llama_cpp_sys_2::ggml_numa_strategy,
);

impl TryFrom<llama_cpp_sys_2::ggml_numa_strategy> for NumaStrategy {
type Error = InvalidNumaStrategy;

fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result<Self, Self::Error> {
match value {
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT),
value => Err(InvalidNumaStrategy(value)),
}
}
}

impl From<NumaStrategy> for llama_cpp_sys_2::ggml_numa_strategy {
fn from(value: NumaStrategy) -> Self {
match value {
NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED,
NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE,
NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE,
NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL,
NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR,
NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT,
}
}
}

/// Drops the llama backend.
/// ```
///
Expand All @@ -92,3 +149,33 @@ impl Drop for LlamaBackend {
unsafe { llama_cpp_sys_2::llama_backend_free() }
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn numa_from_and_to() {
let numas = [
NumaStrategy::DISABLED,
NumaStrategy::DISTRIBUTE,
NumaStrategy::ISOLATE,
NumaStrategy::NUMACTL,
NumaStrategy::MIRROR,
NumaStrategy::COUNT,
];

for numa in &numas {
let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa);
let to = NumaStrategy::try_from(from).expect("Failed to convert from and to");
assert_eq!(*numa, to);
}
}

#[test]
fn check_invalid_numa() {
let invalid = 800;
let invalid = NumaStrategy::try_from(invalid);
assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0)));
}
}

0 comments on commit 1c6130c

Please sign in to comment.