Skip to content

Commit

Permalink
feat: add HIP support
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 authored and dacorvo committed Oct 4, 2024
1 parent cf0b061 commit 843b793
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 5 deletions.
5 changes: 4 additions & 1 deletion optimum/quanto/library/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@


if torch.cuda.is_available():
from .cuda import *
if torch.version.cuda:
from .cuda import *
elif torch.version.hip:
from .hip import *

if torch.backends.mps.is_available():
from .mps import *
36 changes: 36 additions & 0 deletions optimum/quanto/library/extensions/hip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch

from ..extension import Extension, register_extension


__all__ = []


ext = Extension(
"quanto_hip",
root_dir=os.path.dirname(__file__),
sources=["unpack.cu", "pybind_module.cpp"],
extra_cflags=["-std=c++17"],
)
register_extension(ext)


@torch.library.impl("quanto::unpack", ["CUDA"])
def unpack_hip(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
21 changes: 21 additions & 0 deletions optimum/quanto/library/extensions/hip/pybind_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>
#include "unpack.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unpack", &unpack, "unpack");
}
97 changes: 97 additions & 0 deletions optimum/quanto/library/extensions/hip/unpack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAException.h>

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
#define BLOCK_SIZE 256

using namespace at;


static torch::Tensor allocate_output(const torch::Tensor& input, int bits) {
int n_packed = 8 / bits;
auto output_shape = input.sizes().vec();
output_shape[0] = output_shape[0] * n_packed;
return torch::empty(output_shape, input.options());
}

__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x0F);
output[i + n] = (input[i] & 0xF0) >> 4;
}

static torch::Tensor unpack_4bit(const torch::Tensor& input){

auto output = allocate_output(input, 4);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if(i>=n) return;

output[i] = (input[i] & 0x03);
output[i + n] = (input[i] & 0x0C) >> 2;
output[i + n*2] = (input[i] & 0x30) >> 4;
output[i + n*3] = (input[i] & 0xC0) >> 6;
}

static torch::Tensor unpack_2bit(const torch::Tensor& input){

auto output = allocate_output(input, 2);

const auto numel = input.numel();
int blocks = cdiv(numel, BLOCK_SIZE);
unpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(
input.data_ptr<unsigned char>(),
output.data_ptr<unsigned char>(),
numel
);

C10_CUDA_KERNEL_LAUNCH_CHECK();

return output;
}

torch::Tensor unpack(torch::Tensor &t, int bits) {
TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type());
TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor.");
TORCH_CHECK(t.is_contiguous(), "t must be contiguous.");
switch(bits) {
case 4:
return unpack_4bit(t);
case 2:
return unpack_2bit(t);
default:
throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors.");
}
}
17 changes: 17 additions & 0 deletions optimum/quanto/library/extensions/hip/unpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2024 The HuggingFace Team. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <torch/extension.h>

torch::Tensor unpack(torch::Tensor &t, int bits);
4 changes: 2 additions & 2 deletions optimum/quanto/tensor/weights/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g
and axis == 0
and group_size == 128
and len(size) == 2
and data.device.type == "cuda"
and (data.device.type == "cuda" and torch.version.cuda)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
if data.device.type == "cpu" or (
data.device.type == "cuda"
(data.device.type == "cuda" and torch.version.cuda)
and version.parse(torch.version.cuda).release >= (12, 1)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def create(
and activation_qtype is None
and scale.dtype in [torch.float16, torch.bfloat16]
and len(size) == 2
and data.device.type == "cuda"
and (data.device.type == "cuda" and torch.version.cuda)
and axis == 0
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
Expand Down
5 changes: 4 additions & 1 deletion test/library/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

extension_names = ["quanto_cpp"]
if torch.cuda.is_available():
extension_names.append("quanto_cuda")
if torch.version.cuda:
extension_names.append("quanto_cuda")
if torch.version.hip:
extension_names.append("quanto_hip")
if torch.backends.mps.is_available():
extension_names.append("quanto_mps")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
def test_tinygemm_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device):
if device.type == "cuda":
if torch.version.hip:
pytest.skip(reason="TinyGemm not available for ROCm devices")
if version.parse(torch.version.cuda).release < (12, 1):
pytest.skip(reason="CUDA runtime must be at least 12.1")
if torch.cuda.get_device_capability()[0] < 8:
Expand Down Expand Up @@ -98,6 +100,8 @@ def test_tinygemm_weight_qbits_tensor_move(device):
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device):
if device.type == "cuda":
if torch.version.hip:
pytest.skip(reason="TinyGemm not available for ROCm devices")
if version.parse(torch.version.cuda).release < (12, 1):
pytest.skip(reason="CUDA runtime must be at least 12.1")
if torch.cuda.get_device_capability()[0] < 8:
Expand Down

0 comments on commit 843b793

Please sign in to comment.