Skip to content

Commit

Permalink
added some safety comments and decreased unsafe scope
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Feb 29, 2024
1 parent 07ab0eb commit 0b2867c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 51 deletions.
42 changes: 13 additions & 29 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,10 @@ impl LlamaContext<'_> {
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to [p1].
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from [p0].
pub fn copy_kv_cache_seq(&mut self, src: i32, dest: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(
self.context.as_ptr(),
src,
dest,
p0.map_or(-1, i32::from),
p1.map_or(-1, i32::from),
)
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1)
}
}

Expand All @@ -44,13 +40,10 @@ impl LlamaContext<'_> {
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to [p1].
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from [p0].
pub fn clear_kv_cache_seq(&mut self, src: i32, p0: Option<u16>, p1: Option<u16>) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_rm(
self.context.as_ptr(),
src,
p0.map_or(-1, i32::from),
p1.map_or(-1, i32::from),
);
llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1);
}
}

Expand Down Expand Up @@ -85,14 +78,10 @@ impl LlamaContext<'_> {
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from [p0].
/// * `delta` - The relative position to add to the tokens
pub fn kv_cache_seq_add(&mut self, seq_id: i32, p0: Option<u16>, p1: Option<u16>, delta: i32) {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(
self.context.as_ptr(),
seq_id,
p0.map_or(-1, i32::from),
p1.map_or(-1, i32::from),
delta,
)
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta)
}
}

Expand All @@ -114,15 +103,10 @@ impl LlamaContext<'_> {
p1: Option<u16>,
d: NonZeroU8,
) {
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_div(
self.context.as_ptr(),
seq_id,
p0.map_or(-1, i32::from),
p1.map_or(-1, i32::from),
c_int::from(d.get()),
)
}
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
}

/// Returns the largest position present in the KV cache for the specified sequence
Expand Down
52 changes: 30 additions & 22 deletions llama-cpp-2/src/context/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,33 @@ impl LlamaContext<'_> {
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;

let cstr = CString::new(path)?;
let mut tokens = Vec::with_capacity(max_tokens);
let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
let mut n_out = 0;

unsafe {
if llama_cpp_sys_2::llama_load_session_file(
// SAFETY: cast is valid as LlamaToken is repr(transparent)
let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();

let load_session_success = unsafe {
llama_cpp_sys_2::llama_load_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
// cast is valid as LlamaToken is repr(transparent)
Vec::<LlamaToken>::as_mut_ptr(&mut tokens).cast::<llama_cpp_sys_2::llama_token>(),
tokens_out,
max_tokens,
&mut n_out,
) {
if n_out > max_tokens {
return Err(LoadSessionError::InsufficientMaxLength {
n_out,
max_tokens,
});
}
tokens.set_len(n_out);
Ok(tokens)
)
};
if load_session_success {
if n_out > max_tokens {
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
} else {
Err(LoadSessionError::FailedToLoad)
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
unsafe {
tokens.set_len(n_out);
}
}
Ok(tokens)
} else {
Err(LoadSessionError::FailedToLoad)
}
}

Expand All @@ -138,19 +142,23 @@ impl LlamaContext<'_> {
}

/// Copies the state to the specified destination address.
/// Destination needs to have allocated enough memory.
///
/// Returns the number of bytes copied
///
/// # Safety
///
/// Destination needs to have allocated enough memory.
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
unsafe {
llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest)
}
unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
}

/// Set the state reading from the specified address
/// Returns the number of bytes read
///
/// # Safety
///
/// help wanted: not entirely sure what the safety requirements are here.
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
unsafe {
llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr())
}
unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
}
}

0 comments on commit 0b2867c

Please sign in to comment.