Skip to content

Commit

Permalink
Use BF16 on metal when possible. (#2378)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Aug 1, 2024
1 parent bd80078 commit 1ba87a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
16 changes: 16 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ impl Device {
matches!(self, Self::Metal(_))
}

pub fn supports_bf16(&self) -> bool {
match self {
Self::Cuda(_) | Self::Metal(_) => true,
Self::Cpu => false,
}
}

/// Return `BF16` for devices that support it, otherwise default to `F32`.
pub fn bf16_default_to_f32(&self) -> DType {
if self.supports_bf16() {
DType::BF16
} else {
DType::F32
}
}

pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)
Expand Down
6 changes: 1 addition & 5 deletions candle-examples/examples/mixtral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::v0_1_8x7b(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
Expand Down

0 comments on commit 1ba87a9

Please sign in to comment.