Skip to content

Commit

Permalink
Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
Browse files Browse the repository at this point in the history
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
  • Loading branch information
yinqiwen authored Apr 1, 2024
1 parent 888c09a commit 5522bbc
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
let path = self.path(name);
self.data
.backend
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
}

/// Retrieve the tensor associated with the given name at the current path.
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
self.get_with_hints(s, name, Default::default())
}

/// Retrieve the tensor associated with the given name & dtype at the current path.
pub fn get_with_hints_dtype<S: Into<Shape>>(
&self,
s: S,
name: &str,
hints: B::Hints,
dtype: DType,
) -> Result<Tensor> {
let path = self.path(name);
self.data
.backend
.get(s.into(), &path, hints, dtype, &self.data.device)
}
}

struct Zeros;
Expand Down

0 comments on commit 5522bbc

Please sign in to comment.