From 0fc8e1e61c73fa60c7514b2d432269c5da02ec2c Mon Sep 17 00:00:00 2001 From: jiabochao Date: Mon, 1 Apr 2024 12:54:51 +0800 Subject: [PATCH] fix: multi-byte utf8 decoding error --- Cargo.lock | 10 ++++++++++ Cargo.toml | 1 + llama-cpp-2/src/model.rs | 40 ++++++++++++++++++++++++++++++++++++---- simple/Cargo.toml | 1 + simple/src/main.rs | 9 ++++++++- 5 files changed, 56 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 44cd5409..65109367 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -281,6 +281,15 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "encoding_rs" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +dependencies = [ + "cfg-if", +] + [[package]] name = "errno" version = "0.3.8" @@ -907,6 +916,7 @@ version = "0.1.46" dependencies = [ "anyhow", "clap", + "encoding_rs", "hf-hub", "llama-cpp-2", ] diff --git a/Cargo.toml b/Cargo.toml index 1c6eba10..18547cc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bindgen = "0.69.4" cc = "1.0.90" anyhow = "1.0.81" clap = "4.5.3" +encoding_rs = "0.8.33" [workspace.lints.rust] missing_docs = { level = "warn" } diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 5f412c25..9f01ac24 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -90,6 +90,15 @@ impl LlamaModel { self.token_to_str_with_size(token, 32) } + /// Convert single token to bytes. + /// + /// # Errors + /// + /// See [`TokenToStringError`] for more information. + pub fn token_to_bytes(&self, token: LlamaToken) -> Result, TokenToStringError> { + self.token_to_bytes_with_size(token, 32) + } + /// Convert a vector of tokens to a single string. /// /// # Errors @@ -211,22 +220,45 @@ impl LlamaModel { token: LlamaToken, buffer_size: usize, ) -> Result { + let bytes = self.token_to_bytes_with_size(token, buffer_size)?; + Ok(String::from_utf8(bytes)?) + } + + /// Convert a token to bytes with a specified buffer size. + /// + /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and + /// the extra bytes do not really matter. + /// + /// # Errors + /// + /// - if the token type is unknown + /// - the resultant token is larger than `buffer_size`. + /// + /// # Panics + /// + /// - if `buffer_size` does not fit into a [`c_int`]. + /// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen) + pub fn token_to_bytes_with_size( + &self, + token: LlamaToken, + buffer_size: usize, + ) -> Result, TokenToStringError> { if token == self.token_nl() { - return Ok(String::from("\n")); + return Ok(String::from("\n").into_bytes()); } match self.token_type(token) { LlamaTokenType::Normal | LlamaTokenType::UserDefined => {} LlamaTokenType::Control => { if token == self.token_bos() || token == self.token_eos() { - return Ok(String::new()); + return Ok(Vec::new()); } } LlamaTokenType::Unknown | LlamaTokenType::Undefined | LlamaTokenType::Byte | LlamaTokenType::Unused => { - return Ok(String::new()); + return Ok(Vec::new()); } } @@ -246,7 +278,7 @@ impl LlamaModel { let mut bytes = string.into_bytes(); let len = usize::try_from(size).expect("size is positive and fits into usize"); bytes.truncate(len); - Ok(String::from_utf8(bytes)?) + Ok(bytes) } } } diff --git a/simple/Cargo.toml b/simple/Cargo.toml index 4e23632a..66008226 100644 --- a/simple/Cargo.toml +++ b/simple/Cargo.toml @@ -10,6 +10,7 @@ llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.46" } hf-hub = { workspace = true } clap = { workspace = true , features = ["derive"] } anyhow = { workspace = true } +encoding_rs = { workspace = true } [features] cublas = ["llama-cpp-2/cublas"] diff --git a/simple/src/main.rs b/simple/src/main.rs index 8f7451f7..98f863aa 100644 --- a/simple/src/main.rs +++ b/simple/src/main.rs @@ -240,6 +240,9 @@ either reduce n_len or increase n_ctx" let t_main_start = ggml_time_us(); + // The `Decoder` + let mut decoder = encoding_rs::UTF_8.new_decoder(); + while n_cur <= n_len { // sample the next token { @@ -256,7 +259,11 @@ either reduce n_len or increase n_ctx" break; } - print!("{}", model.token_to_str(new_token_id)?); + let output_bytes = model.token_to_bytes(new_token_id)?; + // use `Decoder.decode_to_string()` to avoid the intermediate buffer + let mut output_string = String::with_capacity(32); + let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); + print!("{}", output_string); std::io::stdout().flush()?; batch.clear();