Skip to content

Commit

Permalink
Avoid duplicate metal command buffer encodings (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Oct 17, 2024
1 parent d82306d commit 300cbd5
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions mistralrs-core/src/pipeline/isq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,26 @@ pub trait IsqModel {

use rayon::iter::IntoParallelRefIterator;

let current_rayon_threads = rayon::current_num_threads();
// Get the MINIMUM of the max isq threads the quant method allows
let minimum_max_threads = tensors
.iter()
.map(|(q, _)| {
if let Some(dtype) = dtype {
q.get_max_isq_cpu_threads(dtype)
.map(usize::from)
.unwrap_or(current_rayon_threads)
} else {
current_rayon_threads
}
})
.min()
.unwrap_or(current_rayon_threads);
#[cfg(not(feature = "metal"))]
let minimum_max_threads = {
let current_rayon_threads = rayon::current_num_threads();
tensors
.iter()
.map(|(q, _)| {
if let Some(dtype) = dtype {
q.get_max_isq_cpu_threads(dtype)
.map(usize::from)
.unwrap_or(current_rayon_threads)
} else {
current_rayon_threads
}
})
.min()
.unwrap_or(current_rayon_threads)
};
#[cfg(feature = "metal")]
let minimum_max_threads = 1;

info!("Applying ISQ on {minimum_max_threads} threads.");

Expand Down

0 comments on commit 300cbd5

Please sign in to comment.