Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clamping t5 hidden states to avoid F16 NaNs #2481

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2539,6 +2539,23 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}

pub fn is_inf(&self) -> Result<Self> {
self.broadcast_eq(&Tensor::new(f64::INFINITY, self.device())?.to_dtype(self.dtype)?)
}

pub fn any(&self) -> Result<bool> {
let sum = self.sum_all()?;
match self.dtype {
DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
}
}
}

macro_rules! bin_trait {
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
half = { workspace = true }

[features]
default = []
Expand Down
33 changes: 29 additions & 4 deletions candle-transformers/src/models/t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,22 @@
}
}

fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
let mut max = match xs.dtype() {
DType::U8 => u8::MAX as f64 - 1000.,
DType::U32 => u32::MAX as f64 - 1000.,
DType::I64 => i64::MAX as f64 - 1000.,
DType::F16 => half::f16::MAX.to_f64_const() - 1000.,
DType::BF16 => half::bf16::MAX.to_f64_const() - 1000.,
DType::F32 => f32::MAX as f64 - 1000.,
DType::F64 => f64::MAX - 1000.,
};
if xs.is_inf()?.any()? {
max = max - 1000.;

Check failure on line 591 in candle-transformers/src/models/t5.rs

View workflow job for this annotation

GitHub Actions / Clippy

manual implementation of an assign operation
}
xs.clamp(-max, max)
}

#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
Expand Down Expand Up @@ -632,13 +648,22 @@
false => None,
};
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
// TODO: clamp for f16?
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
if let Some(cross_attn) = &mut self.cross_attn {
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
// TODO: clamp for f16?
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
}
let mut xs = self.ff.forward(&xs)?;
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok((xs, position_bias))
}

Expand Down
Loading