diff --git a/src/dom_tokenizers/pre_tokenizers/sniffer.py b/src/dom_tokenizers/pre_tokenizers/sniffer.py new file mode 100644 index 0000000..27d1759 --- /dev/null +++ b/src/dom_tokenizers/pre_tokenizers/sniffer.py @@ -0,0 +1,45 @@ +from base64 import b64decode, b64encode +from enum import Enum, auto +from typing import Optional + + +class FileType(Enum): + GIF = auto() + PNG = auto() + RIFF = auto() + SVG = auto() + WEBP = auto() + + +_MAGIC = { + "GIF": b"GIF8", + "PNG": b"\x89PNG", + "RIFF": b"RIFF", + "SVG": b" Optional[FileType]: + filetype = BASE64_MAGIC.get(encoded[:5]) + if filetype != FileType.RIFF: + return filetype + return RIFF_MAGIC.get(b64decode(encoded[:16])[-4:]) diff --git a/src/dom_tokenizers/pre_tokenizers/splitter.py b/src/dom_tokenizers/pre_tokenizers/splitter.py index 5abcc1f..8df9a2a 100644 --- a/src/dom_tokenizers/pre_tokenizers/splitter.py +++ b/src/dom_tokenizers/pre_tokenizers/splitter.py @@ -12,6 +12,7 @@ from unidecode import unidecode from ..internal import json +from .sniffer import sniff_base64 logger = logging.getLogger(__name__) debug = logger.debug @@ -64,6 +65,9 @@ def special_tokens(self) -> Iterable[str]: ENTITY_STARTS = {"&", "&#"} ESCAPE_START_RE = re.compile(r".([&%\\])") + MIN_BYTES_FOR_SNIFF = 32 # Smallest I've seen is a 35 byte GIF + MIN_BASE64_FOR_SNIFF = (MIN_BYTES_FOR_SNIFF * 8) // 6 + # XXX older bits MAXWORDLEN = 32 WORD_RE = re.compile(r"(?:\w+['’]?)+") @@ -84,7 +88,6 @@ def special_tokens(self) -> Iterable[str]: } LONGEST_PHITEST = 85 BASE64_RE = base64_matcher() - B64_PNG_RE = re.compile(r"iVBORw0KGg[o-r]") B64_HEX_RE = re.compile(r"^(0x)?([0-9a-f]+)$", re.I) XML_HDR_RE = re.compile(r"<([a-z]{3,})\s+[a-z]+") @@ -161,6 +164,17 @@ def split(self, text: str) -> Iterable[str]: continue # Are we looking at something that might be base64? + if len(curr) >= self.MIN_BASE64_FOR_SNIFF: + filetype = sniff_base64(curr) + if filetype: + splits[cursor:cursor+1] = [ + self.base64_token, + filetype.name.lower(), + SPLIT + ] + cursor += 3 + continue + if self.BASE64_RE.match(curr): if curr.isdecimal(): if VERBOSE: # pragma: no cover @@ -470,8 +484,6 @@ def _is_urlish_looking_base64(self, splits, cursor): def _enter_base64(self, encoded): # Lots of false-positives here, try sniffing - if self.B64_PNG_RE.match(encoded): - return [self.base64_token, "png"] data = b64decode(encoded) try: text = data.decode("utf-8") @@ -485,8 +497,6 @@ def _enter_base64_utf8(self, text): # XXX recurse?? match = self.XML_HDR_RE.match(text) if match is not None: - if match.group(1) == "svg": - return [self.base64_token, "svg"] return [self.base64_token, "xml"] try: _ = json.loads(text) @@ -499,10 +509,8 @@ def _enter_base64_binary(self, data, encoded): # Not out of false-positive territory yet full_magic = magic.from_buffer(data) easy_magic = full_magic.split(maxsplit=1)[0] - if easy_magic in {"GIF", "zlib", "JPEG"}: + if easy_magic in {"zlib", "JPEG"}: return [self.base64_token, easy_magic] - if " Web/P image" in full_magic: - return [self.base64_token, "webp"] if full_magic.startswith("Web Open Font Format"): return [self.base64_token, "woff"] if len(encoded) > self.LONGEST_PHITEST: