Skip to content

Commit

Permalink
Enforce compact JSON serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed May 30, 2024
1 parent ba96074 commit 28ffd6e
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 19 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
exclude = .git,__pycache__,venv*,.venv*,build,dist,.local,.#*,#*,*~
restricted_packages = json
inline-quotes = "
per-file-ignores =
# imported but unused
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dom-tokenizers"
version = "0.0.13"
version = "0.0.14"
authors = [{ name = "Gary Benson", email = "[email protected]" }]
description = "DOM-aware tokenization for 🤗 Hugging Face language models"
readme = "README.md"
Expand Down Expand Up @@ -37,6 +37,7 @@ dev = [
"build",
"datasets",
"flake8",
"flake8-custom-import-rules",
"flake8-quotes",
"pillow",
"pytest",
Expand Down
4 changes: 2 additions & 2 deletions src/dom_tokenizers/diff.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import warnings

from argparse import ArgumentParser
from difflib import SequenceMatcher

from .internal import json
from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer

Expand Down Expand Up @@ -105,7 +105,7 @@ def main():
for line in open(args.reference).readlines():
row = json.loads(line)
source_index = row["source_index"]
serialized = json.dumps(row["dom_snapshot"], separators=(",", ":"))
serialized = json.dumps(row["dom_snapshot"])
b = tokenizer.tokenize(serialized)
a = row["tokenized"]
if b == a:
Expand Down
6 changes: 3 additions & 3 deletions src/dom_tokenizers/dump.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import warnings

from argparse import ArgumentParser

from datasets import load_dataset

from .internal import json
from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer

Expand Down Expand Up @@ -35,11 +35,11 @@ def main():

dataset = load_dataset(args.dataset, split=args.split)
rows = ((row["source_index"], row["dom_snapshot"]) for row in dataset)
rows = ((si, ss, json.dumps(ss, separators=(",", ":"))) for si, ss in rows)
rows = ((si, ss, json.dumps(ss)) for si, ss in rows)
rows = ((len(ser), si, ss, ser) for si, ss, ser in rows)
for _, source_index, dom_snapshot, serialized in sorted(rows):
print(json.dumps(dict(
source_index=source_index,
dom_snapshot=dom_snapshot,
tokenized=tokenizer.tokenize(serialized)
), separators=(",", ":")))
)))
18 changes: 18 additions & 0 deletions src/dom_tokenizers/internal/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from json import * # noqa: F403, CIR107


# Default to compact serialization.

def __wrap(func):
def wrapper(*args, **kwargs):
new_kwargs = {"separators": (",", ":")}
new_kwargs.update(kwargs)
return func(*args, **new_kwargs)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper


dump = __wrap(dump) # noqa: F405
dumps = __wrap(dumps) # noqa: F405
del __wrap
3 changes: 1 addition & 2 deletions src/dom_tokenizers/pre_tokenizers/dom_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json

from dataclasses import make_dataclass
from xml.dom import Node

from tokenizers import NormalizedString

from ..internal import json
from .compat_itertools import batched
from .html import is_void_element
from .pre_tokenizer import PreTokenizer
Expand Down
2 changes: 1 addition & 1 deletion src/dom_tokenizers/pre_tokenizers/splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import re

from base64 import b64decode
Expand All @@ -12,6 +11,7 @@

from unidecode import unidecode

from ..internal import json

_B64_RE_S = r"(?:[A-Za-z0-9+/]{4}){"
_B64_RE_E = r",}(?:[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?"
Expand Down
4 changes: 2 additions & 2 deletions src/dom_tokenizers/profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cProfile as profile
import json
import os
import time
import warnings
Expand All @@ -10,6 +9,7 @@
from datasets import load_dataset
from tokenizers import NormalizedString

from .internal import json
from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer

Expand Down Expand Up @@ -83,7 +83,7 @@ def main():
os.makedirs(os.path.dirname(cache_filename), exist_ok=True)
with open(cache_filename, "w") as fp:
for row in training_dataset:
json.dump(row["dom_snapshot"], fp, separators=(",", ":"))
json.dump(row["dom_snapshot"], fp)
fp.write("\n")
del training_dataset

Expand Down
2 changes: 1 addition & 1 deletion src/dom_tokenizers/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import warnings

Expand All @@ -9,6 +8,7 @@
from tokenizers import AddedToken
from tokenizers.pre_tokenizers import WhitespaceSplit

from .internal import json
from .internal.transformers import AutoTokenizer
from .pre_tokenizers import DOMSnapshotPreTokenizer

Expand Down
6 changes: 2 additions & 4 deletions tests/pre_tokenizers/dom_snapshot/test_pre_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json

from ...util import load_resource
from ...util import load_resource, json


def test_raw_response_unwrapping(pre_tokenizer):
Expand All @@ -13,7 +11,7 @@ def test_raw_response_unwrapping(pre_tokenizer):
assert set(browser_response.keys()) == {"id", "result", "sessionId"}
regular_snapshot = browser_response["result"]
assert set(regular_snapshot.keys()) == {"documents", "strings"}
regular_snapshot = json.dumps(regular_snapshot, separators=(",", ":"))
regular_snapshot = json.dumps(regular_snapshot)
assert regular_snapshot in wrapped_snapshot
del browser_response

Expand Down
4 changes: 1 addition & 3 deletions tests/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json

from datasets import Dataset

from dom_tokenizers.train import train_tokenizer, DEFAULT_VOCAB_SIZE

from .util import load_resource
from .util import load_resource, json


def test_base64(dom_snapshot_tokenizer):
Expand Down
2 changes: 2 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from dom_tokenizers.internal import json # noqa: F401


def get_resource_filename(filename, *, ext=None):
if ext and not filename.endswith(ext):
Expand Down

0 comments on commit 28ffd6e

Please sign in to comment.