Skip to content

Commit

Permalink
Add GPTQ Marlin support for 4 and 8 bit (#856)
Browse files Browse the repository at this point in the history
* Add GPTQ Marlin support for 4 and 8 bit

* GPTQ marlin backend is working

* Typo

* Update docs
  • Loading branch information
EricLBuehler authored Oct 16, 2024
1 parent 100a660 commit 751be3d
Show file tree
Hide file tree
Showing 13 changed files with 2,021 additions and 96 deletions.
3 changes: 2 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ extend-ignore-identifiers-re = [
"mmaped",
"arange",
"Nd",
"nin"
"nin",
"cudaDevAttrMaxSharedMemoryPerBlockOptin"
]

[files]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Mistral.rs supports several model categories:
**Quantization**:
- [Details](docs/QUANTS.md)
- GGML: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit, with ISQ support.
- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit
- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit, with [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit.
- HQQ: 4-bit and 8 bit, with ISQ support

**Powerful**:
Expand Down
2 changes: 2 additions & 0 deletions docs/QUANTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Mistral.rs supports the following quantization:
- Supported in all plain and adapter models
- CUDA only
- 2, 3, 4, 8 bit
- [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit.
- HQQ
- Supported in all plain and adapter models via ISQ
- CUDA and CPU only
Expand Down Expand Up @@ -41,6 +42,7 @@ cargo run --features cuda -- -i --isq Q4K plain -m microsoft/Phi-3-mini-4k-instr
- Use the `plain` (cli) / `Plain` (Python) model selector
- Provide the model ID for the GPTQ model
- Mistral.rs will automatically detect and use GPTQ quantization.
- The [Marlin](https://github.com/IST-DASLab/marlin) kernel will automatically be used in 4-bit and 8-bit.

```
cargo run --features cuda -- -i plain -m kaitchup/Phi-3-mini-4k-instruct-gptq-4bit -a phi3
Expand Down
1 change: 1 addition & 0 deletions mistralrs-quant/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn main() {
"kernels/gptq/q_gemm.cu",
"kernels/hqq/hqq.cu",
"kernels/ops/ops.cu",
"kernels/marlin/marlin_kernel.cu",
];
for lib_file in lib_files.iter() {
println!("cargo:rerun-if-changed={lib_file}");
Expand Down
118 changes: 118 additions & 0 deletions mistralrs-quant/kernels/marlin/marlin/marlin.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <cassert>

// #define CHECK(cond, ...) \
// assert(cond); \
#define CHECK(cond, ...)

namespace marlin {

// Marlin params

// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.

static constexpr int repack_threads = 256;
static constexpr int repack_stages = 8;
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;

static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;

__device__ inline constexpr int ceildiv(int a, int b) {
return (a + b - 1) / b;
}

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}

// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}

// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}

// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}

// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}

// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};

using I4 = Vec<int, 4>;

} // namespace marlin
79 changes: 79 additions & 0 deletions mistralrs-quant/kernels/marlin/marlin/marlin_dtypes.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

#ifndef _data_types_cuh
#define _data_types_cuh
#include "marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>

namespace marlin {

template <typename scalar_t>
class ScalarType {};

template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;

// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;

static __device__ float inline num2float(const half x) {
return __half2float(x);
}

static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}

static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}

static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
};

template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;

using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}

static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}

static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}

static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif
};

} // namespace marlin

#endif
Loading

0 comments on commit 751be3d

Please sign in to comment.