Skip to content

Commit

Permalink
Merge pull request #17 from LLukas22/feat/huggingface-tokenizers
Browse files Browse the repository at this point in the history
Added HuggingFace Tokenizers support
  • Loading branch information
LLukas22 authored Jun 4, 2023
2 parents fb75d58 + 2eff99c commit e2925c4
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 48 deletions.
Binary file modified .github/workflows/CI.yml
Binary file not shown.
39 changes: 39 additions & 0 deletions .github/workflows/MacOS-CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This file is autogenerated by maturin v0.15.2
# To update, run
#
# maturin generate-ci --zig github
#
name: MacOS-CI

on:
workflow_dispatch:

permissions:
contents: read

jobs:
macos:
runs-on: macos-latest
strategy:
fail-fast: false
matrix:
target: [x86_64, aarch64]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter --zig
sccache: 'true'
env:
RUSTFLAGS: "-C link-arg=-undefined -C link-arg=dynamic_lookup"

- name: Upload wheels
uses: actions/upload-artifact@v3
with:
name: wheels
path: dist
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llm-rs"
version = "0.2.8"
version = "0.2.9"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -9,10 +9,10 @@ name = "llm_rs"
crate-type = ["cdylib"]

[dependencies]
pyo3 = "0.18.3"
pyo3 = {version="0.19.0", features=["extension-module", "generate-import-lib"]}
rand = "0.8.5"
rand_chacha = "0.3.1"
log = "0.4.17"
llm = { git = "https://github.com/rustformers/llm.git", rev="ccdd2ab" }
llm-base = { git = "https://github.com/rustformers/llm.git",rev="ccdd2ab" }
ggml = { git = "https://github.com/rustformers/llm.git",rev="ccdd2ab" }
llm = { git = "https://github.com/rustformers/llm.git", rev="e52a102" }
llm-base = { git = "https://github.com/rustformers/llm.git",rev="e52a102" }
ggml = { git = "https://github.com/rustformers/llm.git",rev="e52a102" }
39 changes: 29 additions & 10 deletions llm_rs/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .models import Mpt,GptNeoX,GptJ,Gpt2,Bloom,Llama
from .base_model import Model
import logging
from typing import Optional, List, Union,Type,Dict
from typing import Optional, List, Union,Type,Dict, Callable
import os
from enum import Enum, auto
from dataclasses import dataclass
Expand Down Expand Up @@ -209,20 +209,31 @@ def _infer_model_type(cls,model_file:Union[str,os.PathLike],known_model:Optional
def from_file(cls, path:Union[str,os.PathLike],
model_type: Optional[KnownModels] = None,
session_config:SessionConfig=SessionConfig(),
tokenizer_path_or_repo_id: Optional[Union[str,os.PathLike]]=None,
lora_paths:Optional[List[Union[str,os.PathLike]]]=None,
verbose:bool=False)->Model:
verbose:bool=False,
use_hf_tokenizer:bool=True)->Model:

tokenizer = tokenizer_path_or_repo_id
if use_hf_tokenizer and tokenizer is None:
metadata = cls.load_metadata(path)
tokenizer = metadata.base_model
if tokenizer is None or tokenizer == "":
raise ValueError(f"Model file '{path}' does not have a base_model specified in its metadata file but wants to use a huggingface-tokenizer! Please specify a base_model or expilicitly specify a tokenizer via `tokenizer_path_or_repo_id`.")

model = cls._infer_model_type(path,model_type)
return model(path,session_config,lora_paths,verbose)
return model(path,session_config,tokenizer_path_or_repo_id,lora_paths,verbose)

@classmethod
def from_pretrained(cls,
model_path_or_repo_id: Union[str,os.PathLike],
model_file: Optional[str] = None,
model_type: Optional[KnownModels] = None,
session_config:SessionConfig=SessionConfig(),
tokenizer_path_or_repo_id: Optional[Union[str,os.PathLike]]=None,
lora_paths:Optional[List[Union[str,os.PathLike]]]=None,
verbose:bool=False,
use_hf_tokenizer:bool=True,
default_quantization:QuantizationType=QuantizationType.Q4_0,
default_container:ContainerType=ContainerType.GGJT)->Model:

Expand All @@ -231,7 +242,7 @@ def from_pretrained(cls,
if path_type == PathType.UNKNOWN:
raise ValueError(f"Unknown path type for '{model_path_or_repo_id}'")
elif path_type == PathType.FILE:
return cls.from_file(model_path_or_repo_id,model_type,session_config,lora_paths,verbose)
return cls.from_file(model_path_or_repo_id,model_type,session_config,tokenizer_path_or_repo_id,lora_paths,verbose,use_hf_tokenizer)
else:
if path_type == PathType.REPO:

Expand All @@ -246,14 +257,14 @@ def from_pretrained(cls,

if config.repo_type != "GGML":
logging.warning("Found normal HuggingFace model, starting conversion...")
return cls.from_transformer(model_path_or_repo_id, session_config, lora_paths, verbose, default_quantization, default_container)
return cls.from_transformer(model_path_or_repo_id, session_config, tokenizer_path_or_repo_id, lora_paths, verbose, use_hf_tokenizer,default_quantization, default_container)

resolved_path = cls._find_model_path_from_repo(str(model_path_or_repo_id),model_file)
return cls.from_file(resolved_path,model_type,session_config,lora_paths,verbose)
return cls.from_file(resolved_path,model_type,session_config,tokenizer_path_or_repo_id,lora_paths,verbose,use_hf_tokenizer)

elif path_type == PathType.DIR:
resolved_path = cls._find_model_path_from_dir(str(model_path_or_repo_id),model_file)
return cls.from_file(resolved_path,model_type,session_config,lora_paths,verbose)
return cls.from_file(resolved_path,model_type,session_config,tokenizer_path_or_repo_id,lora_paths,verbose,use_hf_tokenizer)

else:
raise ValueError(f"Unknown path type '{path_type}'")
Expand Down Expand Up @@ -322,8 +333,10 @@ def _find_model_path_from_repo(
def from_transformer(cls,
model_path_or_repo_id: Union[str,os.PathLike],
session_config:SessionConfig=SessionConfig(),
tokenizer_path_or_repo_id: Optional[Union[str,os.PathLike]]=None,
lora_paths:Optional[List[Union[str,os.PathLike]]]=None,
verbose:bool=False,
use_hf_tokenizer:bool=True,
default_quantization:QuantizationType=QuantizationType.Q4_0,
default_container:ContainerType=ContainerType.GGJT):

Expand All @@ -341,7 +354,7 @@ def from_transformer(cls,
converted_model = AutoConverter.convert(model_path_or_repo_id,export_path)
if default_quantization != QuantizationType.F16:
converted_model = AutoQuantizer.quantize(converted_model,quantization=default_quantization,container=default_container)
return cls.from_file(converted_model,None,session_config,lora_paths,verbose)
return cls.from_file(converted_model,None,session_config,tokenizer_path_or_repo_id,lora_paths,verbose,use_hf_tokenizer)

# Hack to make the quantization type enum hashable
_APPENDIX_MAP = {
Expand All @@ -357,7 +370,13 @@ class AutoQuantizer():
Utility to quantize models, without having to specify the model type.
"""
@staticmethod
def quantize(model_file:Union[str,os.PathLike],target_path:Optional[Union[str,os.PathLike]]=None,quantization:QuantizationType=QuantizationType.Q4_0,container:ContainerType=ContainerType.GGJT)->Union[str,os.PathLike]:
def quantize(
model_file:Union[str,os.PathLike],
target_path:Optional[Union[str,os.PathLike]]=None,
quantization:QuantizationType=QuantizationType.Q4_0,
container:ContainerType=ContainerType.GGJT,
callback:Optional[Callable[[str],None]]=None
)->Union[str,os.PathLike]:
metadata=AutoModel.load_metadata(model_file)
if metadata.quantization != QuantizationType.F16:
raise ValueError(f"Model '{model_file}' is already quantized to '{metadata.quantization}'")
Expand Down Expand Up @@ -391,7 +410,7 @@ def build_target_name()->str:
return target_file

logging.info(f"Quantizing model '{model_file}' to '{target_file}'")
model_type.quantize(str(model_file),target_file,quantization,container)
model_type.quantize(str(model_file),target_file,quantization,container,callback=callback)

metadata_file = pathlib.Path(target_file).with_suffix(".meta")
quantized_metadata = ModelMetadata(model=metadata.model,quantization=quantization,container=container,quantization_version=CURRENT_QUANTIZATION_VERSION,base_model=metadata.base_model)
Expand Down
5 changes: 3 additions & 2 deletions llm_rs/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def lora_paths(self)->Optional[List[str]]: ...
def __init__(self,
path:Union[str,os.PathLike],
session_config:SessionConfig=SessionConfig(),
tokenizer_name_or_path:Optional[Union[str,os.PathLike]]=None,
lora_paths:Optional[List[Union[str,os.PathLike]]]=None,
verbose:bool=False) -> None: ...

def generate(self,prompt:str,
generation_config:Optional[GenerationConfig]=None,
callback:Callable[[str],Optional[bool]]=None) -> GenerationResult:
callback:Optional[Callable[[str],Optional[bool]]]=None) -> GenerationResult:
"""
Generates text from a prompt.
"""
Expand All @@ -58,7 +59,7 @@ def decode(self,tokens:List[int]) -> str:
...

@staticmethod
def quantize(source:str,destination:str,quantization:QuantizationType=QuantizationType.Q4_0,container:ContainerType=ContainerType.GGJT)->None:
def quantize(source:str,destination:str,quantization:QuantizationType=QuantizationType.Q4_0,container:ContainerType=ContainerType.GGJT,callback:Optional[Callable[[str],None]]=None)->None:
"""
Quantizes the model.
"""
Expand Down
14 changes: 8 additions & 6 deletions src/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ impl GenerationConfig {
impl GenerationConfig {
pub fn to_llm_params(&self, n_threads: usize, n_batch: usize) -> InferenceParameters {
InferenceParameters {
top_k: self.top_k,
top_p: self.top_p,
temperature: self.temperature,
repeat_penalty: self.repetition_penalty,
repetition_penalty_last_n: self.repetition_penalty_last_n,
bias_tokens: TokenBias::default(),
n_threads,
n_batch,
sampler: std::sync::Arc::new(llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
temperature: self.temperature,
repeat_penalty: self.repetition_penalty,
repetition_penalty_last_n: self.repetition_penalty_last_n,
bias_tokens: TokenBias::default(),
}),
}
}
}
Expand Down
55 changes: 45 additions & 10 deletions src/model_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl GenerationStreamer {
}
}

pub fn _tokenize(model: &dyn llm::Model, text: &str) -> Result<Vec<i32>, InferenceError> {
pub fn _tokenize(model: &dyn llm::Model, text: &str) -> Result<Vec<u32>, InferenceError> {
Ok(model
.vocabulary()
.tokenize(text, false)?
Expand All @@ -98,12 +98,9 @@ pub fn _tokenize(model: &dyn llm::Model, text: &str) -> Result<Vec<i32>, Inferen
.collect())
}

pub fn _decode(model: &dyn llm::Model, tokens: Vec<i32>) -> Result<String, std::str::Utf8Error> {
pub fn _decode(model: &dyn llm::Model, tokens: Vec<u32>) -> Result<String, std::str::Utf8Error> {
let vocab = model.vocabulary();
let characters: Vec<u8> = tokens
.into_iter()
.flat_map(|token| vocab.id_to_token[token as usize].to_owned())
.collect();
let characters: Vec<u8> = vocab.decode(tokens, false);

match std::str::from_utf8(&characters) {
Ok(text) => Ok(text.to_string()),
Expand Down Expand Up @@ -163,7 +160,7 @@ pub fn _infer_next_token(
}

//Buffer until a valid utf8 sequence is found
if let Some(s) = utf8_buf.push(token) {
if let Some(s) = utf8_buf.push(&token) {
return Ok(Some(s));
}
}
Expand Down Expand Up @@ -303,6 +300,7 @@ macro_rules! wrap_model {
fn new(
path: String,
session_config: Option<crate::configs::SessionConfig>,
tokenizer_name_or_path: Option<String>,
lora_paths: Option<Vec<String>>,
verbose: Option<bool>,
) -> Self {
Expand All @@ -320,8 +318,27 @@ macro_rules! wrap_model {
prefer_mmap: config_to_use.prefer_mmap,
lora_adapters: lora_paths.clone(),
};

let vocabulary_source: llm_base::VocabularySource;

if let Some(name_or_path) = tokenizer_name_or_path {
let tokenizer_path = std::path::Path::new(&name_or_path);
if tokenizer_path.is_file() && tokenizer_path.exists() {
// Load tokenizer from file
vocabulary_source = llm_base::VocabularySource::HuggingFaceTokenizerFile(
tokenizer_path.to_owned(),
);
} else {
// Load tokenizer from HuggingFace
vocabulary_source =
llm_base::VocabularySource::HuggingFaceRemote(name_or_path);
}
} else {
vocabulary_source = llm_base::VocabularySource::Model;
}

let llm_model: $llm_model =
llm_base::load(&path, model_params, None, |load_progress| {
llm_base::load(&path, vocabulary_source, model_params, |load_progress| {
if should_log {
llm_base::load_progress_callback_stdout(load_progress)
}
Expand Down Expand Up @@ -399,14 +416,14 @@ macro_rules! wrap_model {
})
}

fn tokenize(&self, text: String) -> PyResult<Vec<i32>> {
fn tokenize(&self, text: String) -> PyResult<Vec<u32>> {
match crate::model_base::_tokenize(self.llm_model.as_ref(), &text) {
Ok(tokens) => Ok(tokens),
Err(e) => Err(pyo3::exceptions::PyException::new_err(e.to_string())),
}
}

fn decode(&self, tokens: Vec<i32>) -> PyResult<String> {
fn decode(&self, tokens: Vec<u32>) -> PyResult<String> {
match crate::model_base::_decode(self.llm_model.as_ref(), tokens) {
Ok(tokens) => Ok(tokens),
Err(e) => Err(pyo3::exceptions::PyException::new_err(e.to_string())),
Expand All @@ -415,16 +432,34 @@ macro_rules! wrap_model {

#[staticmethod]
fn quantize(
_py: Python,
source: String,
destination: String,
quantization: Option<crate::quantize::QuantizationType>,
container: Option<crate::quantize::ContainerType>,
callback: Option<PyObject>,
) -> PyResult<()> {
let mut callback_function: Option<&PyAny> = None;
let pytohn_object: Py<PyAny>;

if let Some(unwrapped) = callback {
pytohn_object = unwrapped;
let python_function = pytohn_object.as_ref(_py);
callback_function = Some(python_function);
assert!(python_function.is_callable(), "Callback is not callable!");
}

crate::quantize::_quantize::<$llm_model>(
source.into(),
destination.into(),
container.unwrap_or(crate::quantize::ContainerType::GGJT),
quantization.unwrap_or(crate::quantize::QuantizationType::Q4_0),
|message| {
if let Some(callback) = callback_function {
let args = pyo3::types::PyTuple::new(_py, &[message]);
callback.call1(args).unwrap();
}
},
)
.map_err(|e| pyo3::exceptions::PyException::new_err(e.to_string()))
}
Expand Down
Loading

0 comments on commit e2925c4

Please sign in to comment.