Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for tokenizer in downloaded models directory #364

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ click>=8.1.7,<9.0.0
datasets>=2.18.0,<3.0.0
docling[tesserocr]>=2.4.2,<3.0.0
GitPython>=3.1.42,<4.0.0
gguf>=0.6.0
httpx>=0.25.0,<1.0.0
instructlab-schema>=0.4.0
langchain-text-splitters
Expand All @@ -11,6 +12,7 @@ langchain-text-splitters
# do not use 8.4.0 due to a bug in the library
# https://github.com/instructlab/instructlab/issues/1389
openai>=1.13.3,<2.0.0
sentencepiece>=0.2.0
tabulate>=0.9.0
tenacity>=8.3.0,!=8.4.0
torch>=2.3.0,<2.5.0
Expand Down
60 changes: 43 additions & 17 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import DefaultDict, Iterable, List, Tuple
from typing import DefaultDict, Iterable, List, Optional, Tuple
import json
import logging
import re
Expand All @@ -21,6 +21,9 @@
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from tabulate import tabulate

# First Party
from instructlab.sdg.utils.model_formats import is_model_gguf, is_model_safetensors

logger = logging.getLogger(__name__)
_DEFAULT_CHUNK_OVERLAP = 100

Expand Down Expand Up @@ -89,8 +92,8 @@ def __new__(
output_dir: Path,
server_ctx_size=4096,
chunk_word_count=1024,
tokenizer_model_name: str | None = None,
docling_model_path: str | None = None,
tokenizer_model_name: Optional[str] = None,
docling_model_path: Optional[str] = None,
):
"""Insantiate the appropriate chunker for the provided document

Expand All @@ -100,7 +103,7 @@ def __new__(
output_dir (Path): directory where artifacts should be stored
server_ctx_size (int): Context window size of server
chunk_word_count (int): Maximum number of words to chunk a document
tokenizer_model_name (str): name of huggingface model to get
tokenizer_model_name (Optional[str]): name of huggingface model to get
tokenizer from
Returns:
TextSplitChunker | ContextAwareChunker: Object of the appropriate
Expand Down Expand Up @@ -220,19 +223,13 @@ def __init__(
filepaths,
output_dir: Path,
chunk_word_count: int,
tokenizer_model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
tokenizer_model_name: Optional[str],
docling_model_path=None,
):
self.document_paths = document_paths
self.filepaths = filepaths
self.output_dir = self._path_validator(output_dir)
self.chunk_word_count = chunk_word_count
self.tokenizer_model_name = (
tokenizer_model_name
if tokenizer_model_name is not None
else "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

self.tokenizer = self.create_tokenizer(tokenizer_model_name)
self.docling_model_path = docling_model_path

Expand Down Expand Up @@ -350,7 +347,8 @@ def fuse_texts(

return fused_texts

def create_tokenizer(self, model_name: str):
@staticmethod
def create_tokenizer(model_name: Optional[str]):
"""
Create a tokenizer instance from a pre-trained model or a local directory.

Expand All @@ -365,13 +363,41 @@ def create_tokenizer(self, model_name: str):
# Third Party
from transformers import AutoTokenizer

if model_name is None:
raise TypeError("No model path provided")

model_path = Path(model_name)
error_info_message = (
"Please run `ilab model download {download_args}` and try again"
)
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded tokenizer from: {model_name}")
if is_model_safetensors(model_path):
error_info_message = error_info_message.format(
download_args=f"--repository {model_path}"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

elif is_model_gguf(model_path):
model_dir, model_filename = model_path.parent, model_path.name
error_info_message = error_info_message.format(
download_args=f"--repository {model_dir} --filename {model_filename}"
)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, gguf_file=model_filename
)

else:
error_info_message = "Please provide a path to a valid model format. For help on downloading models, run `ilab model download --help`."
raise ValueError()

logger.info(f"Successfully loaded tokenizer from: {model_path}")
return tokenizer
except Exception as e:
logger.error(f"Failed to load tokenizer from {model_name}: {str(e)}")
raise

except (OSError, ValueError) as e:
logger.error(
f"Failed to load tokenizer as no valid model was not found at {model_path}. {error_info_message}"
)
raise e

def get_token_count(self, text, tokenizer):
"""
Expand Down
83 changes: 83 additions & 0 deletions src/instructlab/sdg/utils/model_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Standard
import json
import logging
import pathlib
import struct

# Third Party
from gguf.constants import GGUF_MAGIC

logger = logging.getLogger(__name__)


def is_model_safetensors(model_path: pathlib.Path) -> bool:
"""Check if model_path is a valid safe tensors directory
Directory must contain a specific set of files to qualify as a safetensors model directory
Args:
model_path (Path): The path to the model directory
Returns:
bool: True if the model is a safetensors model, False otherwise.
"""
try:
files = list(model_path.iterdir())
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.debug("Failed to read directory: %s", e)
return False

# directory should contain either .safetensors or .bin files to be considered valid
filetypes = [file.suffix for file in files]
if not ".safetensors" in filetypes and not ".bin" in filetypes:
logger.debug("'%s' has no .safetensors or .bin files", model_path)
return False

basenames = {file.name for file in files}
requires_files = {
"config.json",
"tokenizer.json",
"tokenizer_config.json",
}
diff = requires_files.difference(basenames)
if diff:
logger.debug("'%s' is missing %s", model_path, diff)
return False

for file in model_path.glob("*.json"):
try:
with file.open(encoding="utf-8") as f:
json.load(f)
except (PermissionError, json.JSONDecodeError) as e:
logger.debug("'%s' is not a valid JSON file: e", file, e)
return False

return True


def is_model_gguf(model_path: pathlib.Path) -> bool:
"""
Check if the file is a GGUF file.
Args:
model_path (Path): The path to the file.
Returns:
bool: True if the file is a GGUF file, False otherwise.
"""
try:
with model_path.open("rb") as f:
first_four_bytes = f.read(4)

# Convert the first four bytes to an integer
first_four_bytes_int = int(struct.unpack("<I", first_four_bytes)[0])

return first_four_bytes_int == GGUF_MAGIC
except struct.error as e:
logger.debug(
f"Failed to unpack the first four bytes of {model_path}. "
f"The file might not be a valid GGUF file or is corrupted: {e}"
)
return False
except IsADirectoryError as e:
logger.debug(f"GGUF Path {model_path} is a directory, returning {e}")
return False
except OSError as e:
logger.debug(f"An unexpected error occurred while processing {model_path}: {e}")
return False
16 changes: 12 additions & 4 deletions tests/functional/test_chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@
from pathlib import Path
import os

# Third Party
import pytest

# First Party
from instructlab.sdg.utils.chunkers import DocumentChunker

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "testdata")


def test_chunk_pdf(tmp_path):
@pytest.fixture
def tokenizer_model_name():
return os.path.join(TEST_DATA_DIR, "models/instructlab/granite-7b-lab")


def test_chunk_pdf(tmp_path, tokenizer_model_name):
pdf_path = Path(os.path.join(TEST_DATA_DIR, "sample_documents", "phoenix.pdf"))
leaf_node = [
{
Expand All @@ -23,7 +31,7 @@ def test_chunk_pdf(tmp_path):
output_dir=tmp_path,
server_ctx_size=4096,
chunk_word_count=500,
tokenizer_model_name="instructlab/merlinite-7b-lab",
tokenizer_model_name=tokenizer_model_name,
)
chunks = chunker.chunk_documents()
assert len(chunks) > 9
Expand All @@ -33,7 +41,7 @@ def test_chunk_pdf(tmp_path):
assert len(chunk) < 2500


def test_chunk_md(tmp_path):
def test_chunk_md(tmp_path, tokenizer_model_name):
markdown_path = Path(os.path.join(TEST_DATA_DIR, "sample_documents", "phoenix.md"))
leaf_node = [
{
Expand All @@ -48,7 +56,7 @@ def test_chunk_md(tmp_path):
output_dir=tmp_path,
server_ctx_size=4096,
chunk_word_count=500,
tokenizer_model_name="instructlab/merlinite-7b-lab",
tokenizer_model_name=tokenizer_model_name,
)
chunks = chunker.chunk_documents()
assert len(chunks) > 7
Expand Down
38 changes: 32 additions & 6 deletions tests/test_chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Standard
from pathlib import Path
from unittest.mock import MagicMock, patch
import os
import tempfile

# Third Party
Expand All @@ -21,10 +22,17 @@
# Local
from .testdata import testdata

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "testdata")


@pytest.fixture
def documents_dir():
return Path(__file__).parent / "testdata" / "sample_documents"
return Path(TEST_DATA_DIR) / "sample_documents"


@pytest.fixture
def tokenizer_model_name():
return os.path.join(TEST_DATA_DIR, "models/instructlab/granite-7b-lab")


@pytest.mark.parametrize(
Expand All @@ -34,7 +42,7 @@ def documents_dir():
([Path("document.pdf")], ContextAwareChunker),
],
)
def test_chunker_factory(filepaths, chunker_type, documents_dir):
def test_chunker_factory(filepaths, chunker_type, documents_dir, tokenizer_model_name):
"""Test that the DocumentChunker factory class returns the proper Chunker type"""
leaf_node = [
{
Expand All @@ -48,12 +56,12 @@ def test_chunker_factory(filepaths, chunker_type, documents_dir):
leaf_node=leaf_node,
taxonomy_path=documents_dir,
output_dir=temp_dir,
tokenizer_model_name="instructlab/merlinite-7b-lab",
tokenizer_model_name=tokenizer_model_name,
)
assert isinstance(chunker, chunker_type)


def test_chunker_factory_unsupported_filetype(documents_dir):
def test_chunker_factory_unsupported_filetype(documents_dir, tokenizer_model_name):
"""Test that the DocumentChunker factory class fails when provided an unsupported document"""
leaf_node = [
{
Expand All @@ -68,7 +76,7 @@ def test_chunker_factory_unsupported_filetype(documents_dir):
leaf_node=leaf_node,
taxonomy_path=documents_dir,
output_dir=temp_dir,
tokenizer_model_name="instructlab/merlinite-7b-lab",
tokenizer_model_name=tokenizer_model_name,
)


Expand All @@ -87,7 +95,7 @@ def test_chunker_factory_empty_filetype(documents_dir):
leaf_node=leaf_node,
taxonomy_path=documents_dir,
output_dir=temp_dir,
tokenizer_model_name="instructlab/merlinite-7b-lab",
tokenizer_model_name=tokenizer_model_name,
)


Expand Down Expand Up @@ -138,3 +146,21 @@ def test_resolve_ocr_options_none_found_logs_error(
ocr_options = resolve_ocr_options()
assert ocr_options is None
mock_logger.assert_called()


def test_create_tokenizer(tokenizer_model_name):
ContextAwareChunker.create_tokenizer(tokenizer_model_name)


@pytest.mark.parametrize(
"model_name",
[
"models/invalid_gguf.gguf",
"models/invalid_safetensors_dir/",
"bad_path",
],
)
def test_invalid_tokenizer(model_name):
model_path = os.path.join(TEST_DATA_DIR, model_name)
with pytest.raises(ValueError):
ContextAwareChunker.create_tokenizer(model_path)
35 changes: 35 additions & 0 deletions tests/testdata/models/instructlab/granite-7b-lab/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
Loading
Loading