|
| 1 | +import json |
| 2 | +import re |
| 3 | + |
| 4 | +from base64 import b64decode |
| 5 | +from collections import defaultdict |
| 6 | +from collections.abc import Iterable |
| 7 | +from functools import cached_property |
| 8 | +from itertools import chain |
| 9 | +from posixpath import commonprefix |
| 10 | +from typing import List |
| 11 | +from xml.dom import Node |
| 12 | + |
| 13 | +import magic |
| 14 | + |
| 15 | +from tokenizers import NormalizedString, PreTokenizedString |
| 16 | + |
| 17 | + |
| 18 | +class DOMSnapshotPreTokenizer: |
| 19 | + """Pre-tokenizer that consumes JSON-serialized DOM snapshots |
| 20 | + and emits tokenized representations of the snapshotted DOMs. |
| 21 | + """ |
| 22 | + bos_token = "[BOS]" # beginning of sequence |
| 23 | + eos_token = "[EOS]" # end of sequence |
| 24 | + sep_token = "[SEP]" # separator between documents |
| 25 | + elem_token = "[TAG]" # beginning of element name |
| 26 | + attr_token = "[ATTR]" # beginning of attribute |
| 27 | + comm_token = "[COMMENT]" # beginning of comment |
| 28 | + base64_token = "[BASE64]" # beginning of some base64 |
| 29 | + long_token = "[LONG]" # elided long token |
| 30 | + |
| 31 | + @property |
| 32 | + def special_tokens(self): |
| 33 | + return [ |
| 34 | + value |
| 35 | + for attr, value in self.__class__.__dict__.items() |
| 36 | + if attr.endswith("token") |
| 37 | + ] |
| 38 | + |
| 39 | + def pre_tokenize(self, pretok: PreTokenizedString): |
| 40 | + """Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place. |
| 41 | + """ |
| 42 | + pretok.split(self._split_json) |
| 43 | + |
| 44 | + def _split_json(self, i: int, s: NormalizedString) -> List[NormalizedString]: |
| 45 | + snapshot = json.loads(s.normalized) |
| 46 | + return list(chain.from_iterable(self._split_serialized(snapshot))) |
| 47 | + |
| 48 | + def _split_serialized(self, snapshot: dict) -> Iterable[List[NormalizedString]]: |
| 49 | + emitter = TokenEmitter(self, snapshot) |
| 50 | + elem_token = [NormalizedString(self.elem_token)] |
| 51 | + attr_token = [NormalizedString(self.attr_token)] |
| 52 | + |
| 53 | + for document_index, document in enumerate(snapshot["documents"]): |
| 54 | + token = self.bos_token if document_index == 0 else self.sep_token |
| 55 | + yield [NormalizedString(token)] |
| 56 | + |
| 57 | + nodes = document["nodes"] |
| 58 | + for node_index, node_values in enumerate(zip( |
| 59 | + nodes["nodeType"], |
| 60 | + nodes["nodeName"], |
| 61 | + nodes["nodeValue"], |
| 62 | + nodes["attributes"])): |
| 63 | + node_type, name_index, value_index, attr_indexes = node_values |
| 64 | + |
| 65 | + match node_type: |
| 66 | + case Node.ELEMENT_NODE: |
| 67 | + yield elem_token |
| 68 | + yield emitter.emit(name_index) |
| 69 | + for attr_index in range(0, len(attr_indexes), 2): |
| 70 | + yield attr_token |
| 71 | + yield emitter.emit(attr_indexes[attr_index]) |
| 72 | + yield emitter.emit(attr_indexes[attr_index + 1]) |
| 73 | + |
| 74 | + case Node.TEXT_NODE: |
| 75 | + yield emitter.emit(value_index) |
| 76 | + |
| 77 | + case Node.COMMENT_NODE: |
| 78 | + yield [NormalizedString(self.comm_token)] |
| 79 | + yield emitter.emit(value_index) |
| 80 | + |
| 81 | + yield [NormalizedString(self.eos_token)] |
| 82 | + |
| 83 | + |
| 84 | +_B64_RE_S = r"(?:[A-Za-z0-9+/]{4}){" |
| 85 | +_B64_RE_E = r",}(?:[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?" |
| 86 | + |
| 87 | + |
| 88 | +def base64_matcher(min_encoded_len=24): |
| 89 | + min_groups, extra = divmod(min_encoded_len, 4) |
| 90 | + if extra: |
| 91 | + min_groups += 1 |
| 92 | + return re.compile(f"{_B64_RE_S}{min_groups}{_B64_RE_E}") |
| 93 | + |
| 94 | + |
| 95 | +class TokenEmitter: |
| 96 | + MAXWORDLEN = 32 |
| 97 | + WORD_RE = re.compile( |
| 98 | + r"[a-z0-9]+(?:[a-z0-9']*[a-z0-9])?") # XXX English only :( |
| 99 | + ESCAPED_RE = re.compile( |
| 100 | + r"((?:%|\\x|\\u[0-9a-f]{2})[0-9a-f]{2})", re.I) |
| 101 | + HEX_RE = re.compile(r"^(?:0x|[0-9a-f]{2})[0-9a-f]{6,}$") |
| 102 | + DIGIT_RE = re.compile(r"\d") |
| 103 | + URLISH_RE = re.compile(r"(?:[a-z]+|[0-9a-f]+|[A-Z0-9]+)") |
| 104 | + SHORTEST_URLISH = 16 |
| 105 | + LONGEST_PHITEST = 85 |
| 106 | + BASE64_RE = base64_matcher() |
| 107 | + B64_PNG_RE = re.compile(r"iVBORw0KGg[o-r]") |
| 108 | + XML_HDR_RE = re.compile(r"<([a-z]{3,})\s+[a-z]+") |
| 109 | + |
| 110 | + def __init__(self, pretokenizer: DOMSnapshotPreTokenizer, snapshot: dict): |
| 111 | + self._pt = pretokenizer |
| 112 | + self._strings = snapshot["strings"] |
| 113 | + self._tokens = {} |
| 114 | + |
| 115 | + @cached_property |
| 116 | + def base64_token(self): |
| 117 | + return self._pt.base64_token |
| 118 | + |
| 119 | + @cached_property |
| 120 | + def long_token(self): |
| 121 | + return self._pt.long_token |
| 122 | + |
| 123 | + def emit(self, string_index: int) -> Iterable[NormalizedString]: |
| 124 | + """Emit tokens for one string in a DOM snapshot's string table. |
| 125 | +
|
| 126 | + It splits on any non-alphanumeric character, but also tries |
| 127 | + to detect (and recurse into) base64-encoded date, of which |
| 128 | + there's a lot in just the 295 `interesting-dom-snapshots`. |
| 129 | + (Not dealing with base64 results in a whole load of "words" |
| 130 | + which are just fragments of base64. It isn't easy though, |
| 131 | + lots of regular text is valid base64, we have to sniff.) |
| 132 | + """ |
| 133 | + if string_index < 0: |
| 134 | + return [] |
| 135 | + tokens = self._tokens.get(string_index) |
| 136 | + if tokens is not None: |
| 137 | + return tokens |
| 138 | + tokens = [ |
| 139 | + NormalizedString(token) |
| 140 | + for token in self._postprocess( |
| 141 | + chain.from_iterable( |
| 142 | + self._split( |
| 143 | + self._preprocess( |
| 144 | + self._strings[string_index])))) |
| 145 | + ] |
| 146 | + self._tokens[string_index] = tokens |
| 147 | + return tokens |
| 148 | + |
| 149 | + def _preprocess(self, text): |
| 150 | + return "".join( |
| 151 | + self._unescape_char(s) if i & 1 else s |
| 152 | + for i, s in enumerate(self.ESCAPED_RE.split(text)) |
| 153 | + ) |
| 154 | + |
| 155 | + def _unescape_char(self, escaped): |
| 156 | + if escaped[0] == "%": |
| 157 | + escaped = "\\x" + escaped[1:] |
| 158 | + return eval(f'"{escaped}"') |
| 159 | + |
| 160 | + def _split(self, text): |
| 161 | + while text: |
| 162 | + match = self.BASE64_RE.search(text) |
| 163 | + if match is not None: |
| 164 | + start, limit = match.span() |
| 165 | + else: |
| 166 | + start = limit = len(text) |
| 167 | + if start > 0: |
| 168 | + yield self._split_words(text[:start]) |
| 169 | + if limit > start: |
| 170 | + encoded = text[start:limit] |
| 171 | + matched = self._match_urlish_base64(encoded) |
| 172 | + if matched is not None: |
| 173 | + limit = start + len(matched) |
| 174 | + yield self._split_words(text[start:limit]) |
| 175 | + else: |
| 176 | + yield self._enter_base64(encoded) |
| 177 | + if limit == len(text): |
| 178 | + break |
| 179 | + text = text[limit:] |
| 180 | + |
| 181 | + def _split_words(self, text): |
| 182 | + return self.WORD_RE.findall(text.lower()) |
| 183 | + |
| 184 | + def _match_urlish_base64(self, encoded): |
| 185 | + urlish = "/".join(self.URLISH_RE.findall(encoded)) |
| 186 | + result = commonprefix((encoded, urlish)) |
| 187 | + if len(result) < self.SHORTEST_URLISH: |
| 188 | + return None |
| 189 | + return result |
| 190 | + |
| 191 | + def _enter_base64(self, encoded): |
| 192 | + # Lots of false-positives here, try sniffing |
| 193 | + if self.B64_PNG_RE.match(encoded): |
| 194 | + return [self.base64_token, "png"] |
| 195 | + data = b64decode(encoded) |
| 196 | + try: |
| 197 | + text = data.decode("utf-8") |
| 198 | + except UnicodeDecodeError: |
| 199 | + text = None |
| 200 | + if text is not None: |
| 201 | + return self._enter_base64_utf8(text) |
| 202 | + return self._enter_base64_binary(data, encoded) |
| 203 | + |
| 204 | + def _enter_base64_utf8(self, text): |
| 205 | + # XXX recurse?? |
| 206 | + match = self.XML_HDR_RE.match(text) |
| 207 | + if match is not None: |
| 208 | + if match.group(1) == "svg": |
| 209 | + return [self.base64_token, "svg"] |
| 210 | + return [self.base64_token, "xml"] |
| 211 | + try: |
| 212 | + _ = json.loads(text) |
| 213 | + return [self.base64_token, "json"] |
| 214 | + except json.JSONDecodeError: |
| 215 | + pass |
| 216 | + return [self.base64_token, "utf", "8"] |
| 217 | + |
| 218 | + def _enter_base64_binary(self, data, encoded): |
| 219 | + # Not out of false-positive territory yet |
| 220 | + full_magic = magic.from_buffer(data) |
| 221 | + easy_magic = full_magic.split(maxsplit=1)[0] |
| 222 | + if easy_magic in {"GIF", "zlib", "JPEG"}: |
| 223 | + return [self.base64_token, easy_magic.lower()] |
| 224 | + if " Web/P image" in full_magic: |
| 225 | + return [self.base64_token, "webp"] |
| 226 | + if full_magic.startswith("Web Open Font Format"): |
| 227 | + return [self.base64_token, "woff"] |
| 228 | + if len(encoded) > self.LONGEST_PHITEST: |
| 229 | + return [self.base64_token] |
| 230 | + # phi test for monoalphabeticity |
| 231 | + hist = defaultdict(int) |
| 232 | + for symbol in encoded: |
| 233 | + hist[symbol] += 1 |
| 234 | + phi_o = sum(freq * (freq - 1) for freq in hist.values()) |
| 235 | + N = len(encoded) |
| 236 | + phi_r = N * (N - 1) / 64 |
| 237 | + # non-standard comparison (observed phi > twice random) |
| 238 | + if phi_o > phi_r * 2: |
| 239 | + return self._split_words(encoded) |
| 240 | + return [self.base64_token] |
| 241 | + |
| 242 | + def _postprocess(self, tokens: Iterable[str]) -> Iterable[str]: |
| 243 | + for token in tokens: |
| 244 | + if self.HEX_RE.match(token): |
| 245 | + yield self.long_token |
| 246 | + try: |
| 247 | + _ = int(token) |
| 248 | + except ValueError: |
| 249 | + yield "hex" |
| 250 | + yield "digits" |
| 251 | + continue |
| 252 | + |
| 253 | + if len(token) <= self.MAXWORDLEN: |
| 254 | + yield token |
| 255 | + continue |
| 256 | + |
| 257 | + yield self.long_token |
| 258 | + if self.DIGIT_RE.search(token): |
| 259 | + yield "alphanumeric" |
| 260 | + else: |
| 261 | + yield "alphabetic" |
0 commit comments