Skip to content

Commit

Permalink
Merge pull request #579 from tinglou/main
Browse files Browse the repository at this point in the history
wrap llama_batch_get_one
  • Loading branch information
MarcusDunn authored Nov 27, 2024
2 parents 42aaeeb + 2822e3a commit d3eade6
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ pub enum BatchAddError {
/// There was not enough space in the batch to add the token.
#[error("Insufficient Space of {0}")]
InsufficientSpace(usize),
/// Empty buffer is provided for get_one
#[error("Empty buffer")]
EmptyBuffer,
}

impl LlamaBatch {
Expand Down Expand Up @@ -149,6 +152,31 @@ impl LlamaBatch {
}
}

/// llama_batch_get_one
/// Return batch for single sequence of tokens starting at pos_0
///
/// NOTE: this is a helper function to facilitate transition to the new batch API
///
pub fn get_one(
tokens: &[LlamaToken],
pos_0: llama_pos,
seq_id: llama_seq_id,
) -> Result<Self, BatchAddError> {
if tokens.is_empty() {
return Err(BatchAddError::EmptyBuffer);
}
let batch = unsafe {
let ptr = tokens.as_ptr() as *mut i32;
llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id)
};
let batch = Self {
allocated: 0,
initialized_logits: vec![(tokens.len() - 1) as i32],
llama_batch: batch,
};
Ok(batch)
}

/// Returns the number of tokens in the batch.
#[must_use]
pub fn n_tokens(&self) -> i32 {
Expand All @@ -170,7 +198,9 @@ impl Drop for LlamaBatch {
/// # }
fn drop(&mut self) {
unsafe {
llama_batch_free(self.llama_batch);
if self.allocated > 0 {
llama_batch_free(self.llama_batch);
}
}
}
}

0 comments on commit d3eade6

Please sign in to comment.