Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Oct 7, 2024
1 parent e4a96f9 commit 089bc8c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ And then head over to
<!--- ANCHOR: useful_libraries --->

## Useful External Resources

- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
Expand All @@ -187,6 +188,7 @@ And then head over to
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.

If you have an addition to this list, please submit a pull request.

Expand Down
45 changes: 45 additions & 0 deletions candle-transformers/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,49 @@ impl LogitsProcessor {
};
Ok(next_token)
}

pub fn sample_f_with_probs(
&mut self,
logits: &Tensor,
f: impl FnOnce(&mut [f32]),
) -> Result<(u32, f32)> {
let logits = logits.to_dtype(DType::F32)?;
let prs = |temperature: f64| -> Result<Vec<f32>> {
let logits = (&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
let mut prs = prs.to_vec1()?;
f(&mut prs);
Ok(prs)
};

let next_token = match &self.sampling {
Sampling::ArgMax => {
let next_token = self.sample_argmax(logits)?
},
Sampling::All { temperature } => {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?
}
Sampling::TopP { p, temperature } => {
let mut prs = prs(*temperature)?;
if *p <= 0.0 || *p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multinomial(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, *p as f32)?
}
}
Sampling::TopK { k, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
}
Sampling::TopKThenTopP { k, p, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk_topp(&mut prs, *k, *p as f32)?
}
};

Ok((next_token, prob))
}
}

0 comments on commit 089bc8c

Please sign in to comment.