Skip to content

Commit

Permalink
Merge pull request #316 from SilasMarvin/silas-fix-apply-chat-template
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn authored May 25, 2024
2 parents 584077b + 2cb4498 commit 99d1c41
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ impl LlamaModel {
/// # Errors
///
/// See [`TokenToStringError`] for more information.
pub fn token_to_str(&self, token: LlamaToken, special: Special) -> Result<String, TokenToStringError> {
pub fn token_to_str(
&self,
token: LlamaToken,
special: Special,
) -> Result<String, TokenToStringError> {
self.token_to_str_with_size(token, 32, special)
}

Expand All @@ -122,7 +126,11 @@ impl LlamaModel {
/// # Errors
///
/// See [`TokenToStringError`] for more information.
pub fn token_to_bytes(&self, token: LlamaToken, special: Special) -> Result<Vec<u8>, TokenToStringError> {
pub fn token_to_bytes(
&self,
token: LlamaToken,
special: Special,
) -> Result<Vec<u8>, TokenToStringError> {
self.token_to_bytes_with_size(token, 32, special)
}

Expand All @@ -131,9 +139,17 @@ impl LlamaModel {
/// # Errors
///
/// See [`TokenToStringError`] for more information.
pub fn tokens_to_str(&self, tokens: &[LlamaToken], special: Special) -> Result<String, TokenToStringError> {
pub fn tokens_to_str(
&self,
tokens: &[LlamaToken],
special: Special,
) -> Result<String, TokenToStringError> {
let mut builder = String::with_capacity(tokens.len() * 4);
for str in tokens.iter().copied().map(|t| self.token_to_str(t, special)) {
for str in tokens
.iter()
.copied()
.map(|t| self.token_to_str(t, special))
{
builder += &str?;
}
Ok(builder)
Expand Down Expand Up @@ -451,12 +467,14 @@ impl LlamaModel {
content: c.content.as_ptr(),
})
.collect();

// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match tmpl {
Some(str) => str?.as_ptr(),
let tmpl_ptr = match &tmpl {
Some(str) => str.as_ref().map_err(|e| e.clone())?.as_ptr(),
None => std::ptr::null(),
};

let formatted_chat = unsafe {
let res = llama_cpp_sys_2::llama_chat_apply_template(
self.model.as_ptr(),
Expand Down

0 comments on commit 99d1c41

Please sign in to comment.