From 2cb4498142d044116131a20779536c91ac547d6c Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sat, 25 May 2024 09:40:06 -0700 Subject: [PATCH] Fixed apply_chat_template --- llama-cpp-2/src/model.rs | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 8dc83a95..9fd940f2 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -113,7 +113,11 @@ impl LlamaModel { /// # Errors /// /// See [`TokenToStringError`] for more information. - pub fn token_to_str(&self, token: LlamaToken, special: Special) -> Result { + pub fn token_to_str( + &self, + token: LlamaToken, + special: Special, + ) -> Result { self.token_to_str_with_size(token, 32, special) } @@ -122,7 +126,11 @@ impl LlamaModel { /// # Errors /// /// See [`TokenToStringError`] for more information. - pub fn token_to_bytes(&self, token: LlamaToken, special: Special) -> Result, TokenToStringError> { + pub fn token_to_bytes( + &self, + token: LlamaToken, + special: Special, + ) -> Result, TokenToStringError> { self.token_to_bytes_with_size(token, 32, special) } @@ -131,9 +139,17 @@ impl LlamaModel { /// # Errors /// /// See [`TokenToStringError`] for more information. - pub fn tokens_to_str(&self, tokens: &[LlamaToken], special: Special) -> Result { + pub fn tokens_to_str( + &self, + tokens: &[LlamaToken], + special: Special, + ) -> Result { let mut builder = String::with_capacity(tokens.len() * 4); - for str in tokens.iter().copied().map(|t| self.token_to_str(t, special)) { + for str in tokens + .iter() + .copied() + .map(|t| self.token_to_str(t, special)) + { builder += &str?; } Ok(builder) @@ -451,12 +467,14 @@ impl LlamaModel { content: c.content.as_ptr(), }) .collect(); + // Set the tmpl pointer let tmpl = tmpl.map(CString::new); - let tmpl_ptr = match tmpl { - Some(str) => str?.as_ptr(), + let tmpl_ptr = match &tmpl { + Some(str) => str.as_ref().map_err(|e| e.clone())?.as_ptr(), None => std::ptr::null(), }; + let formatted_chat = unsafe { let res = llama_cpp_sys_2::llama_chat_apply_template( self.model.as_ptr(),