diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 49e333e0..95384a93 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -207,6 +207,28 @@ pub enum StringToTokenError { CIntConversionError(#[from] std::num::TryFromIntError), } +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum NewLlamaChatMessageError { + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), +} + +/// Failed to apply model chat template. +#[derive(Debug, thiserror::Error)] +pub enum ApplyChatTemplateError { + /// the buffer was too small. + #[error("The buffer was too small. Please contact a maintainer and we will update it.")] + BuffSizeError, + /// the string contained a null byte and thus could not be converted to a c string. + #[error("{0}")] + NulError(#[from] NulError), + /// the string could not be converted to utf8. + #[error("{0}")] + FromUtf8Error(#[from] FromUtf8Error), +} + /// Get the time in microseconds according to ggml /// /// ``` diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 9f01ac24..a39e70e1 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::LlamaTokenType; use crate::{ - ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, - TokenToStringError, + ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, + NewLlamaChatMessageError, StringToTokenError, TokenToStringError, }; pub mod params; @@ -25,6 +25,23 @@ pub struct LlamaModel { pub(crate) model: NonNull, } +/// A Safe wrapper around `llama_chat_message` +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct LlamaChatMessage { + role: CString, + content: CString, +} + +impl LlamaChatMessage { + /// Create a new `LlamaChatMessage` + pub fn new(role: String, content: String) -> Result { + Ok(Self { + role: CString::new(role)?, + content: CString::new(content)?, + }) + } +} + /// How to determine if we should prepend a bos token to tokens #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AddBos { @@ -312,17 +329,16 @@ impl LlamaModel { /// Get chat template from model. /// /// # Errors - /// + /// /// * If the model has no chat template /// * If the chat template is not a valid [`CString`]. #[allow(clippy::missing_panics_doc)] // we statically know this will not panic as pub fn get_chat_template(&self, buf_size: usize) -> Result { - // longest known template is about 1200 bytes from llama.cpp let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null"); let chat_ptr = chat_temp.into_raw(); let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes"); - + let chat_template: String = unsafe { let ret = llama_cpp_sys_2::llama_model_meta_val_str( self.model.as_ptr(), @@ -337,7 +353,7 @@ impl LlamaModel { debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len()); template }; - + Ok(chat_template) } @@ -388,6 +404,60 @@ impl LlamaModel { Ok(LlamaContext::new(self, context, params.embeddings())) } + + /// Apply the models chat template to some messages. + /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// + /// `tmpl` of None means to use the default template provided by llama.cpp for the model + /// + /// # Errors + /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. + #[tracing::instrument(skip_all)] + pub fn apply_chat_template( + &self, + tmpl: Option, + chat: Vec, + add_ass: bool, + ) -> Result { + // Buffer is twice the length of messages per their recommendation + let message_length = chat.iter().fold(0, |acc, c| { + acc + c.role.to_bytes().len() + c.content.to_bytes().len() + }); + let mut buff: Vec = vec![0_i8; message_length * 2]; + + // Build our llama_cpp_sys_2 chat messages + let chat: Vec = chat + .iter() + .map(|c| llama_cpp_sys_2::llama_chat_message { + role: c.role.as_ptr(), + content: c.content.as_ptr(), + }) + .collect(); + // Set the tmpl pointer + let tmpl = tmpl.map(CString::new); + let tmpl_ptr = match tmpl { + Some(str) => str?.as_ptr(), + None => std::ptr::null(), + }; + let formatted_chat = unsafe { + let res = llama_cpp_sys_2::llama_chat_apply_template( + self.model.as_ptr(), + tmpl_ptr, + chat.as_ptr(), + chat.len(), + add_ass, + buff.as_mut_ptr().cast::(), + buff.len() as i32, + ); + // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it + // The error message informs the user to contact a maintainer + if res > buff.len() as i32 { + return Err(ApplyChatTemplateError::BuffSizeError); + } + String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect()) + }?; + Ok(formatted_chat) + } } impl Drop for LlamaModel {