From 176151799b74cd3c48a1d30daa533195e8999f83 Mon Sep 17 00:00:00 2001 From: Gary Benson Date: Mon, 3 Jun 2024 22:36:41 +0100 Subject: [PATCH] Start reintegrating the Oracle --- pyproject.toml | 4 +- runner.py | 53 +++++++ src/dom_tokenizers/pre_tokenizers/oracle.py | 141 ++++++++++++++++++ .../pre_tokenizers/shared_oracle.py | 18 +++ src/dom_tokenizers/pre_tokenizers/splitter.py | 23 ++- tests/test_oracle.py | 110 ++++++++++++++ tests/test_splitter.py | 2 + 7 files changed, 341 insertions(+), 10 deletions(-) create mode 100644 runner.py create mode 100644 src/dom_tokenizers/pre_tokenizers/oracle.py create mode 100644 src/dom_tokenizers/pre_tokenizers/shared_oracle.py create mode 100644 tests/test_oracle.py diff --git a/pyproject.toml b/pyproject.toml index 2ea3b9c..6e7a9d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,10 @@ classifiers = [ "Topic :: Text Processing :: Markup :: HTML", ] dependencies = [ + "numpy", "python-magic", # XXX review "tokenizers", + "transformers", "unidecode", # XXX review ] @@ -42,12 +44,10 @@ dev = [ "pillow", "pytest", "pytest-cov", - "transformers", ] train = [ "datasets", "pillow", - "transformers", ] [project.scripts] diff --git a/runner.py b/runner.py new file mode 100644 index 0000000..24297e6 --- /dev/null +++ b/runner.py @@ -0,0 +1,53 @@ +import sys +import warnings + +from itertools import chain + +from dom_tokenizers.internal import json +from dom_tokenizers.pre_tokenizers.shared_oracle import SharedOracle + +DEFAULT_TESTCASES = [ + "overflow", + "uniqueid", + "uniqueId", + "uniqueID", + "pagewrap", + "pageWrap", + "autocompletetype", + "autocompleteType", + "backfill", + "Inauspicious", + "Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch", + + "1655885421832", # first token is 4 chars + "8eb5e30dac7d493298287704a5f578c7", + "next/static/css/99762953f4d03581", + "org/TR/xhtml1/DTD/xhtml1", + "KFOmCnqEu92Fr1Mu4mxK", + "electronically8eb5e30dac7", # median chars/token = 1.0 (mean=2.7), + "electronically8eb5e30dac", # median chars/token = 1.5 (mean=3.0) + "electronically8eb5e30da", # median chars/token = 2.0 (mean=3.3) +] + + +def main(): + warnings.filterwarnings("ignore", message=r".*resume_download.*") + + oracle = SharedOracle() + if len(sys.argv) < 2: + lines = DEFAULT_TESTCASES + else: + lines = chain.from_iterable( + (json.loads(line)["text"] + for line in open(filename).readlines()) + for filename in sys.argv[1:]) + + for line in lines: + print("input:", line) + result = oracle.split_if_trivial(line, log_unhandled=False) + if result is not None: + print(f"\x1B[32m{result}\x1B[0m\n") + + +if __name__ == "__main__": + main() diff --git a/src/dom_tokenizers/pre_tokenizers/oracle.py b/src/dom_tokenizers/pre_tokenizers/oracle.py new file mode 100644 index 0000000..3b14644 --- /dev/null +++ b/src/dom_tokenizers/pre_tokenizers/oracle.py @@ -0,0 +1,141 @@ +import re + +from typing import Optional, Callable + +import numpy as np + +from ..internal import jsonl +from ..internal.transformers import AutoTokenizer + +_IntOrIntList = int | list[int] +_StrOrStrList = str | list[str] + + +class Oracle: + def __init__(self, *args, **kwargs): + self._tok = AutoTokenizer.from_pretrained(*args, **kwargs) + self._tok.model_max_length = 1 << 31 + self.cls_token_id = self._tok.cls_token_id + self.sep_token_id = self._tok.sep_token_id + self.unk_token_id = self._tok.unk_token_id + self.max_token_len = max( + len(token) for token in self._tok.vocab + ) + self.max_try_split_len = min(self.max_token_len * 5, 100) + self._log = jsonl.Writer(basename="oracle", with_timestamp=True) + + def close(self): + self._log.close() + + @property + def normalize_str(self) -> Callable[[str], str]: + """Normalize the given string. + """ + return self._tok.backend_tokenizer.normalizer.normalize_str + + def encode(self, *args, **kwargs) -> list[int]: + """Convert the given string to a list of integer token IDs. + """ + token_ids = self._tok.encode(*args, **kwargs) + assert token_ids[0] == self.cls_token_id + assert token_ids[-1] == self.sep_token_id + return token_ids[1:-1] + + IDsToTokensType = Callable[[_IntOrIntList], _StrOrStrList] + + @property + def convert_ids_to_tokens(self, *args, **kwargs) -> IDsToTokensType: + """Convert the given list of token IDs to a list of tokens. + """ + return self._tok.convert_ids_to_tokens + + def tokenize(self, *args, **kwargs) -> list[str]: + """Convert the given string into a list of tokens. + """ + return self.convert_ids_to_tokens(self.encode(*args, **kwargs)) + + @property + def decode(self) -> Callable[[_IntOrIntList], str]: + """Convert the given list of token IDs to a string. + """ + return self._tok.decode + + # For quick checks, see TextSplitter.BASE64_RE for the real deal + _LOOSE_BASE64_RE = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") + + def split_if_trivial( + self, + text: str, + log_unhandled: bool = True, # XXX + ) -> Optional[list[str]]: + """Split a string into a list of tokens XXX IF! + + Like `tokenize()` but it only returns if XXX. Otherwise None is + returned. + """ + if len(text) > self.max_try_split_len: + return None + + # Fast path for text that's in the oracle's vocabulary. + if len(text) <= self.max_token_len and ( + (text in self._tok.vocab + or text.lower() in self._tok.vocab) + and text.isalnum()): + return [text] + + # Limit ourselves to base64-ish input, for now at least. + if not self._LOOSE_BASE64_RE.match(text): + raise NotImplementedError(text) + + token_ids = self.encode(text) + if not token_ids or self.unk_token_id in token_ids: + return None + + tokens = self.convert_ids_to_tokens(token_ids) + word_pieces = [token.lstrip("#") for token in tokens] + token_lengths = [len(token) for token in word_pieces] + + # If the tokens are mostly 2+ characters long and the + # input text splits on whitespace in the same places as + # the decoded token ID sequence then call this a match. + # Subtracting the standard deviation prevents situations + # where one long token skews the median away from a load + # of 1-2 character tokens, e.g. "electronically8eb5e30da" + # tokenizes to ["electronically", "8", "eb", "5", "e", + # "30", "da"] with bert-base-uncased, so a median token + # length of 2 characters/token and a mean of 3.3, but + # the standard deviation of 4.4 indicates at least one + # token is very far from the mean. + median_length = np.median(token_lengths) + length_stddev = np.std(token_lengths) + if median_length - length_stddev > 1: + result = text.split() + want = [token.lower() for token in result] + if self.decode(token_ids).split() == want: + return result + + print(f"tokens: {tokens}"[:80]) + + first_token_id = token_ids[0] + first_token = self.convert_ids_to_tokens(first_token_id) + assert "#" not in first_token + print(f"first_token: {first_token!r} ({first_token_id})") + + chars_per_token = len(text) / len(token_ids) + + #mean = sum(token_lengths) / len(token_ids) + print("chars_per_token:", chars_per_token) + #print("or ------> mean:", mean) + print(" median:", median_length) + print(" std.dev:", length_stddev) + print() + + # XXX now what? + if log_unhandled: + self._log.write( + text=text, token_ids=token_ids, + tokens=tokens, + decoded=self.decode(token_ids), + chars_per_token=chars_per_token, + ) + return None diff --git a/src/dom_tokenizers/pre_tokenizers/shared_oracle.py b/src/dom_tokenizers/pre_tokenizers/shared_oracle.py new file mode 100644 index 0000000..9b54ecf --- /dev/null +++ b/src/dom_tokenizers/pre_tokenizers/shared_oracle.py @@ -0,0 +1,18 @@ +import atexit + +from .oracle import Oracle + + +class SharedOracle(Oracle): + _shared_borg_state = {} + + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls) + obj.__dict__ = cls._shared_borg_state + return obj + + def __init__(self, model="bert-base-uncased", *args, **kwargs): + if hasattr(self, "_tok"): + return + super().__init__(model, *args, **kwargs) + atexit.register(self.close) diff --git a/src/dom_tokenizers/pre_tokenizers/splitter.py b/src/dom_tokenizers/pre_tokenizers/splitter.py index 7729003..b90cc83 100644 --- a/src/dom_tokenizers/pre_tokenizers/splitter.py +++ b/src/dom_tokenizers/pre_tokenizers/splitter.py @@ -4,7 +4,7 @@ from base64 import binascii, b64decode from collections import defaultdict from collections.abc import Iterable -from dataclasses import dataclass +from dataclasses import dataclass, field from urllib.parse import unquote import magic @@ -12,6 +12,8 @@ from unidecode import unidecode from ..internal import json +from .oracle import Oracle +from .shared_oracle import SharedOracle from .sniffer import sniff_bytes logger = logging.getLogger(__name__) @@ -44,6 +46,7 @@ class FalseBase64Error(RuntimeError): class TextSplitter: base64_token: str = "[BASE64]" long_token: str = "[LONG]" + oracle: Oracle = field(default_factory=SharedOracle) @property def special_tokens(self) -> Iterable[str]: @@ -424,20 +427,20 @@ def _sub_urlencoded(self, splits, cursor): def _split_base64(self, encoded): try: - encoded = encoded.encode("ascii") + _encoded = encoded.encode("ascii") except UnicodeEncodeError: return None try: - data = b64decode(encoded, validate=True) + data = b64decode(_encoded, validate=True) except binascii.Error: return None try: text = data.decode("utf-8") except UnicodeDecodeError: - return self._split_base64_binary(data) - return self._split_base64_utf8(text) + return self._split_base64_binary(data, encoded) + return self._split_base64_utf8(text, encoded) - def _split_base64_utf8(self, text): + def _split_base64_utf8(self, text, encoded): match = self.XML_HDR_RE.match(text) if match is not None: if match.group(1) == "svg": @@ -448,12 +451,16 @@ def _split_base64_utf8(self, text): return [self.base64_token, "json"] except json.JSONDecodeError: pass + if self.oracle.first_is_better(encoded, text): + return None # encoded is better return [self.base64_token, "text"] - def _split_base64_binary(self, data): + def _split_base64_binary(self, data, encoded): filetype = sniff_bytes(data) if not filetype: - return None + if self.oracle.is_texty(encoded): + return None + return [self.base64_token, "data"] return [self.base64_token, filetype.name.lower()] # XXX junk? diff --git a/tests/test_oracle.py b/tests/test_oracle.py new file mode 100644 index 0000000..78e6265 --- /dev/null +++ b/tests/test_oracle.py @@ -0,0 +1,110 @@ +import pytest + +from dom_tokenizers.pre_tokenizers.shared_oracle import SharedOracle + + +@pytest.mark.parametrize( + ("text,expect_normalized"), + (("hello world", "hello world"), + ("html", "html"), + ("", ""), + ("HTML", "html"), + ("Parse error", "parse error"), + (" html", " html"), + ("html ", "html "), + (": syntax error, unexpected ')' in ", + ": syntax error, unexpected ')' in "), + ("\n", " "), + (": \t syntax error, unexpected ')' in ", + ": syntax error, unexpected ')' in "), + ("\ufeff", ""), + )) +def test_normalizer(text, expect_normalized): + """Check the backend normalizer works as we expect. + + Specifically: + - lowercasing is performed + - leading and trailing whitespace are retained + - sequences of whitespace are not compressed + - all whitespace characters become ASCII space + - punctiation is retained + - BOM is not retained + """ + assert SharedOracle().normalize_str(text) == expect_normalized + + +@pytest.mark.parametrize( + ("text,expect_tokens"), + (("hello world", ["hello", "world"]), + ("html", ["html"]), + ("", []), + ("HTML", ["html"]), + ("Parse error", ["par", "##se", "error"]), + ("宏 error", ["[UNK]", "error"]), + (" html", ["html"]), + ("html ", ["html"]), + (": syntax error, unexpected ')' in ", + [":", "syntax", "error", ",", "unexpected", "'", ")", "'", "in"]), + (": \t syntax error, unexpected ')' in ", + [":", "syntax", "error", ",", "unexpected", "'", ")", "'", "in"]), + ("\ufeff", []), + + # Testcases for split_if_trivial() + ("overflow", ["over", "##flow"]), + ("uniqueid", ["unique", "##id"]), + ("pagewrap", ["page", "##wr", "##ap"]), + ("autocompletetype", ["auto", "##com", "##ple", "##tet", "##ype"]), + ("Inauspicious", ["ina", "##us", "##pic", "##ious"]), + ("Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch", + ["ll", "##an", "##fa", "##ir", "##pw", "##ll", "##g", "##wyn", + "##gy", "##ll", "##go", "##ger", "##ych", "##wy", "##rn", + "##dro", "##b", "##wl", "##ll", "##lan", "##ty", "##sil", + "##io", "##go", "##go", "##go", "##ch"]), + ("1655885421832", + ["1655", "##8", "##85", "##42", "##18", "##32"]), + )) +def test_tokenizer(text, expect_tokens): + """Check the backend tokenizer works as expected. + + Specifically: + - normalization is performed as per `test_normalizer()` + - whitespace causes splits but is not retained + - unhandled input is substituted with [UNK] + - result is not bracketed by [CLS], [SEP] + """ + assert SharedOracle().tokenize(text) == expect_tokens + + +@pytest.mark.parametrize( + ("text,expect_tokens"), + (("overflow", ["overflow"]), + ("uniqueid", ["uniqueid"]), + ("uniqueId", ["uniqueId"]), + ("uniqueID", ["uniqueID"]), + ("pagewrap", ["pagewrap"]), + ("autocompletetype", ["autocompletetype"]), + ("Inauspicious", ["Inauspicious"]), + ("Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch", + ["Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch"]), + #("8eb5e30dac7d493298287704a5f578c7", + # ["[CWCI]"]), + #("next/static/css/99762953f4d03581", + # ["next", "static", "css", "[CWCI]"]), + #("org/TR/xhtml1/DTD/xhtml1", + # ["org", "TR", "xhtml1", "DTD", "xhtml1"]), + #("KFOmCnqEu92Fr1Mu4mxK", ["[CWCI]"]), + #("pageWrap", ["page", "Wrap"]), # XXX maybe? + #("autocompleteType", ["autocomplete", "Type"]), + #("electronically8eb5e30dac7", # median chars/token = 1.0 (mean=2.7) + # ["electronically", "[CWCI]"]), + #("electronically8eb5e30dac", # median chars/token = 1.5 (mean=3.0) + # ["electronically", "[CWCI]"]), + #("electronically8eb5e30da", # median chars/token = 2.0 (mean=3.3) + # ["electronically", "[CWCI]"]), + # ("1655885421832", ["[CWCI]"]), + )) +def test_split_if_trivial(text, expect_tokens): + """Check `Oracle.split_if_trivial()` is doing what it should. + """ + assert SharedOracle().split_if_trivial( + text, log_unhandled=False) == expect_tokens diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 11a96f4..f97e6f4 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -213,6 +213,8 @@ def test_decoding(text, expect_tokens): ["src", "url", "fonts", "gstatic", "com", "s", "roboto", "v18", "KFOmCnqEu92Fr1Mu4mxK", "woff2", "format", "woff2", "unicode", "range", "U", "0000", "00FF"]), + ("0x8eb5e30dac7d493298287704a5f578c7", + ["0x", "[LONG]", "hex", "digits"]), )) def test_regressions(text, expect_tokens): """Check that things we improve stay improved.