From cb0ecd90ff696ddc3fe13edd270bd3f0137885fb Mon Sep 17 00:00:00 2001 From: Lou Ting Date: Thu, 21 Nov 2024 11:25:51 +0800 Subject: [PATCH 1/3] wrap llama_batch_get_one --- llama-cpp-2/src/llama_batch.rs | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index e52bfa9e..3efb7965 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -149,6 +149,25 @@ impl LlamaBatch { } } + /// llama_batch_get_one + /// Return batch for single sequence of tokens starting at pos_0 + /// + /// NOTE: this is a helper function to facilitate transition to the new batch API + /// + pub fn get_one(tokens: &[LlamaToken], pos_0: llama_pos, seq_id: llama_seq_id) -> Self { + unsafe { + let ptr = tokens.as_ptr() as *mut i32; + let batch = + llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id); + + crate::llama_batch::LlamaBatch { + allocated: 0, + initialized_logits: vec![], + llama_batch: batch, + } + } + } + /// Returns the number of tokens in the batch. #[must_use] pub fn n_tokens(&self) -> i32 { @@ -170,7 +189,9 @@ impl Drop for LlamaBatch { /// # } fn drop(&mut self) { unsafe { - llama_batch_free(self.llama_batch); + if self.allocated > 0 { + llama_batch_free(self.llama_batch); + } } } } From 81c2b05d32fd7703f40c483c316c02b8463283ec Mon Sep 17 00:00:00 2001 From: Lou Ting Date: Tue, 26 Nov 2024 16:25:45 +0800 Subject: [PATCH 2/3] get_one --- llama-cpp-2/src/llama_batch.rs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 3efb7965..31fb9d54 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -1,6 +1,6 @@ //! Safe wrapper around `llama_batch`. -use crate::token::LlamaToken; +use crate::token::{self, LlamaToken}; use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id}; /// A safe wrapper around `llama_batch`. @@ -20,6 +20,9 @@ pub enum BatchAddError { /// There was not enough space in the batch to add the token. #[error("Insufficient Space of {0}")] InsufficientSpace(usize), + /// Empty buffer is provided for get_one + #[error("Empty buffer")] + EmptyBuffer, } impl LlamaBatch { @@ -154,18 +157,24 @@ impl LlamaBatch { /// /// NOTE: this is a helper function to facilitate transition to the new batch API /// - pub fn get_one(tokens: &[LlamaToken], pos_0: llama_pos, seq_id: llama_seq_id) -> Self { - unsafe { - let ptr = tokens.as_ptr() as *mut i32; - let batch = - llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id); - - crate::llama_batch::LlamaBatch { - allocated: 0, - initialized_logits: vec![], - llama_batch: batch, - } + pub fn get_one( + tokens: &[LlamaToken], + pos_0: llama_pos, + seq_id: llama_seq_id, + ) -> Result { + if tokens.is_empty() { + return Err(BatchAddError::EmptyBuffer); } + let batch = unsafe { + let ptr = tokens.as_ptr() as *mut i32; + llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32, pos_0, seq_id) + }; + let batch = Self { + allocated: 0, + initialized_logits: vec![(tokens.len() - 1) as i32], + llama_batch: batch, + }; + Ok(batch) } /// Returns the number of tokens in the batch. From 2822e3ae86a66b571cc6de371feefb938f4f1da7 Mon Sep 17 00:00:00 2001 From: Lou Ting Date: Tue, 26 Nov 2024 16:27:53 +0800 Subject: [PATCH 3/3] fmt --- llama-cpp-2/src/llama_batch.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 31fb9d54..8a2fd376 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -1,6 +1,6 @@ //! Safe wrapper around `llama_batch`. -use crate::token::{self, LlamaToken}; +use crate::token::LlamaToken; use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id}; /// A safe wrapper around `llama_batch`.