Skip to content

Commit

Permalink
Adding the convert scripts which will now prevent converting models (#59
Browse files Browse the repository at this point in the history
)

* Adding the convert scripts which will now prevent converting models

in case they will trigger warnigns in the `transformers` side.
Even if the model is perfectly fine, core maintainers fear an influx
of opened issues.

This is perfectly legit.
On the `transformers` side fixes are on the way: huggingface/transformers#20042

We can wait for this PR to hit `main` before communicating super widely.

In the meantime this script of convertion will now prevent converting
models that would trigger such warnings (so the output of the script
**will** depend on the `transformers` freshness.

* Adding a nicer diff for the error when reloading.
  • Loading branch information
Narsil authored Nov 4, 2022
1 parent 1194739 commit 764ff0f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 19 deletions.
52 changes: 38 additions & 14 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import json
import os
import shutil
from tempfile import TemporaryDirectory
from collections import defaultdict
from inspect import signature
from typing import Optional, List
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional

import torch

from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import save_file
from transformers import AutoConfig
from transformers.pipelines.base import infer_framework_load_model
from safetensors.torch import save_file


class AlreadyExists(Exception):
Expand All @@ -30,15 +30,18 @@ def shared_pointers(tensors):
failing.append(names)
return failing


def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size

if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
"""
)


def rename(pt_filename: str) -> str:
Expand All @@ -47,12 +50,13 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str) -> List["CommitOperationAdd"]:
def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
with open(filename, "r") as f:
data = json.load(f)

filenames = set(data["weight_map"].values())
local_filenames = []
for filename in filenames:
cached_filename = hf_hub_download(repo_id=model_id, filename=filename)
loaded = torch.load(cached_filename)
Expand All @@ -71,7 +75,9 @@ def convert_multi(model_id: str) -> List["CommitOperationAdd"]:
json.dump(newdata, f)
local_filenames.append(index)

operations = [CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames]
operations = [
CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
]

return operations

Expand All @@ -97,16 +103,34 @@ def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
operations = [CommitOperationAdd(path_in_repo=sf_filename, path_or_fileobj=local)]
return operations


def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) -> str:
errors = []
for key in ["missing_keys", "mismatched_keys", "unexpected_keys"]:
pt_set = set(pt_infos[key])
sf_set = set(sf_infos[key])

pt_only = pt_set - sf_set
sf_only = sf_set - pt_set

if pt_only:
errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
if sf_only:
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
return "\n".join(errors)


def check_final_model(model_id: str, folder: str):
config = hf_hub_download(repo_id=model_id, filename="config.json")
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

_, pt_model = infer_framework_load_model(model_id, config)
_, sf_model = infer_framework_load_model(folder, config)
_, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
_, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)

pt_model = pt_model
sf_model = sf_model
if pt_infos != sf_infos:
error_string = create_diff(pt_infos, sf_infos)
raise ValueError(f"Different infos when reloading the model: {error_string}")

pt_params = pt_model.state_dict()
sf_params = sf_model.state_dict()
Expand Down Expand Up @@ -134,7 +158,6 @@ def check_final_model(model_id: str, folder: str):
if "image" in sig.parameters:
kwargs["image"] = pixel_values


if torch.cuda.is_available():
pt_model = pt_model.cuda()
sf_model = sf_model.cuda()
Expand All @@ -146,6 +169,7 @@ def check_final_model(model_id: str, folder: str):
torch.testing.assert_close(sf_logits, pt_logits)
print(f"Model {model_id} is ok !")


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
Expand All @@ -156,7 +180,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
return discussion


def convert(api: "HfApi", model_id: str, force: bool=False) -> Optional["CommitInfo"]:
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
Expand Down
22 changes: 17 additions & 5 deletions bindings/python/convert_all.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
"""Simple utility tool to convert automatically most downloaded models"""
from convert import AlreadyExists, convert
from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
from convert import convert, AlreadyExists
from transformers import AutoConfig


if __name__ == "__main__":
api = HfApi()
args = ModelSearchArguments()

total = 100
models = list(api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1))[:total]
total = 50
models = list(
api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1)
)[:total]

correct = 0
errors = set()
for model in models:
model = api.model_info(model.modelId, files_metadata=True)
size = None
for sibling in model.siblings:
if sibling.rfilename == "pytorch_model.bin":
size = sibling.size
if size is None or size > 2_000_000_000:
print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})")
continue

model_id = model.modelId
print(f"[{model.downloads}] {model.modelId}")
try:
Expand All @@ -22,10 +34,10 @@
correct += 1
print(e)
except Exception as e:
errors.add( model_id)
config = AutoConfig.from_pretrained(model_id)
errors.add(config.__class__.__name__)
print(e)


print(f"Errors: {errors}")
print(f"File size is difference {len(errors)}")
print(f"Correct rate {correct}/{total} ({correct/total * 100:.2f}%)")

0 comments on commit 764ff0f

Please sign in to comment.