Skip to content

Commit

Permalink
clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Feb 29, 2024
1 parent 0b2867c commit 035cb57
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
44 changes: 31 additions & 13 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl LlamaContext<'_> {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1)
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
}

Expand All @@ -48,6 +48,7 @@ impl LlamaContext<'_> {
}

/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
#[must_use]
pub fn get_kv_cache_used_cells(&self) -> i32 {
unsafe { llama_cpp_sys_2::llama_get_kv_cache_used_cells(self.context.as_ptr()) }
}
Expand All @@ -68,8 +69,8 @@ impl LlamaContext<'_> {

/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
/// If the KV cache is RoPEd, the KV data is updated accordingly:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Parameters
///
Expand All @@ -81,14 +82,14 @@ impl LlamaContext<'_> {
let p0 = p0.map_or(-1, i32::from);
let p1 = p1.map_or(-1, i32::from);
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta)
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
}

/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is RoPEd, the KV data is updated accordingly:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// If the KV cache is `RoPEd`, the KV data is updated accordingly:
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Parameters
///
Expand All @@ -114,14 +115,15 @@ impl LlamaContext<'_> {
/// # Parameters
///
/// * `seq_id` - The sequence id to get the max position for
#[must_use]
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_pos_max(self.context.as_ptr(), seq_id) }
}

/// Defragment the KV cache
/// This will be applied:
/// - lazily on next llama_decode()
/// - explicitly with llama_kv_cache_update()
/// - lazily on next [`LlamaContext::decode`]
/// - explicitly with [`Self::kv_cache_update`]
pub fn kv_cache_defrag(&mut self) {
unsafe { llama_cpp_sys_2::llama_kv_cache_defrag(self.context.as_ptr()) }
}
Expand All @@ -133,6 +135,7 @@ impl LlamaContext<'_> {

/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
#[must_use]
pub fn get_kv_cache_token_count(&self) -> i32 {
unsafe { llama_cpp_sys_2::llama_get_kv_cache_token_count(self.context.as_ptr()) }
}
Expand All @@ -143,7 +146,8 @@ impl LlamaContext<'_> {
///
/// * `n_max_seq` - Maximum number of sequences that can exist in a cell. It's not an error
/// if there are more sequences in a cell than this value, however they will
/// not be visible in the view cells_sequences.
/// not be visible in the view `cells_sequences`.
#[must_use]
pub fn new_kv_cache_view(&self, n_max_seq: i32) -> KVCacheView {
let view =
unsafe { llama_cpp_sys_2::llama_kv_cache_view_init(self.context.as_ptr(), n_max_seq) };
Expand All @@ -170,39 +174,48 @@ impl<'a> KVCacheView<'a> {
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
pub fn update(&mut self) {
unsafe {
llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view)
llama_cpp_sys_2::llama_kv_cache_view_update(self.ctx.context.as_ptr(), &mut self.view);
}
}

/// Number of KV cache cells. This will be the same as the context size.
#[must_use]
pub fn n_cells(&self) -> i32 {
self.view.n_cells
}

/// Number of tokens in the cache. For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
#[must_use]
pub fn token_count(&self) -> i32 {
self.view.token_count
}

/// Number of populated cache cells.
#[must_use]
pub fn used_cells(&self) -> i32 {
self.view.used_cells
}

/// Maximum contiguous empty slots in the cache.
#[must_use]
pub fn max_contiguous(&self) -> i32 {
self.view.max_contiguous
}

/// Index to the start of the max_contiguous slot range. Can be negative
/// Index to the start of the `max_contiguous` slot range. Can be negative
/// when cache is full.
#[must_use]
pub fn max_contiguous_idx(&self) -> i32 {
self.view.max_contiguous_idx
}

/// Information for individual cells.
///
/// # Panics
///
/// - if `n_cells` does not fit into usize.
pub fn cells(&self) -> impl Iterator<Item = KVCacheViewCell> {
unsafe {
std::slice::from_raw_parts(
Expand All @@ -214,7 +227,12 @@ impl<'a> KVCacheView<'a> {
.map(|&cell| KVCacheViewCell { pos: cell.pos })
}

/// The sequences for each cell. There will be n_max_seq items per cell.
/// The sequences for each cell. There will be `n_max_seq` items per cell.
///
/// # Panics
///
/// - if `n_cells * n_max_seq` does not fit into usize.
/// - if `n_max_seq` does not fit into usize.
pub fn cells_sequences(&self) -> impl Iterator<Item = &[llama_cpp_sys_2::llama_seq_id]> {
unsafe {
std::slice::from_raw_parts(
Expand Down
12 changes: 6 additions & 6 deletions llama-cpp-2/src/context/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,10 @@ impl LlamaContext<'_> {
if load_session_success {
if n_out > max_tokens {
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
} else {
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
unsafe {
tokens.set_len(n_out);
}
}
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
unsafe {
tokens.set_len(n_out);
}
Ok(tokens)
} else {
Expand All @@ -136,7 +135,8 @@ impl LlamaContext<'_> {
}

/// Returns the maximum size in bytes of the state (rng, logits, embedding
/// and kv_cache) - will often be smaller after compacting tokens
/// and `kv_cache`) - will often be smaller after compacting tokens
#[must_use]
pub fn get_state_size(&self) -> usize {
unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
}
Expand Down

0 comments on commit 035cb57

Please sign in to comment.