Skip to content

Commit 2719ff5

Browse files
committed
Initial pre-tokenizer and trainer
1 parent 00b4311 commit 2719ff5

File tree

7 files changed

+394
-0
lines changed

7 files changed

+394
-0
lines changed

.flake8

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[flake8]
2+
exclude = .git,__pycache__,venv*,.venv*,build,dist,.local,.#*,#*,*~
3+
per-file-ignores =
4+
# imported but unused
5+
src/dom_tokenizers/**/__init__.py: F401
6+
# line too long
7+
src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# DOM tokenizers
2+
3+
HTML DOM-aware tokenizers for Hugging Face language models.
4+
5+
## Setup for development
6+
7+
```sh
8+
git clone --recursive https://github.com/gbenson/dom-tokenizers.git
9+
cd dom-tokenizers
10+
python3 -m venv .venv
11+
. .venv/bin/activate
12+
pip install --upgrade pip
13+
pip install -e .[dev]
14+
```

pyproject.toml

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[project]
2+
name = "dom-tokenizers"
3+
version = "0.0.1"
4+
dependencies = [
5+
"python-magic",
6+
"tokenizers",
7+
"transformers",
8+
]
9+
10+
[project.optional-dependencies]
11+
dev = [
12+
"datasets",
13+
"flake8",
14+
"pillow",
15+
]
16+
17+
[project.scripts]
18+
train-tokenizer = "dom_tokenizers.train:main"
19+
20+
[build-system]
21+
requires = ["setuptools>=61.0"]
22+
build-backend = "setuptools.build_meta"

src/dom_tokenizers/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dom_snapshot import DOMSnapshotPreTokenizer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import json
2+
import re
3+
4+
from base64 import b64decode
5+
from collections import defaultdict
6+
from collections.abc import Iterable
7+
from functools import cached_property
8+
from itertools import chain
9+
from posixpath import commonprefix
10+
from typing import List
11+
from xml.dom import Node
12+
13+
import magic
14+
15+
from tokenizers import NormalizedString, PreTokenizedString
16+
17+
18+
class DOMSnapshotPreTokenizer:
19+
"""Pre-tokenizer that consumes JSON-serialized DOM snapshots
20+
and emits tokenized representations of the snapshotted DOMs.
21+
"""
22+
bos_token = "[BOS]" # beginning of sequence
23+
eos_token = "[EOS]" # end of sequence
24+
sep_token = "[SEP]" # separator between documents
25+
elem_token = "[TAG]" # beginning of element name
26+
attr_token = "[ATTR]" # beginning of attribute
27+
comm_token = "[COMMENT]" # beginning of comment
28+
base64_token = "[BASE64]" # beginning of some base64
29+
long_token = "[LONG]" # elided long token
30+
31+
@property
32+
def special_tokens(self):
33+
return [
34+
value
35+
for attr, value in self.__class__.__dict__.items()
36+
if attr.endswith("token")
37+
]
38+
39+
def pre_tokenize(self, pretok: PreTokenizedString):
40+
"""Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place.
41+
"""
42+
pretok.split(self._split_json)
43+
44+
def _split_json(self, i: int, s: NormalizedString) -> List[NormalizedString]:
45+
snapshot = json.loads(s.normalized)
46+
return list(chain.from_iterable(self._split_serialized(snapshot)))
47+
48+
def _split_serialized(self, snapshot: dict) -> Iterable[List[NormalizedString]]:
49+
emitter = TokenEmitter(self, snapshot)
50+
elem_token = [NormalizedString(self.elem_token)]
51+
attr_token = [NormalizedString(self.attr_token)]
52+
53+
for document_index, document in enumerate(snapshot["documents"]):
54+
token = self.bos_token if document_index == 0 else self.sep_token
55+
yield [NormalizedString(token)]
56+
57+
nodes = document["nodes"]
58+
for node_index, node_values in enumerate(zip(
59+
nodes["nodeType"],
60+
nodes["nodeName"],
61+
nodes["nodeValue"],
62+
nodes["attributes"])):
63+
node_type, name_index, value_index, attr_indexes = node_values
64+
65+
match node_type:
66+
case Node.ELEMENT_NODE:
67+
yield elem_token
68+
yield emitter.emit(name_index)
69+
for attr_index in range(0, len(attr_indexes), 2):
70+
yield attr_token
71+
yield emitter.emit(attr_indexes[attr_index])
72+
yield emitter.emit(attr_indexes[attr_index + 1])
73+
74+
case Node.TEXT_NODE:
75+
yield emitter.emit(value_index)
76+
77+
case Node.COMMENT_NODE:
78+
yield [NormalizedString(self.comm_token)]
79+
yield emitter.emit(value_index)
80+
81+
yield [NormalizedString(self.eos_token)]
82+
83+
84+
_B64_RE_S = r"(?:[A-Za-z0-9+/]{4}){"
85+
_B64_RE_E = r",}(?:[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?"
86+
87+
88+
def base64_matcher(min_encoded_len=24):
89+
min_groups, extra = divmod(min_encoded_len, 4)
90+
if extra:
91+
min_groups += 1
92+
return re.compile(f"{_B64_RE_S}{min_groups}{_B64_RE_E}")
93+
94+
95+
class TokenEmitter:
96+
MAXWORDLEN = 32
97+
WORD_RE = re.compile(
98+
r"[a-z0-9]+(?:[a-z0-9']*[a-z0-9])?") # XXX English only :(
99+
ESCAPED_RE = re.compile(
100+
r"((?:%|\\x|\\u[0-9a-f]{2})[0-9a-f]{2})", re.I)
101+
HEX_RE = re.compile(r"^(?:0x|[0-9a-f]{2})[0-9a-f]{6,}$")
102+
DIGIT_RE = re.compile(r"\d")
103+
URLISH_RE = re.compile(r"(?:[a-z]+|[0-9a-f]+|[A-Z0-9]+)")
104+
SHORTEST_URLISH = 16
105+
LONGEST_PHITEST = 85
106+
BASE64_RE = base64_matcher()
107+
B64_PNG_RE = re.compile(r"iVBORw0KGg[o-r]")
108+
XML_HDR_RE = re.compile(r"<([a-z]{3,})\s+[a-z]+")
109+
110+
def __init__(self, pretokenizer: DOMSnapshotPreTokenizer, snapshot: dict):
111+
self._pt = pretokenizer
112+
self._strings = snapshot["strings"]
113+
self._tokens = {}
114+
115+
@cached_property
116+
def base64_token(self):
117+
return self._pt.base64_token
118+
119+
@cached_property
120+
def long_token(self):
121+
return self._pt.long_token
122+
123+
def emit(self, string_index: int) -> Iterable[NormalizedString]:
124+
"""Emit tokens for one string in a DOM snapshot's string table.
125+
126+
It splits on any non-alphanumeric character, but also tries
127+
to detect (and recurse into) base64-encoded date, of which
128+
there's a lot in just the 295 `interesting-dom-snapshots`.
129+
(Not dealing with base64 results in a whole load of "words"
130+
which are just fragments of base64. It isn't easy though,
131+
lots of regular text is valid base64, we have to sniff.)
132+
"""
133+
if string_index < 0:
134+
return []
135+
tokens = self._tokens.get(string_index)
136+
if tokens is not None:
137+
return tokens
138+
tokens = [
139+
NormalizedString(token)
140+
for token in self._postprocess(
141+
chain.from_iterable(
142+
self._split(
143+
self._preprocess(
144+
self._strings[string_index]))))
145+
]
146+
self._tokens[string_index] = tokens
147+
return tokens
148+
149+
def _preprocess(self, text):
150+
return "".join(
151+
self._unescape_char(s) if i & 1 else s
152+
for i, s in enumerate(self.ESCAPED_RE.split(text))
153+
)
154+
155+
def _unescape_char(self, escaped):
156+
if escaped[0] == "%":
157+
escaped = "\\x" + escaped[1:]
158+
return eval(f'"{escaped}"')
159+
160+
def _split(self, text):
161+
while text:
162+
match = self.BASE64_RE.search(text)
163+
if match is not None:
164+
start, limit = match.span()
165+
else:
166+
start = limit = len(text)
167+
if start > 0:
168+
yield self._split_words(text[:start])
169+
if limit > start:
170+
encoded = text[start:limit]
171+
matched = self._match_urlish_base64(encoded)
172+
if matched is not None:
173+
limit = start + len(matched)
174+
yield self._split_words(text[start:limit])
175+
else:
176+
yield self._enter_base64(encoded)
177+
if limit == len(text):
178+
break
179+
text = text[limit:]
180+
181+
def _split_words(self, text):
182+
return self.WORD_RE.findall(text.lower())
183+
184+
def _match_urlish_base64(self, encoded):
185+
urlish = "/".join(self.URLISH_RE.findall(encoded))
186+
result = commonprefix((encoded, urlish))
187+
if len(result) < self.SHORTEST_URLISH:
188+
return None
189+
return result
190+
191+
def _enter_base64(self, encoded):
192+
# Lots of false-positives here, try sniffing
193+
if self.B64_PNG_RE.match(encoded):
194+
return [self.base64_token, "png"]
195+
data = b64decode(encoded)
196+
try:
197+
text = data.decode("utf-8")
198+
except UnicodeDecodeError:
199+
text = None
200+
if text is not None:
201+
return self._enter_base64_utf8(text)
202+
return self._enter_base64_binary(data, encoded)
203+
204+
def _enter_base64_utf8(self, text):
205+
# XXX recurse??
206+
match = self.XML_HDR_RE.match(text)
207+
if match is not None:
208+
if match.group(1) == "svg":
209+
return [self.base64_token, "svg"]
210+
return [self.base64_token, "xml"]
211+
try:
212+
_ = json.loads(text)
213+
return [self.base64_token, "json"]
214+
except json.JSONDecodeError:
215+
pass
216+
return [self.base64_token, "utf", "8"]
217+
218+
def _enter_base64_binary(self, data, encoded):
219+
# Not out of false-positive territory yet
220+
full_magic = magic.from_buffer(data)
221+
easy_magic = full_magic.split(maxsplit=1)[0]
222+
if easy_magic in {"GIF", "zlib", "JPEG"}:
223+
return [self.base64_token, easy_magic.lower()]
224+
if " Web/P image" in full_magic:
225+
return [self.base64_token, "webp"]
226+
if full_magic.startswith("Web Open Font Format"):
227+
return [self.base64_token, "woff"]
228+
if len(encoded) > self.LONGEST_PHITEST:
229+
return [self.base64_token]
230+
# phi test for monoalphabeticity
231+
hist = defaultdict(int)
232+
for symbol in encoded:
233+
hist[symbol] += 1
234+
phi_o = sum(freq * (freq - 1) for freq in hist.values())
235+
N = len(encoded)
236+
phi_r = N * (N - 1) / 64
237+
# non-standard comparison (observed phi > twice random)
238+
if phi_o > phi_r * 2:
239+
return self._split_words(encoded)
240+
return [self.base64_token]
241+
242+
def _postprocess(self, tokens: Iterable[str]) -> Iterable[str]:
243+
for token in tokens:
244+
if self.HEX_RE.match(token):
245+
yield self.long_token
246+
try:
247+
_ = int(token)
248+
except ValueError:
249+
yield "hex"
250+
yield "digits"
251+
continue
252+
253+
if len(token) <= self.MAXWORDLEN:
254+
yield token
255+
continue
256+
257+
yield self.long_token
258+
if self.DIGIT_RE.search(token):
259+
yield "alphanumeric"
260+
else:
261+
yield "alphabetic"

src/dom_tokenizers/train.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import json
2+
import warnings
3+
4+
from datasets import load_dataset
5+
from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit
6+
from transformers import AutoTokenizer
7+
8+
from .pre_tokenizers import DOMSnapshotPreTokenizer
9+
10+
FULL_DATASET = "gbenson/webui-dom-snapshots"
11+
TEST_DATASET = "gbenson/interesting-dom-snapshots"
12+
13+
14+
def train_tokenizer(
15+
*args,
16+
training_dataset=None,
17+
base_tokenizer="bert-base-uncased",
18+
vocab_size=1024, # XXX including all tokens and alphabet
19+
**kwargs):
20+
"""
21+
XXX
22+
base_tokenizer
23+
all other args passed to load_dataset for XXX...
24+
"""
25+
26+
# Load the training data we'll train our new tokenizer with.
27+
if training_dataset is None:
28+
training_dataset = load_dataset(*args, **kwargs)
29+
30+
# Create the base tokenizer we'll train our new tokenizer from.
31+
if isinstance(base_tokenizer, str):
32+
base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer)
33+
34+
# Create the custom pretokenizer our new tokenizer will use.
35+
new_pretokenizer = DOMSnapshotPreTokenizer()
36+
37+
# List the custom special tokens that need adding to our tokenizer.
38+
new_special_tokens = [
39+
special_token
40+
for special_token in new_pretokenizer.special_tokens
41+
if base_tokenizer.tokenize(special_token) != [special_token]
42+
]
43+
44+
# It's not possible to train using a custom pre-tokenizer, the Rust
45+
# code raises "Exception: Custom PreTokenizer cannot be serialized"
46+
# (see https://github.com/huggingface/tokenizers/issues/269) so we
47+
# have to run our pre-tokenizer manually, then join its output with
48+
# whitespace and hope the regular pretokenizer takes it back apart
49+
# how we need it to.
50+
51+
base_tokenizer.backend_tokenizer.pre_tokenizer = WhitespaceSplit()
52+
base_pretokenizer = base_tokenizer.backend_tokenizer.pre_tokenizer
53+
new_pretokenizer = PreTokenizer.custom(new_pretokenizer)
54+
55+
def futz_input(real_input):
56+
pretokenized = new_pretokenizer.pre_tokenize_str(real_input)
57+
want_tokens = [token for token, offsets in pretokenized]
58+
futzed_input = " ".join(want_tokens)
59+
pretokenized = base_pretokenizer.pre_tokenize_str(futzed_input)
60+
got_tokens = [token for token, offsets in pretokenized]
61+
assert got_tokens == want_tokens
62+
return futzed_input
63+
64+
def get_training_corpus():
65+
for row in training_dataset:
66+
yield futz_input(json.dumps(row["dom_snapshot"]))
67+
68+
# Train the new tokenizer.
69+
new_tokenizer = base_tokenizer.train_new_from_iterator(
70+
text_iterator=get_training_corpus(),
71+
vocab_size=vocab_size,
72+
new_special_tokens=new_special_tokens,
73+
length=len(training_dataset), # used for progress tracking
74+
show_progress=True,
75+
)
76+
77+
return new_tokenizer
78+
79+
80+
def main(save_directory="pretrained", use_full_dataset=False):
81+
warnings.filterwarnings("ignore", message=r".*resume_download.*")
82+
83+
if use_full_dataset:
84+
dataset, kwargs = FULL_DATASET, dict(streaming=True)
85+
else:
86+
dataset, kwargs = TEST_DATASET, {}
87+
88+
tokenizer = train_tokenizer(dataset, split="train", **kwargs)
89+
tokenizer.save_pretrained(save_directory)

0 commit comments

Comments
 (0)