Skip to content

Commit

Permalink
Merge pull request #232 from jiabochao/fix-multi-byte-decoding
Browse files Browse the repository at this point in the history
fix: multi-byte utf8 decoding error
  • Loading branch information
MarcusDunn authored Apr 1, 2024
2 parents a0eebde + f9bd213 commit 99d9563
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 5 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bindgen = "0.69.4"
cc = "1.0.90"
anyhow = "1.0.81"
clap = "4.5.4"
encoding_rs = "0.8.33"

[workspace.lints.rust]
missing_docs = { level = "warn" }
Expand Down
40 changes: 36 additions & 4 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ impl LlamaModel {
self.token_to_str_with_size(token, 32)
}

/// Convert single token to bytes.
///
/// # Errors
///
/// See [`TokenToStringError`] for more information.
pub fn token_to_bytes(&self, token: LlamaToken) -> Result<Vec<u8>, TokenToStringError> {
self.token_to_bytes_with_size(token, 32)
}

/// Convert a vector of tokens to a single string.
///
/// # Errors
Expand Down Expand Up @@ -211,22 +220,45 @@ impl LlamaModel {
token: LlamaToken,
buffer_size: usize,
) -> Result<String, TokenToStringError> {
let bytes = self.token_to_bytes_with_size(token, buffer_size)?;
Ok(String::from_utf8(bytes)?)
}

/// Convert a token to bytes with a specified buffer size.
///
/// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
/// the extra bytes do not really matter.
///
/// # Errors
///
/// - if the token type is unknown
/// - the resultant token is larger than `buffer_size`.
///
/// # Panics
///
/// - if `buffer_size` does not fit into a [`c_int`].
/// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
pub fn token_to_bytes_with_size(
&self,
token: LlamaToken,
buffer_size: usize,
) -> Result<Vec<u8>, TokenToStringError> {
if token == self.token_nl() {
return Ok(String::from("\n"));
return Ok(String::from("\n").into_bytes());
}

match self.token_type(token) {
LlamaTokenType::Normal | LlamaTokenType::UserDefined => {}
LlamaTokenType::Control => {
if token == self.token_bos() || token == self.token_eos() {
return Ok(String::new());
return Ok(Vec::new());
}
}
LlamaTokenType::Unknown
| LlamaTokenType::Undefined
| LlamaTokenType::Byte
| LlamaTokenType::Unused => {
return Ok(String::new());
return Ok(Vec::new());
}
}

Expand All @@ -246,7 +278,7 @@ impl LlamaModel {
let mut bytes = string.into_bytes();
let len = usize::try_from(size).expect("size is positive and fits into usize");
bytes.truncate(len);
Ok(String::from_utf8(bytes)?)
Ok(bytes)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions simple/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.46" }
hf-hub = { workspace = true }
clap = { workspace = true , features = ["derive"] }
anyhow = { workspace = true }
encoding_rs = { workspace = true }

[features]
cublas = ["llama-cpp-2/cublas"]
Expand Down
9 changes: 8 additions & 1 deletion simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ either reduce n_len or increase n_ctx"

let t_main_start = ggml_time_us();

// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();

while n_cur <= n_len {
// sample the next token
{
Expand All @@ -256,7 +259,11 @@ either reduce n_len or increase n_ctx"
break;
}

print!("{}", model.token_to_str(new_token_id)?);
let output_bytes = model.token_to_bytes(new_token_id)?;
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{}", output_string);
std::io::stdout().flush()?;

batch.clear();
Expand Down

0 comments on commit 99d9563

Please sign in to comment.