Skip to content

Commit

Permalink
Refactored exl2 method to add LoRA, 8bit cache, and other features su…
Browse files Browse the repository at this point in the history
…pported by exllama (#729)

Refactored the exl2 function in exllamav2.py.

The new version offers the following benefits:
1. auto split support. You no longer need to split a large model over 2
GPUs manually, exllama will do it for you
2. 8 bit cache support. Supports the 8 bit cache, can squeeze more
context into the same GPU
3. Additional exllamav2 improvements. Supports low_mem, fasttensors.
4. No longer need to pass in num_experts, it is optional.
5. Future support for 4 bit cache. Whenever turbo updates the pip
package, uncomment the 4 bit lines for 4 bit support.
6. Refactored the function parameters. Changed the model_kwargs
dictionary to individual parameters. Combined with documentation this
makes it easier for new users to understand what options they can
select.
  • Loading branch information
psych0v0yager authored Mar 13, 2024
1 parent d47bd6b commit 03c71f7
Showing 1 changed file with 125 additions and 27 deletions.
152 changes: 125 additions & 27 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from typing import TYPE_CHECKING, Optional

import torch

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora
from transformers import PreTrainedTokenizer

from .transformers import TransformerTokenizer


class ExLlamaV2Model:
"""Represents a `exl2` model."""
Expand All @@ -18,12 +19,14 @@ def __init__(
tokenizer: "PreTrainedTokenizer",
device,
cache: "ExLlamaV2Cache",
lora: Optional["ExLlamaV2Lora"] = None,
):
self.device = device
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)
self.cache = cache
self.past_seq = None
self.lora = lora

def forward(self, input_ids: torch.LongTensor, *_):
"""Compute a forward pass through the exl2 model."""
Expand All @@ -50,6 +53,7 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[longest_prefix:-1].view(1, -1),
self.cache,
preprocess_only=True,
loras=[self.lora],
)
elif seq_tensor.shape[0] == longest_prefix:
self.cache.current_seq_len -= 1
Expand All @@ -61,58 +65,152 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[:-1].view(1, -1),
self.cache,
preprocess_only=True,
loras=[self.lora],
)

self.past_seq = seq_tensor

return self.model.forward(seq_tensor[-1:].view(1, -1), self.cache)
return self.model.forward(
seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora]
)

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
logits = self.forward(input_ids)
next_token_logits = logits[..., -1, :]

return next_token_logits, None

def update_lora(self, lora_path: Optional[str] = None):
"""
Update and apply the LoRA to the model.
Args:
lora_path (Optional[str]): The path to the LoRA directory. If None, the LoRA will be unloaded.
"""
try:
from exllamav2 import ExLlamaV2Lora
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
)
if lora_path is None:
if self.lora is not None:
print(" -- Unloading LoRA...")
self.lora = None
else:
self.lora = ExLlamaV2Lora.from_directory(self.model, lora_path)
print(" -- Loading LoRA...")


def exl2(
model_path: str,
device: Optional[str] = None,
model_kwargs: dict = {},
device: str,
max_seq_len: Optional[int] = None,
scale_pos_emb: Optional[float] = None,
scale_alpha_value: Optional[float] = None,
no_flash_attn: Optional[bool] = None,
num_experts_per_token: Optional[int] = None,
cache_8bit: bool = False,
cache_q4: bool = False,
tokenizer_kwargs: dict = {},
):
gpu_split: Optional[str] = None,
low_mem: Optional[bool] = None,
verbose: Optional[bool] = None,
) -> ExLlamaV2Model:
"""
Load an ExLlamaV2 model.
Args:
model_path (str): Path to the model directory.
device (str): Device to load the model on. Pass in 'cuda' for GPU or 'cpu' for CPU
max_seq_len (Optional[int], optional): Maximum sequence length. Defaults to None.
scale_pos_emb (Optional[float], optional): Scale factor for positional embeddings. Defaults to None.
scale_alpha_value (Optional[float], optional): Scale alpha value. Defaults to None.
no_flash_attn (Optional[bool], optional): Disable flash attention. Defaults to None.
num_experts_per_token (Optional[int], optional): Number of experts per token. Defaults to None.
cache_8bit (bool, optional): Use 8-bit cache. Defaults to False.
cache_q4 (bool, optional): Use Q4 cache. Defaults to False.
tokenizer_kwargs (dict, optional): Additional keyword arguments for the tokenizer. Defaults to {}.
gpu_split (str): \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature
low_mem (bool, optional): Enable VRAM optimizations, potentially trading off speed
verbose (bool, optional): Enable if you want debugging statements
Returns:
ExLlamaV2Model: Loaded ExLlamaV2 model.
Raises:
ImportError: If the `exllamav2` library is not installed.
"""

try:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Config,
)
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
)

# Load tokenizer
if not verbose:
print(" -- Loading tokenizer...")
tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)
# tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs)

# Check fasttensors for config
if os.name != "nt":
use_fasttensors = True
else:
use_fasttensors = False

# Create config
config = ExLlamaV2Config()
config.model_dir = model_path
config.fasttensors = use_fasttensors
config.prepare()

config.max_seq_len = model_kwargs.pop("max_seq_len", config.max_seq_len)
config.scale_pos_emb = model_kwargs.pop("scale_pos_emb", config.scale_pos_emb)
config.scale_alpha_value = model_kwargs.pop(
"scale_alpha_value", config.scale_alpha_value
)
config.no_flash_attn = model_kwargs.pop("no_flash_attn", config.no_flash_attn)
config.num_experts_per_token = int(
model_kwargs.pop("num_experts_per_token", config.num_experts_per_token)
)

# Set config options
if max_seq_len is not None:
config.max_seq_len = max_seq_len
if scale_pos_emb is not None:
config.scale_pos_emb = scale_pos_emb
if scale_alpha_value is not None:
config.scale_alpha_value = scale_alpha_value
if no_flash_attn is not None:
config.no_flash_attn = no_flash_attn
if num_experts_per_token is not None:
config.num_experts_per_token = num_experts_per_token
if low_mem:
config.set_low_mem()

# Prepare the model from the config
model = ExLlamaV2(config)

split = None
if "gpu_split" in model_kwargs.keys():
split = [float(alloc) for alloc in model_kwargs["gpu_split"].split(",")]

model.load(split)
# Create cache
if cache_8bit:
cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded)
elif cache_q4:
cache = ExLlamaV2Cache_Q4(model, lazy=not model.loaded)
else:
cache = ExLlamaV2Cache(model, lazy=not model.loaded)

tokenizer_kwargs.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)

cache = ExLlamaV2Cache(model)
# Load the model
split = None
if gpu_split and gpu_split != "auto":
split = [float(alloc) for alloc in gpu_split.split(",")]
if not verbose:
print(" -- Loading model...")
model.load(split)

# Autoload if no GPU split was provided
if not model.loaded:
print(" -- Loading model...")
model.load_autosplit(cache)

return ExLlamaV2Model(model, tokenizer, device, cache)

0 comments on commit 03c71f7

Please sign in to comment.