diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 943db1121c..12439892ce 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> { }; println!("img\n{img}"); let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; - candle_examples::save_image(&img.i(0)?, "out.jpg")?; + let filename = match args.seed { + None => "out.jpg".to_string(), + Some(s) => format!("out-{s}.jpg"), + }; + candle_examples::save_image(&img.i(0)?, filename)?; Ok(()) } diff --git a/candle-examples/examples/glm4/README.org b/candle-examples/examples/glm4/README.org index 364f61e8eb..a584f6c745 100644 --- a/candle-examples/examples/glm4/README.org +++ b/candle-examples/examples/glm4/README.org @@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode ** Running with ~cuda~ #+begin_src shell - cargo run --example glm4 --release --features cuda + cargo run --example glm4 --release --features cuda -- --prompt "Hello world" #+end_src ** Running with ~cpu~ #+begin_src shell - cargo run --example glm4 --release -- --cpu + cargo run --example glm4 --release -- --cpu--prompt "Hello world" #+end_src ** Output Example #+begin_src shell -cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache . - Finished release [optimized] target(s) in 0.24s - Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .` +cargo run --features cuda -r --example glm4 -- --prompt "Hello " + avx: true, neon: false, simd128: false, f16c: true temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64 -cache path . -retrieved the files in 6.88963ms -loaded the model in 6.113752297s +retrieved the files in 6.454375ms +loaded the model in 3.652383779s starting the inference loop -[欢迎使用GLM-4,请输入prompt] -请你告诉我什么是FFT -266 tokens generated (34.50 token/s) -Result: -。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。 - -具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。 - -以下是使用 Python 中的 numpy 进行 FFT 的简单示例: - -```python -import numpy as np - -# 创建一个时域信号 -t = np.linspace(0, 1, num=100) -f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t) - -# 对该信号做FFT变换,并计算其幅值谱 -fft_result = np.fft.fftshift(np.abs(np.fft.fft(f))) - -``` - -在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。 +Hello 2018, hello new year! I’m so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share what’s been inspiring me lately in hopes that it will inspire you too! +... #+end_src This example will read prompt from stdin diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs index 55a27f349e..ced3841d8e 100644 --- a/candle-examples/examples/glm4/main.rs +++ b/candle-examples/examples/glm4/main.rs @@ -12,120 +12,97 @@ struct TextGeneration { device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, - repeat_penalty: f32, - repeat_last_n: usize, - verbose_prompt: bool, + args: Args, dtype: DType, } impl TextGeneration { #[allow(clippy::too_many_arguments)] - fn new( - model: Model, - tokenizer: Tokenizer, - seed: u64, - temp: Option, - top_p: Option, - repeat_penalty: f32, - repeat_last_n: usize, - verbose_prompt: bool, - device: &Device, - dtype: DType, - ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); + fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self { + let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); Self { model, tokenizer, logits_processor, - repeat_penalty, - repeat_last_n, - verbose_prompt, + args, device: device.clone(), dtype, } } - fn run(&mut self, sample_len: usize) -> anyhow::Result<()> { - use std::io::BufRead; - use std::io::BufReader; + fn run(&mut self) -> anyhow::Result<()> { use std::io::Write; + let args = &self.args; println!("starting the inference loop"); - println!("[欢迎使用GLM-4,请输入prompt]"); - let stdin = std::io::stdin(); - let reader = BufReader::new(stdin); - for line in reader.lines() { - let line = line.expect("Failed to read line"); - - let tokens = self.tokenizer.encode(line, true).expect("tokens error"); - if tokens.is_empty() { - panic!("Empty prompts are not supported in the chatglm model.") - } - if self.verbose_prompt { - for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { - let token = token.replace('▁', " ").replace("<0x0A>", "\n"); - println!("{id:7} -> '{token}'"); - } + + let tokens = self + .tokenizer + .encode(args.prompt.to_string(), true) + .expect("tokens error"); + if tokens.is_empty() { + panic!("Empty prompts are not supported in the chatglm model.") + } + if args.verbose { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); } - let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { - Some(token) => *token, - None => panic!("cannot find the endoftext token"), + } else { + print!("{}", &args.prompt); + std::io::stdout().flush()?; + } + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => panic!("cannot find the endoftext token"), + }; + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + + std::io::stdout().flush().expect("output flush error"); + let start_gen = std::time::Instant::now(); + + for index in 0..args.sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? }; - let mut tokens = tokens.get_ids().to_vec(); - let mut generated_tokens = 0usize; - - std::io::stdout().flush().expect("output flush error"); - let start_gen = std::time::Instant::now(); - - let mut count = 0; - let mut result = vec![]; - for index in 0..sample_len { - count += 1; - let context_size = if index > 0 { 1 } else { tokens.len() }; - let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input)?; - let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; - let logits = if self.repeat_penalty == 1. { - logits - } else { - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - candle_transformers::utils::apply_repeat_penalty( - &logits, - self.repeat_penalty, - &tokens[start_at..], - )? - }; - - let next_token = self.logits_processor.sample(&logits)?; - tokens.push(next_token); - generated_tokens += 1; - if next_token == eos_token { - break; - } - let token = self - .tokenizer - .decode(&[next_token], true) - .expect("Token error"); - if self.verbose_prompt { - println!( - "[Count: {}] [Raw Token: {}] [Decode Token: {}]", - count, next_token, token - ); - } - result.push(token); - std::io::stdout().flush()?; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; } - let dt = start_gen.elapsed(); - println!( - "\n{generated_tokens} tokens generated ({:.2} token/s)", - generated_tokens as f64 / dt.as_secs_f64(), - ); - println!("Result:"); - for tokens in result { - print!("{tokens}"); + let token = self + .tokenizer + .decode(&[next_token], true) + .expect("token decode error"); + if args.verbose { + println!( + "[Count: {}] [Raw Token: {}] [Decode Token: {}]", + generated_tokens, next_token, token + ); + } else { + print!("{token}"); + std::io::stdout().flush()?; } - self.model.reset_kv_cache(); // clean the cache } + let dt = start_gen.elapsed(); + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); Ok(()) } } @@ -141,7 +118,11 @@ struct Args { /// Display the token for the specified prompt. #[arg(long)] - verbose_prompt: bool, + prompt: String, + + /// Display the tokens for the specified prompt and outputs. + #[arg(long)] + verbose: bool, /// The temperature used to generate samples. #[arg(long)] @@ -197,28 +178,29 @@ fn main() -> anyhow::Result<()> { ); let start = std::time::Instant::now(); - println!("cache path {}", args.cache_path); - let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) - .build() - .map_err(anyhow::Error::msg)?; + let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new( + args.cache_path.to_string().into(), + )) + .build() + .map_err(anyhow::Error::msg)?; - let model_id = match args.model_id { + let model_id = match args.model_id.as_ref() { Some(model_id) => model_id.to_string(), None => "THUDM/glm-4-9b".to_string(), }; - let revision = match args.revision { + let revision = match args.revision.as_ref() { Some(rev) => rev.to_string(), None => "main".to_string(), }; let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let tokenizer_filename = match args.tokenizer { + let tokenizer_filename = match args.tokenizer.as_ref() { Some(file) => std::path::PathBuf::from(file), None => api .model("THUDM/codegeex4-all-9b".to_string()) .get("tokenizer.json") .map_err(anyhow::Error::msg)?, }; - let filenames = match args.weight_file { + let filenames = match args.weight_file.as_ref() { Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; @@ -238,18 +220,7 @@ fn main() -> anyhow::Result<()> { println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - args.top_p, - args.repeat_penalty, - args.repeat_last_n, - args.verbose_prompt, - &device, - dtype, - ); - pipeline.run(args.sample_len)?; + let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype); + pipeline.run()?; Ok(()) }