Skip to content

Commit 4d178cc

Browse files
committed
Start reintegrating the Oracle
1 parent 568ac5a commit 4d178cc

File tree

7 files changed

+341
-10
lines changed

7 files changed

+341
-10
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ classifiers = [
2323
"Topic :: Text Processing :: Markup :: HTML",
2424
]
2525
dependencies = [
26+
"numpy",
2627
"python-magic", # XXX review
2728
"tokenizers",
29+
"transformers",
2830
"unidecode", # XXX review
2931
]
3032

@@ -42,12 +44,10 @@ dev = [
4244
"pillow",
4345
"pytest",
4446
"pytest-cov",
45-
"transformers",
4647
]
4748
train = [
4849
"datasets",
4950
"pillow",
50-
"transformers",
5151
]
5252

5353
[project.scripts]

runner.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import sys
2+
import warnings
3+
4+
from itertools import chain
5+
6+
from dom_tokenizers.internal import json
7+
from dom_tokenizers.pre_tokenizers.shared_oracle import SharedOracle
8+
9+
DEFAULT_TESTCASES = [
10+
"overflow",
11+
"uniqueid",
12+
"uniqueId",
13+
"uniqueID",
14+
"pagewrap",
15+
"pageWrap",
16+
"autocompletetype",
17+
"autocompleteType",
18+
"backfill",
19+
"Inauspicious",
20+
"Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch",
21+
22+
"1655885421832", # first token is 4 chars
23+
"8eb5e30dac7d493298287704a5f578c7",
24+
"next/static/css/99762953f4d03581",
25+
"org/TR/xhtml1/DTD/xhtml1",
26+
"KFOmCnqEu92Fr1Mu4mxK",
27+
"electronically8eb5e30dac7", # median chars/token = 1.0 (mean=2.7),
28+
"electronically8eb5e30dac", # median chars/token = 1.5 (mean=3.0)
29+
"electronically8eb5e30da", # median chars/token = 2.0 (mean=3.3)
30+
]
31+
32+
33+
def main():
34+
warnings.filterwarnings("ignore", message=r".*resume_download.*")
35+
36+
oracle = SharedOracle()
37+
if len(sys.argv) < 2:
38+
lines = DEFAULT_TESTCASES
39+
else:
40+
lines = chain.from_iterable(
41+
(json.loads(line)["text"]
42+
for line in open(filename).readlines())
43+
for filename in sys.argv[1:])
44+
45+
for line in lines:
46+
print("input:", line)
47+
result = oracle.split_if_trivial(line, log_unhandled=False)
48+
if result is not None:
49+
print(f"\x1B[32m{result}\x1B[0m\n")
50+
51+
52+
if __name__ == "__main__":
53+
main()
+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import re
2+
3+
from typing import Optional, Callable
4+
5+
import numpy as np
6+
7+
from ..internal import jsonl
8+
from ..internal.transformers import AutoTokenizer
9+
10+
_IntOrIntList = int | list[int]
11+
_StrOrStrList = str | list[str]
12+
13+
14+
class Oracle:
15+
def __init__(self, *args, **kwargs):
16+
self._tok = AutoTokenizer.from_pretrained(*args, **kwargs)
17+
self._tok.model_max_length = 1 << 31
18+
self.cls_token_id = self._tok.cls_token_id
19+
self.sep_token_id = self._tok.sep_token_id
20+
self.unk_token_id = self._tok.unk_token_id
21+
self.max_token_len = max(
22+
len(token) for token in self._tok.vocab
23+
)
24+
self.max_try_split_len = min(self.max_token_len * 5, 100)
25+
self._log = jsonl.Writer(basename="oracle", with_timestamp=True)
26+
27+
def close(self):
28+
self._log.close()
29+
30+
@property
31+
def normalize_str(self) -> Callable[[str], str]:
32+
"""Normalize the given string.
33+
"""
34+
return self._tok.backend_tokenizer.normalizer.normalize_str
35+
36+
def encode(self, *args, **kwargs) -> list[int]:
37+
"""Convert the given string to a list of integer token IDs.
38+
"""
39+
token_ids = self._tok.encode(*args, **kwargs)
40+
assert token_ids[0] == self.cls_token_id
41+
assert token_ids[-1] == self.sep_token_id
42+
return token_ids[1:-1]
43+
44+
IDsToTokensType = Callable[[_IntOrIntList], _StrOrStrList]
45+
46+
@property
47+
def convert_ids_to_tokens(self, *args, **kwargs) -> IDsToTokensType:
48+
"""Convert the given list of token IDs to a list of tokens.
49+
"""
50+
return self._tok.convert_ids_to_tokens
51+
52+
def tokenize(self, *args, **kwargs) -> list[str]:
53+
"""Convert the given string into a list of tokens.
54+
"""
55+
return self.convert_ids_to_tokens(self.encode(*args, **kwargs))
56+
57+
@property
58+
def decode(self) -> Callable[[_IntOrIntList], str]:
59+
"""Convert the given list of token IDs to a string.
60+
"""
61+
return self._tok.decode
62+
63+
# For quick checks, see TextSplitter.BASE64_RE for the real deal
64+
_LOOSE_BASE64_RE = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
65+
66+
def split_if_trivial(
67+
self,
68+
text: str,
69+
log_unhandled: bool = True, # XXX
70+
) -> Optional[list[str]]:
71+
"""Split a string into a list of tokens XXX IF!
72+
73+
Like `tokenize()` but it only returns if XXX. Otherwise None is
74+
returned.
75+
"""
76+
if len(text) > self.max_try_split_len:
77+
return None
78+
79+
# Fast path for text that's in the oracle's vocabulary.
80+
if len(text) <= self.max_token_len and (
81+
(text in self._tok.vocab
82+
or text.lower() in self._tok.vocab)
83+
and text.isalnum()):
84+
return [text]
85+
86+
# Limit ourselves to base64-ish input, for now at least.
87+
if not self._LOOSE_BASE64_RE.match(text):
88+
raise NotImplementedError(text)
89+
90+
token_ids = self.encode(text)
91+
if not token_ids or self.unk_token_id in token_ids:
92+
return None
93+
94+
tokens = self.convert_ids_to_tokens(token_ids)
95+
word_pieces = [token.lstrip("#") for token in tokens]
96+
token_lengths = [len(token) for token in word_pieces]
97+
98+
# If the tokens are mostly 2+ characters long and the
99+
# input text splits on whitespace in the same places as
100+
# the decoded token ID sequence then call this a match.
101+
# Subtracting the standard deviation prevents situations
102+
# where one long token skews the median away from a load
103+
# of 1-2 character tokens, e.g. "electronically8eb5e30da"
104+
# tokenizes to ["electronically", "8", "eb", "5", "e",
105+
# "30", "da"] with bert-base-uncased, so a median token
106+
# length of 2 characters/token and a mean of 3.3, but
107+
# the standard deviation of 4.4 indicates at least one
108+
# token is very far from the mean.
109+
median_length = np.median(token_lengths)
110+
length_stddev = np.std(token_lengths)
111+
if median_length - length_stddev > 1:
112+
result = text.split()
113+
want = [token.lower() for token in result]
114+
if self.decode(token_ids).split() == want:
115+
return result
116+
117+
print(f"tokens: {tokens}"[:80])
118+
119+
first_token_id = token_ids[0]
120+
first_token = self.convert_ids_to_tokens(first_token_id)
121+
assert "#" not in first_token
122+
print(f"first_token: {first_token!r} ({first_token_id})")
123+
124+
chars_per_token = len(text) / len(token_ids)
125+
126+
#mean = sum(token_lengths) / len(token_ids)
127+
print("chars_per_token:", chars_per_token)
128+
#print("or ------> mean:", mean)
129+
print(" median:", median_length)
130+
print(" std.dev:", length_stddev)
131+
print()
132+
133+
# XXX now what?
134+
if log_unhandled:
135+
self._log.write(
136+
text=text, token_ids=token_ids,
137+
tokens=tokens,
138+
decoded=self.decode(token_ids),
139+
chars_per_token=chars_per_token,
140+
)
141+
return None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import atexit
2+
3+
from .oracle import Oracle
4+
5+
6+
class SharedOracle(Oracle):
7+
_shared_borg_state = {}
8+
9+
def __new__(cls, *args, **kwargs):
10+
obj = super().__new__(cls)
11+
obj.__dict__ = cls._shared_borg_state
12+
return obj
13+
14+
def __init__(self, model="bert-base-uncased", *args, **kwargs):
15+
if hasattr(self, "_tok"):
16+
return
17+
super().__init__(model, *args, **kwargs)
18+
atexit.register(self.close)

src/dom_tokenizers/pre_tokenizers/splitter.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from base64 import binascii, b64decode
55
from collections import defaultdict
66
from collections.abc import Iterable
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from urllib.parse import unquote
99

1010
import magic
1111

1212
from unidecode import unidecode
1313

1414
from ..internal import json
15+
from .oracle import Oracle
16+
from .shared_oracle import SharedOracle
1517
from .sniffer import sniff_bytes
1618

1719
logger = logging.getLogger(__name__)
@@ -44,6 +46,7 @@ class FalseBase64Error(RuntimeError):
4446
class TextSplitter:
4547
base64_token: str = "[BASE64]"
4648
long_token: str = "[LONG]"
49+
oracle: Oracle = field(default_factory=SharedOracle)
4750

4851
@property
4952
def special_tokens(self) -> Iterable[str]:
@@ -435,20 +438,20 @@ def _sub_urlencoded(self, splits, cursor):
435438

436439
def _split_base64(self, encoded):
437440
try:
438-
encoded = encoded.encode("ascii")
441+
_encoded = encoded.encode("ascii")
439442
except UnicodeEncodeError:
440443
return None
441444
try:
442-
data = b64decode(encoded, validate=True)
445+
data = b64decode(_encoded, validate=True)
443446
except binascii.Error:
444447
return None
445448
try:
446449
text = data.decode("utf-8")
447450
except UnicodeDecodeError:
448-
return self._split_base64_binary(data)
449-
return self._split_base64_utf8(text)
451+
return self._split_base64_binary(data, encoded)
452+
return self._split_base64_utf8(text, encoded)
450453

451-
def _split_base64_utf8(self, text):
454+
def _split_base64_utf8(self, text, encoded):
452455
match = self.XML_HDR_RE.match(text)
453456
if match is not None:
454457
if match.group(1) == "svg":
@@ -459,12 +462,16 @@ def _split_base64_utf8(self, text):
459462
return [self.base64_token, "json"]
460463
except json.JSONDecodeError:
461464
pass
465+
if self.oracle.first_is_better(encoded, text):
466+
return None # encoded is better
462467
return [self.base64_token, "text"]
463468

464-
def _split_base64_binary(self, data):
469+
def _split_base64_binary(self, data, encoded):
465470
filetype = sniff_bytes(data)
466471
if not filetype:
467-
return None
472+
if self.oracle.is_texty(encoded):
473+
return None
474+
return [self.base64_token, "data"]
468475
return [self.base64_token, filetype.name.lower()]
469476

470477
# XXX junk?

0 commit comments

Comments
 (0)