Skip to content

Commit

Permalink
Merge pull request #127 from SilasMarvin/silas-apply-chat-template
Browse files Browse the repository at this point in the history
Added Apply Chat Template
  • Loading branch information
MarcusDunn authored Apr 6, 2024
2 parents 89f73ec + 6f9fa32 commit 636da79
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 6 deletions.
22 changes: 22 additions & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,28 @@ pub enum StringToTokenError {
CIntConversionError(#[from] std::num::TryFromIntError),
}

/// Failed to apply model chat template.
#[derive(Debug, thiserror::Error)]
pub enum NewLlamaChatMessageError {
/// the string contained a null byte and thus could not be converted to a c string.
#[error("{0}")]
NulError(#[from] NulError),
}

/// Failed to apply model chat template.
#[derive(Debug, thiserror::Error)]
pub enum ApplyChatTemplateError {
/// the buffer was too small.
#[error("The buffer was too small. Please contact a maintainer and we will update it.")]
BuffSizeError,
/// the string contained a null byte and thus could not be converted to a c string.
#[error("{0}")]
NulError(#[from] NulError),
/// the string could not be converted to utf8.
#[error("{0}")]
FromUtf8Error(#[from] FromUtf8Error),
}

/// Get the time in microseconds according to ggml
///
/// ```
Expand Down
82 changes: 76 additions & 6 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::LlamaTokenType;
use crate::{
ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError,
TokenToStringError,
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError,
NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
};

pub mod params;
Expand All @@ -25,6 +25,23 @@ pub struct LlamaModel {
pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
}

/// A Safe wrapper around `llama_chat_message`
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
role: CString,
content: CString,
}

impl LlamaChatMessage {
/// Create a new `LlamaChatMessage`
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
Ok(Self {
role: CString::new(role)?,
content: CString::new(content)?,
})
}
}

/// How to determine if we should prepend a bos token to tokens
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddBos {
Expand Down Expand Up @@ -312,17 +329,16 @@ impl LlamaModel {
/// Get chat template from model.
///
/// # Errors
///
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {

// longest known template is about 1200 bytes from llama.cpp
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");

let chat_template: String = unsafe {
let ret = llama_cpp_sys_2::llama_model_meta_val_str(
self.model.as_ptr(),
Expand All @@ -337,7 +353,7 @@ impl LlamaModel {
debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len());
template
};

Ok(chat_template)
}

Expand Down Expand Up @@ -388,6 +404,60 @@ impl LlamaModel {

Ok(LlamaContext::new(self, context, params.embeddings()))
}

/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
///
/// # Errors
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<String>,
chat: Vec<LlamaChatMessage>,
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
// Buffer is twice the length of messages per their recommendation
let message_length = chat.iter().fold(0, |acc, c| {
acc + c.role.to_bytes().len() + c.content.to_bytes().len()
});
let mut buff: Vec<i8> = vec![0_i8; message_length * 2];

// Build our llama_cpp_sys_2 chat messages
let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
.iter()
.map(|c| llama_cpp_sys_2::llama_chat_message {
role: c.role.as_ptr(),
content: c.content.as_ptr(),
})
.collect();
// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match tmpl {
Some(str) => str?.as_ptr(),
None => std::ptr::null(),
};
let formatted_chat = unsafe {
let res = llama_cpp_sys_2::llama_chat_apply_template(
self.model.as_ptr(),
tmpl_ptr,
chat.as_ptr(),
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<std::os::raw::c_char>(),
buff.len() as i32,
);
// A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
// The error message informs the user to contact a maintainer
if res > buff.len() as i32 {
return Err(ApplyChatTemplateError::BuffSizeError);
}
String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect())
}?;
Ok(formatted_chat)
}
}

impl Drop for LlamaModel {
Expand Down

0 comments on commit 636da79

Please sign in to comment.