Skip to content

Commit

Permalink
Add ISQ FP8 (#832)
Browse files Browse the repository at this point in the history
* Add cublaslt matmul

* cuBLASlt matmul roughly works

* Add layer

* Tests pass

* Add isq fp8 layer forward

* Better tests

* It works

* Clippy and fix dequant matmul

* Add UQFF support

* Update docs

* All tests pass

* Improved quantization
  • Loading branch information
EricLBuehler authored Oct 12, 2024
1 parent 8ab7d12 commit 9a45756
Show file tree
Hide file tree
Showing 31 changed files with 2,245 additions and 60 deletions.
28 changes: 23 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ license = "MIT"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand All @@ -49,3 +49,4 @@ rayon = "1.1.0"
url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
float8 = "0.1.1"
1 change: 1 addition & 0 deletions docs/ISQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ To set the ISQ type for individual layers, use a model [`topology`](TOPOLOGY.md)
- Q8K (*not available on CUDA*)
- HQQ4
- HQQ8
- FP8

When using ISQ, it will automatically load ISQ-able weights into CPU memory before applying ISQ. The ISQ application process moves the weights to device memory. This process is implemented to avoid memory spikes from loading the model in full precision.

Expand Down
3 changes: 3 additions & 0 deletions docs/UQFF.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ The following quantization formats are supported in UQFF. One can, of course, be
- HQQ4
- HQQ8

- FP8:
- FP8 E4M3 (4-bit exponent, 3-bit mantissa)

## Loading a UQFF model

To load a UQFF model, one should specify the artifact path. This can be either be a path to a UQFF file locally, or a Hugging Face model ID with the format `<MODEL ID>/<FILE>`. For example, the following work:
Expand Down
14 changes: 13 additions & 1 deletion docs/UQFF/LAYOUT.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |


## HQQ quantization
| ID | Element type | Endianness |
| -------- | -------- | -------- |
Expand All @@ -51,6 +50,19 @@ The following describes the exact memory layout of HQFF tensors of version 0.1.0
| CFG round zeroes (boolean) | u8 | little endian |
| CFG channel wise (boolean) | u8 | little endian |

## FP8 layers
| ID | Element type | Endianness |
| -------- | -------- | -------- |
| HQFF version | u32 | little endian |
| ISQ type (3) | u8 | little endian |
| Whether bias data is included (boolean) | u8 | little endian |
| **Array** Weight tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |
| Dequant scale W | f32 | little endian |
| Dequant scale X | f32 | little endian |
| Quant scale | f32 | little endian |
| Layer dtype | u32 | little endian |
| **[Optional]** **Array** Bias tensor data, see [docs](#standard-tensors) | See [docs](#standard-tensors) | See [docs](#standard-tensors) |

## Standard tensors
| ID | Element type | Endianness |
| -------- | -------- | -------- |
Expand Down
5 changes: 3 additions & 2 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "20a57c4", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.7.0", rev = "60eb251", optional = true }
dirs = "5.0.1"
hf-hub = "0.3.2"
thiserror = "1.0.57"
Expand Down Expand Up @@ -78,10 +78,11 @@ regex = "1.10.6"
safetensors = "0.4.5"
serde_plain = "1.0.2"
as-any = "0.3.1"
float8.workspace = true

[features]
pyo3_macros = ["pyo3"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"]
cudnn = ["candle-core/cudnn"]
metal = ["candle-core/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
Expand Down
12 changes: 7 additions & 5 deletions mistralrs-core/src/cublaslt/api.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
pub use candle_core::cuda_backend::cudarc::cublaslt::Activation;
use candle_core::cuda::cudarc::driver::DevicePtr;
use float8::F8E4M3;
use std::ffi::c_int;

use candle_core::backend::BackendStorage;
use candle_core::cuda_backend::WrapErr;
use candle_core::{CpuStorage, Device, Layout, Result, Shape, Storage, Tensor};
use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor};
use half::{bf16, f16};
use std::sync::Arc;

use candle_core::cuda_backend::cudarc::cublaslt::{CudaBlasLT, Matmul, MatmulConfig};
use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig};

#[derive(Debug, Clone)]
pub struct CublasLt(Arc<CudaBlasLT>);
Expand Down Expand Up @@ -858,11 +859,12 @@ pub fn fused_batch_matmul(
a.apply_op2(b, op)
}
}

#[cfg(test)]
mod tests {
use std::f32::consts::PI;

use super::*;
use candle_core::{DType, Device};
use candle_core::{DType, Device, IndexOp};

fn to_vec2_round(t: Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
let b = 10f32.powi(digits);
Expand Down
Loading

0 comments on commit 9a45756

Please sign in to comment.