From 4cd2ad2d8a3766ed7770bf9094ebca53017d9961 Mon Sep 17 00:00:00 2001 From: Gary Benson Date: Wed, 5 Jun 2024 20:38:26 +0100 Subject: [PATCH] Limit unescaping and base64 sniffing --- .../pre_tokenizers/dom_snapshot.py | 46 +++++++++---------- src/dom_tokenizers/pre_tokenizers/splitter.py | 38 +++++++++++++-- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py b/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py index a092bc5..f0849fe 100644 --- a/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py +++ b/src/dom_tokenizers/pre_tokenizers/dom_snapshot.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import make_dataclass from xml.dom import Node @@ -7,7 +8,7 @@ from .compat_itertools import batched from .html import is_void_element from .pre_tokenizer import PreTokenizer -from .splitter import TextSplitter +from .splitter import TextSplitter, Flags as Split from .token_buffer import TokenBuffer @@ -26,59 +27,59 @@ def pre_tokenize_dom(self, buf: TokenBuffer, serialized: str): if not any(key in snapshot for key in ("documents", "strings")): snapshot = snapshot.get("result", snapshot) - tokens = TokenCache(snapshot["strings"], self._splitter) + split = TokenCache(snapshot["strings"], self._splitter).get for document in snapshot["documents"]: stack = [self._SENTINEL] for node in _Node.each(document["nodes"]): while stack[-1].index != node.parent_index: - self._terminate(buf, tokens, stack.pop()) + self._terminate(buf, split, stack.pop()) match node.type: case Node.ELEMENT_NODE: buf.append("<") - buf.extend(tokens.get(node.name_index, lowercase=True)) + buf.extend(split(node.name_index, Split.TAG_NAME)) for name_index, value_index in node.attr_indexes: buf.append("_") - buf.extend(tokens[name_index]) + buf.extend(split(name_index, Split.ATTR_NAME)) buf.append("=") - buf.extend(tokens[value_index]) + buf.extend(split(value_index, Split.ATTR_VALUE)) buf.append(">") stack.append(node) case Node.TEXT_NODE: - buf.extend(tokens[node.value_index]) + buf.extend(split(node.value_index, Split.TEXT)) case Node.DOCUMENT_NODE: stack.append(node) case Node.COMMENT_NODE: buf.append("") case Node.DOCUMENT_TYPE_NODE: buf.append("= 0: buf.append("PUBLIC") - buf.extend(tokens[public_index]) + buf.extend(split(public_index, Split.DOCTYPE)) system_index = document["systemId"] if system_index >= 0: - buf.extend(tokens[system_index]) + buf.extend(split(system_index, Split.DOCTYPE)) buf.append(">") for node in reversed(stack[2:]): - self._terminate(buf, tokens, node) + self._terminate(buf, split, node) @staticmethod - def _terminate(buf, tokens, node): - tag = tokens._strings[node.name_index] - if is_void_element(tag): + def _terminate(buf, split, node): + tokens = split(node.name_index, Split.TAG_NAME) + if is_void_element(tokens[-1].original): return buf.append("") @@ -115,31 +116,26 @@ class TokenCache: def __init__(self, strings: list[str], splitter: TextSplitter): self._strings = strings self._splitter = splitter - self._tokens = {} + self._cache = defaultdict(dict) self._lowercase_tokens = {} def get( self, string_index: int, - *, - lowercase=False + split_flags: Split, ) -> list[NormalizedString]: """Return tokens for one string in a DOM snapshot's string table. """ if string_index < 0: return [] - cache = self._lowercase_tokens if lowercase else self._tokens + cache = self._cache[split_flags] tokens = cache.get(string_index) if tokens is not None: return tokens text = self._strings[string_index] - if lowercase: - text = text.lower() tokens = [ NormalizedString(token) - for token in self._splitter.split(text) + for token in self._splitter.split(text, split_flags) ] cache[string_index] = tokens return tokens - - __getitem__ = get diff --git a/src/dom_tokenizers/pre_tokenizers/splitter.py b/src/dom_tokenizers/pre_tokenizers/splitter.py index 32ebe82..bbfbef0 100644 --- a/src/dom_tokenizers/pre_tokenizers/splitter.py +++ b/src/dom_tokenizers/pre_tokenizers/splitter.py @@ -5,6 +5,7 @@ from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass +from enum import Flag, auto from urllib.parse import unquote import magic @@ -17,6 +18,25 @@ debug = logger.debug +class Flags(Flag): + BASIC = 0 # No unescaping, no base64 + + LOWERCASE = auto() # Lowercase before splitting + UNESCAPE_JS = auto() # Decode JS backslash escapes + UNQUOTE_URLS = auto() # Decode URL encoding + SUB_ENTITIES = auto() # Decode HTML entities + SNIFF_BASE64 = auto() # Detect and substitute base64 + + FULL = UNESCAPE_JS | UNQUOTE_URLS | SUB_ENTITIES | SNIFF_BASE64 + + TAG_NAME = LOWERCASE + ATTR_NAME = BASIC + ATTR_VALUE = FULL + TEXT = FULL + COMMENT = FULL # XXX maybe... or BASIC? SUB_ENTITIES?? + DOCTYPE = BASIC + + class MandatorySplit: # pragma: no cover def __repr__(self): return "SPLIT" @@ -88,7 +108,7 @@ def special_tokens(self) -> Iterable[str]: B64_PNG_RE = re.compile(r"iVBORw0KGg[o-r]") XML_HDR_RE = re.compile(r"<([a-z]{3,})\s+[a-z]+") - def split(self, text: str) -> Iterable[str]: + def split(self, text: str, flags: Flags = Flags.FULL) -> Iterable[str]: """Split a string into a sequence of tokens. It splits on any non-alphanumeric character, but also tries @@ -98,6 +118,14 @@ def split(self, text: str) -> Iterable[str]: which are just fragments of base64. It isn't easy though, lots of regular text is valid base64, we have to sniff.) """ + if Flags.LOWERCASE in flags: + text = text.lower() + + unquote_urls = Flags.UNQUOTE_URLS in flags + unescape_js = Flags.UNESCAPE_JS in flags + sub_entities = Flags.SUB_ENTITIES in flags + sniff_base64 = Flags.SNIFF_BASE64 in flags + VERBOSE = logger.isEnabledFor(logging.DEBUG) if VERBOSE and len(text) < 4096: # pragma: no cover debug("input: \x1B[44;36m%s\x1B[0m", text) @@ -132,21 +160,21 @@ def split(self, text: str) -> Iterable[str]: continue # Are we looking at URL-encoding (`%xx` escapes)? - if curr == "%": + if unquote_urls and curr == "%": if VERBOSE: # pragma: no cover debug("it's urlencoded") cursor = self._sub_urlencoded(splits, cursor) continue # Are we looking at Javascript escaping? - if curr[0] == "\\": + if unescape_js and curr[0] == "\\": if VERBOSE: # pragma: no cover debug("it's escaped") cursor = self._sub_js_escape(splits, cursor) continue # Are we looking at character entities? - if curr in self.ENTITY_STARTS: + if sub_entities and curr in self.ENTITY_STARTS: if VERBOSE: # pragma: no cover debug("it's an entity") cursor = self._sub_html_entity(splits, cursor) @@ -170,7 +198,7 @@ def split(self, text: str) -> Iterable[str]: continue # Are we looking at something that might be base64? - if self.BASE64_RE.match(curr): + if sniff_base64 and self.BASE64_RE.match(curr): cursor = self._sub_base64(splits, cursor) continue