From 1a47348b07ea9e4c37867f01079b0481a572522b Mon Sep 17 00:00:00 2001 From: Khaled Sulayman Date: Wed, 13 Nov 2024 22:12:35 -0500 Subject: [PATCH] include custom safetensor and gguf format checking based on implementation in instructlab.utils In our case, .safetensor file validation is not needed since we don't read it to load a tokenizer Signed-off-by: Khaled Sulayman --- src/instructlab/sdg/utils/chunkers.py | 1 + src/instructlab/sdg/utils/model_formats.py | 83 ++++++++++++++++++++++ tests/test_chunkers.py | 3 - 3 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 src/instructlab/sdg/utils/model_formats.py diff --git a/src/instructlab/sdg/utils/chunkers.py b/src/instructlab/sdg/utils/chunkers.py index 8a37886d..f03fc779 100644 --- a/src/instructlab/sdg/utils/chunkers.py +++ b/src/instructlab/sdg/utils/chunkers.py @@ -352,6 +352,7 @@ def create_tokenizer(model_name: Optional[str]): # pylint: disable=import-outside-toplevel # Third Party from transformers import AutoTokenizer + if model_name is None: raise TypeError("No model path provided") diff --git a/src/instructlab/sdg/utils/model_formats.py b/src/instructlab/sdg/utils/model_formats.py new file mode 100644 index 00000000..54272888 --- /dev/null +++ b/src/instructlab/sdg/utils/model_formats.py @@ -0,0 +1,83 @@ +# Standard +import json +import logging +import pathlib +import struct + +# Third Party +from gguf.constants import GGUF_MAGIC + +logger = logging.getLogger(__name__) + + +def is_model_safetensors(model_path: pathlib.Path) -> bool: + """Check if model_path is a valid safe tensors directory + + Directory must contain a specific set of files to qualify as a safetensors model directory + Args: + model_path (Path): The path to the model directory + Returns: + bool: True if the model is a safetensors model, False otherwise. + """ + try: + files = list(model_path.iterdir()) + except (FileNotFoundError, NotADirectoryError, PermissionError) as e: + logger.debug("Failed to read directory: %s", e) + return False + + # directory should contain either .safetensors or .bin files to be considered valid + filetypes = [file.suffix for file in files] + if not ".safetensors" in filetypes and not ".bin" in filetypes: + logger.debug("'%s' has no .safetensors or .bin files", model_path) + return False + + basenames = {file.name for file in files} + requires_files = { + "config.json", + "tokenizer.json", + "tokenizer_config.json", + } + diff = requires_files.difference(basenames) + if diff: + logger.debug("'%s' is missing %s", model_path, diff) + return False + + for file in model_path.glob("*.json"): + try: + with file.open(encoding="utf-8") as f: + json.load(f) + except (PermissionError, json.JSONDecodeError) as e: + logger.debug("'%s' is not a valid JSON file: e", file, e) + return False + + return True + + +def is_model_gguf(model_path: pathlib.Path) -> bool: + """ + Check if the file is a GGUF file. + Args: + model_path (Path): The path to the file. + Returns: + bool: True if the file is a GGUF file, False otherwise. + """ + try: + with model_path.open("rb") as f: + first_four_bytes = f.read(4) + + # Convert the first four bytes to an integer + first_four_bytes_int = int(struct.unpack(">>>>>> 192e500 (Increase Exception specificity for invalid model paths) import tempfile # Third Party