Skip to content

Commit

Permalink
include custom safetensor and gguf format checking based on implement…
Browse files Browse the repository at this point in the history
…ation in instructlab.utils

In our case, .safetensor file validation is not needed since we don't read it to load a tokenizer

Signed-off-by: Khaled Sulayman <[email protected]>
  • Loading branch information
khaledsulayman committed Nov 14, 2024
1 parent ba9984f commit 1a47348
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def create_tokenizer(model_name: Optional[str]):
# pylint: disable=import-outside-toplevel
# Third Party
from transformers import AutoTokenizer

if model_name is None:
raise TypeError("No model path provided")

Expand Down
83 changes: 83 additions & 0 deletions src/instructlab/sdg/utils/model_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Standard
import json
import logging
import pathlib
import struct

# Third Party
from gguf.constants import GGUF_MAGIC

Check failure on line 8 in src/instructlab/sdg/utils/model_formats.py

View workflow job for this annotation

GitHub Actions / pylint

E0401: Unable to import 'gguf.constants' (import-error)

logger = logging.getLogger(__name__)


def is_model_safetensors(model_path: pathlib.Path) -> bool:
"""Check if model_path is a valid safe tensors directory
Directory must contain a specific set of files to qualify as a safetensors model directory
Args:
model_path (Path): The path to the model directory
Returns:
bool: True if the model is a safetensors model, False otherwise.
"""
try:
files = list(model_path.iterdir())
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.debug("Failed to read directory: %s", e)
return False

# directory should contain either .safetensors or .bin files to be considered valid
filetypes = [file.suffix for file in files]
if not ".safetensors" in filetypes and not ".bin" in filetypes:
logger.debug("'%s' has no .safetensors or .bin files", model_path)
return False

basenames = {file.name for file in files}
requires_files = {
"config.json",
"tokenizer.json",
"tokenizer_config.json",
}
diff = requires_files.difference(basenames)
if diff:
logger.debug("'%s' is missing %s", model_path, diff)
return False

for file in model_path.glob("*.json"):
try:
with file.open(encoding="utf-8") as f:
json.load(f)
except (PermissionError, json.JSONDecodeError) as e:
logger.debug("'%s' is not a valid JSON file: e", file, e)
return False

return True


def is_model_gguf(model_path: pathlib.Path) -> bool:
"""
Check if the file is a GGUF file.
Args:
model_path (Path): The path to the file.
Returns:
bool: True if the file is a GGUF file, False otherwise.
"""
try:
with model_path.open("rb") as f:
first_four_bytes = f.read(4)

# Convert the first four bytes to an integer
first_four_bytes_int = int(struct.unpack("<I", first_four_bytes)[0])

return first_four_bytes_int == GGUF_MAGIC
except struct.error as e:
logger.debug(
f"Failed to unpack the first four bytes of {model_path}. "
f"The file might not be a valid GGUF file or is corrupted: {e}"
)
return False
except IsADirectoryError as e:
logger.debug(f"GGUF Path {model_path} is a directory, returning {e}")
return False
except OSError as e:
logger.debug(f"An unexpected error occurred while processing {model_path}: {e}")
return False
3 changes: 0 additions & 3 deletions tests/test_chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

# Standard
from pathlib import Path
<<<<<<< HEAD
from unittest.mock import MagicMock, patch
=======
import os
>>>>>>> 192e500 (Increase Exception specificity for invalid model paths)
import tempfile

# Third Party
Expand Down

0 comments on commit 1a47348

Please sign in to comment.