Skip to content

Commit

Permalink
feat(candle-transformers/models/codegeex4-9b): add codegeex4-9 (#2334)
Browse files Browse the repository at this point in the history
* feat(candle-transformers/models/codegeex4-9b): add codegeex4-9b transoformers

* change mod.rs

* feat(candle-examples/codegeex4-9b)

* Update codegeex4_9b.rs

* Update main.rs

* Update codegeex4_9b.rs

* Update main.rs

* fmt

* fix

* fmt

* Clippy fix.

* Remove some print statements.

* Avoid using unwrap.

* 1. add README
2. change the print fmt

* Another clippy fix.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
donjuanplatinum and LaurentMazare authored Jul 21, 2024
1 parent 3c815b1 commit 2489a60
Show file tree
Hide file tree
Showing 4 changed files with 945 additions and 0 deletions.
96 changes: 96 additions & 0 deletions candle-examples/examples/codegeex4-9b/README.org
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
* candle-codegeex4_9b
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more.

- [[https://github.com/THUDM/CodeGeeX4][Github]]
- [[https://codegeex.cn/][HomePage]]
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]]

** Running with ~cuda~

#+begin_src shell
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src

** Running with ~cpu~
#+begin_src shell
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300
#+end_src

** Output_Example
*** Input
#+begin_src shell
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp
#+end_src

*** Output
#+begin_src shell
avx: false, neon: false, simd128: false, f16c: false
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64
cache path /root/autodl-tmp
Prompt: [please write a FFT in rust]
Using Seed 11511762269791786684
DType is BF16
transofrmer layers create
模型加载完毕 4
starting the inference loop

开始生成
samplelen 500

500 tokens generated (34.60 token/s)
Result:

Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust:

```rust
use num_complex::Complex;

fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > {
let n = input.len();

if n == 1 {
return vec![input[0]]];
}

let mut even = vec![];
let mut odd = vec![];

for i in 0..n {

if i % 2 == 0 {
even.push(input[i]);
} else {
odd.push(input[i]);
}
}

let even_fft = fft(&even);
let odd_fft = fft(&odd);

let mut output = vec![];

for k in 0..n/2 {
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp();

output.push(even_fft[k] + odd_fft[k] * t]);
output.push(even_fft[k] - odd_fft[k] * t]);
}

return output;
}
```

This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT.
#+end_src


* Citation
#+begin_src
@inproceedings{zheng2023codegeex,
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X},
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang},
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
pages={5673--5684},
year={2023}
}
#+end_src
252 changes: 252 additions & 0 deletions candle-examples/examples/codegeex4-9b/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
use candle_transformers::models::codegeex4_9b::*;
use clap::Parser;

use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;

struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}

impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}

fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;

print!("{prompt}");
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();

println!("\n start_gen");
println!("samplelen {}", sample_len);
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};

let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
Ok(())
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,

#[arg(long)]
cpu: bool,

/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,

#[arg(long)]
prompt: String,

/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,

#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

#[arg(long)]
weight_file: Option<String>,

#[arg(long)]
tokenizer: Option<String>,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}

fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.95),
args.repeat_penalty,
args.repeat_last_n
);

let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;

let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/codegeex4-all-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");

let start = std::time::Instant::now();
let config = Config::codegeex4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;

println!("loaded the model in {:?}", start.elapsed());

let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
Loading

0 comments on commit 2489a60

Please sign in to comment.