Skip to content

Commit

Permalink
Implement DRY penalty (#637)
Browse files Browse the repository at this point in the history
* Implement dry penalty

* Add dry sampling params to requests

* Handle it

* Clippy

* Review: "Implement DRY penalty" (#645)

* Silence bogus Clippy warning

Clippy's suggestion cannot be implemented because of borrowing issues

* Get rid of unnecessary type annotations

Interesting that Clippy doesn't catch this

* Store default sequence breakers in a slice

It's nicer when the length is not hardcoded

* Make default sequence breakers private

No need to leak this as it's not used elsewhere

* Limit match length

Avoids quadratic runtime and potential DoS with adversarial inputs

Ref oobabooga/text-generation-webui#6047

* "Fix" sequence breaker tokenization

Most tokenizers encode punctuation tokens differently depending on where they occur in the input, and which tokens surround them. With the default sequence breakers, the appropriate encoding usually corresponds to the encoding produced when the token occurs after a word, rather than by itself. To emulate this, prefix the token with "a" before encoding, and extract the final token of the result.

See LostRuins/koboldcpp#982 for a correct solution to this problem.

* Nicer

* Even better

* Complete merge

* Fix saturating sub

* Handle when no context

* Make context the entire sequence and refactor

* Remove slicing for all

* Fix the bug with penalty

Credit to @p-e-w for finding this!

Co-authored-by: Philipp Emanuel Weidmann <[email protected]>

* Add custom logits processor API (#702)

* Add custom logits processor api

* Typos

* Nicer interface and update example

* Fix doctest

* Update docs

* Update exports

* Add Gemma 2 PagedAttention support (#704)

* Add gemma2 paged attn support

* Non cuda support?

* Remove error

* It works

* Faster RmsNorm in gemma/gemma2 (#703)

* Fix bug in metal isq (#706)

* Support GGUF BF16 tensors (#691)

* Support GGUF bf16 tensors

* Fix loading of bf16 ggml tensor

* Fix dequant of bf16

* Use merged rev

* Softcapping, real batching + sliding window support for Flash Attention  (#707)

* Flash attention varlen kind of works

* Seems to work

* Now it's nice

* Sliding window support and clippy

* Remove warning

* Support smollm

* Update rev to match merged

* Remove some usages of 'pub' in models (#708)

* Support the Phi 3.5 V model (#710)

* Update image_seq_len

* Update the examples

* Format

* Implement the Phi 3.5 MoE model (#709)

* Copy the model

* Add most of it

* Add the blocksparse moe parts

* Clippy

* Fix mscales

* A batch of fixes

* Correctly cast it

* Handle isq on gate

* Even more progress

* Runs now

* Clippy

* Fix to use layernorm

* Remove unused

* Add docs

* Add more docs

* Apply review comments

* Update readme

---------

Co-authored-by: Philipp Emanuel Weidmann <[email protected]>
  • Loading branch information
EricLBuehler and p-e-w authored Aug 27, 2024
1 parent 91a423e commit d35f62e
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 31 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ Mistal.rs supports several model categories:
- [Paper](https://arxiv.org/abs/2405.19076)
- [Docs](docs/ANYMOE.md)
- PagedAttention: [docs](docs/PAGED_ATTENTION.md)
- Various sampling techniques:
- Various sampling and penalty techniques:
- Top K
- Top P
- Min P
- [Dry Penalty](https://github.com/oobabooga/text-generation-webui/pull/5677)
- Frequency and Presence Penalty
- Please suggest more by raising an issue!
- Tool calling: [docs](docs/TOOL_CALLING.md)
- Prompt chunking (only without PagedAttention for now): handle larger prompts where the activation size would cause an OOM by sending chunks
Expand Down
9 changes: 6 additions & 3 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use mistralrs_core::{
initialize_logging, paged_attn_supported, Constraint, DefaultSchedulerMethod,
DeviceLayerMapMetadata, DeviceMapMetadata, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, Loader, LoaderBuilder,
MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelDType, ModelSelected, NormalRequest,
PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams, SchedulerConfig,
TokenSource, Usage,
};
use std::sync::Arc;
use std::{fmt::Display, num::NonZeroUsize};
Expand Down Expand Up @@ -64,6 +65,7 @@ fn run_bench(
stop_toks: None,
logits_bias: None,
n_choices: 1,
dry_params: Some(DrySamplingParams::default()),
};
let sender = mistralrs.get_sender().unwrap();
let (tx, mut rx) = channel(10_000);
Expand Down Expand Up @@ -227,6 +229,7 @@ fn warmup_run(mistralrs: Arc<MistralRs>) {
stop_toks: None,
logits_bias: None,
n_choices: 1,
dry_params: Some(DrySamplingParams::default()),
};
let sender = mistralrs.get_sender().unwrap();
let (tx, mut rx) = channel(10_000);
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,13 @@ impl Engine {
tokenizer,
request.sampling_params.frequency_penalty,
request.sampling_params.presence_penalty,
request.sampling_params.dry_params,
topk,
topp,
minp,
request.logits_processors.unwrap_or_default(),
);
let sampler = handle_seq_error!(sampler, request.response);

if request.sampling_params.n_choices == 0 {
request
Expand Down
4 changes: 3 additions & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ pub use pipeline::{
pub use request::{Constraint, MessageContent, NormalRequest, Request, RequestMessage};
pub use response::Response;
pub use response::*;
pub use sampler::{CustomLogitsProcessor, SamplingParams, StopTokens, TopLogprob};
pub use sampler::{
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
};
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
use serde::Serialize;
use tokio::runtime::Runtime;
Expand Down
15 changes: 13 additions & 2 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,19 @@ impl AnyMoePipelineMixin for AnyMoePipeline {

// Create several dummy objects for the sequences. No custom logits processors.
let (dummy_sender, _) = tokio::sync::mpsc::channel(10000);
let dummy_sampler =
Sampler::new(None, 0, tokenizer.clone(), None, None, -1, 0.0, 0.0, vec![]);
let dummy_sampler = Sampler::new(
None,
0,
tokenizer.clone(),
None,
None,
None,
-1,
0.0,
0.0,
vec![],
)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;

let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
1, false, false, 0,
Expand Down
196 changes: 187 additions & 9 deletions mistralrs-core/src/sampler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::{
collections::HashMap,
collections::{HashMap, HashSet},
iter::zip,
sync::{Arc, Mutex},
};
Expand All @@ -12,9 +12,14 @@ use pyo3::pyclass;

use rand::distributions::{Distribution, WeightedIndex};
use rand_isaac::Isaac64Rng;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::sync::LazyLock;
use tokenizers::Tokenizer;

static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());

#[derive(Clone, Debug)]
/// Stop sequences or ids.
pub enum StopTokens {
Expand All @@ -36,6 +41,7 @@ pub struct SamplingParams {
pub max_len: Option<usize>,
pub logits_bias: Option<HashMap<u32, f32>>,
pub n_choices: usize,
pub dry_params: Option<DrySamplingParams>,
}

impl Default for SamplingParams {
Expand All @@ -52,10 +58,91 @@ impl Default for SamplingParams {
max_len: None,
logits_bias: None,
n_choices: 1,
dry_params: None,
}
}
}

#[derive(Clone, Debug)]
pub struct DrySamplingParams {
pub sequence_breakers: Vec<String>,
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
}

impl DrySamplingParams {
pub fn new_with_defaults(
multiplier: f32,
sequence_breakers: Option<Vec<String>>,
base: Option<f32>,
allowed_length: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
base: base.unwrap_or(1.75),
allowed_length: allowed_length.unwrap_or(2),
sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
multiplier,
})
}
}

impl Default for DrySamplingParams {
fn default() -> Self {
Self {
multiplier: 0.0,
base: 1.75,
allowed_length: 2,
sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
}
}
}

#[derive(Clone, Debug)]
struct DrySamplingParamsInner {
pub sequence_breakers: HashSet<u32>,
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
}

impl DrySamplingParamsInner {
pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
Ok(Self {
base: other.base,
allowed_length: other.allowed_length,
sequence_breakers: HashSet::from_iter(
other
.sequence_breakers
.into_iter()
.map(|breaker| {
tokenizer
// Prefix with 'a' to get the correct encoding of the token at the end of a text.
//
// FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
// for the correct solution which covers multi-token sequence breakers
// and ambiguous encodings.
.encode(["a", &breaker].concat(), true)
.map_err(anyhow::Error::msg)
.map(|enc| {
let ids = enc.get_ids();
if !ids.is_empty() {
None
} else {
Some(ids[ids.len() - 1])
}
})
})
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
),
multiplier: other.multiplier,
})
}
}

/// Customizable logtis processor
pub trait CustomLogitsProcessor: Send + Sync {
/// Logits and sequence context (prompt and generated tokens), returning modified tokens.
Expand All @@ -76,6 +163,7 @@ pub struct Sampler {
tokenizer: Arc<Tokenizer>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
dry_params: Option<DrySamplingParamsInner>,
top_k: i64,
top_p: f64,
min_p: f64,
Expand Down Expand Up @@ -112,27 +200,34 @@ impl Sampler {
tokenizer: Arc<Tokenizer>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
dry_params: Option<DrySamplingParams>,
top_k: i64,
top_p: f64,
min_p: f64,
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
) -> Self {
) -> anyhow::Result<Self> {
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
None
} else {
temperature
};
Self {
let dry_params = dry_params.map(|params| DrySamplingParamsInner::from(params, &tokenizer));
let dry_params = match dry_params {
Some(fallible) => Some(fallible?),
None => None,
};
Ok(Self {
temperature,
top_n_logprobs,
tokenizer,
frequency_penalty,
presence_penalty,
dry_params,
top_k,
top_p,
min_p,
logits_processors,
}
})
}

fn get_top_logprobs(
Expand Down Expand Up @@ -372,6 +467,21 @@ impl Sampler {
}

fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
if context.is_empty() {
candle_core::bail!("Penalty context is empty, this should not happen.");
}

// Dry penalty
self.apply_dry_penalty(&mut logits, context)?;

// Frequency and Presence penalty
self.apply_freq_presc_penalty(&mut logits, context)?;

let vocab_size = logits.len();
Tensor::from_vec(logits, vocab_size, &Device::Cpu)
}

fn apply_freq_presc_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
if self.frequency_penalty.is_some() || self.presence_penalty.is_some() {
let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
let presence_penalty = self.presence_penalty.unwrap_or(0.);
Expand All @@ -390,8 +500,71 @@ impl Sampler {
- if count > 0.0 { 1. } else { 0. } * presence_penalty;
}
}
let vocab_size = logits.len();
Tensor::from_vec(logits, vocab_size, &Device::Cpu)
Ok(())
}

fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
if let Some(ref params) = self.dry_params {
let match_indices = context
.par_iter()
.enumerate()
.take(context.len() - 1)
.filter(|(_i, x)| *context.last().unwrap() == **x)
.map(|(i, _)| i)
.collect::<Vec<_>>();

let mut match_lengths = HashMap::new();

for i in match_indices {
let next_token = context[i + 1];

if params.sequence_breakers.contains(&next_token) {
continue;
}

let mut match_length = 1;

// Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
while match_length < 50 {
if match_length > i {
// Start of input
break;
}

let j = i - match_length;

let prev_tok = context[context.len() - (match_length + 1)];
if context[j] != prev_tok {
// Start of match reached
break;
}

if params.sequence_breakers.contains(&prev_tok) {
// Seq breaking tok reached
break;
}

match_length += 1;
}

#[allow(clippy::map_entry)]
if match_lengths.contains_key(&next_token) {
match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
} else {
match_lengths.insert(next_token, match_length);
}
}

// Actually apply penalties
for (tok, match_len) in match_lengths {
if match_len >= params.allowed_length {
let penalty = params.multiplier
* params.base.powf((match_len - params.allowed_length) as f32);
logits[tok as usize] -= penalty;
}
}
}
Ok(())
}

/// Sample the provided tokens.
Expand All @@ -406,7 +579,8 @@ impl Sampler {
rng: Arc<Mutex<Isaac64Rng>>,
sample_speculative: bool,
) -> Result<Logprobs> {
let mut logits = self.apply_penalties(logits.to_vec1()?, context)?;
let logits = logits.to_vec1()?;
let mut logits = self.apply_penalties(logits, context)?;
for processor in &self.logits_processors {
logits = processor.apply(&logits, context)?;
}
Expand Down Expand Up @@ -487,11 +661,13 @@ mod tests {
get_tokenizer().into(),
None,
None,
None,
32,
0.1,
0.05,
vec![],
);
)
.unwrap();
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler
Expand All @@ -517,11 +693,13 @@ mod tests {
get_tokenizer().into(),
None,
None,
None,
32,
0.1,
0.05,
vec![],
);
)
.unwrap();
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler
Expand Down
Loading

0 comments on commit d35f62e

Please sign in to comment.