Skip to content

Commit

Permalink
Try decoding the whole token first?
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed Jun 3, 2024
1 parent a97cc52 commit 4e7f9ae
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 58 deletions.
16 changes: 16 additions & 0 deletions src/dom_tokenizers/pre_tokenizers/sniffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ class FileType(Enum):
RIFF = auto()
SVG = auto()
WEBP = auto()
WOFF = auto()


MIN_BYTES_FOR_SNIFF = 33 # Smallest I've seen was a 35 byte GIF
MIN_BASE64_FOR_SNIFF = (MIN_BYTES_FOR_SNIFF * 8) // 6

_MAGIC = {
"GIF": b"GIF8",
"PNG": b"\x89PNG",
"RIFF": b"RIFF",
"SVG": b"<svg",
"WOFF": b"wOFF",
}

MAGIC = dict(
Expand All @@ -39,7 +44,18 @@ class FileType(Enum):


def sniff_base64(encoded: str) -> Optional[FileType]:
if len(encoded) < MIN_BASE64_FOR_SNIFF:
return None
filetype = BASE64_MAGIC.get(encoded[:5])
if filetype != FileType.RIFF:
return filetype
return RIFF_MAGIC.get(b64decode(encoded[:16])[-4:])


def sniff_bytes(data: bytes) -> Optional[FileType]:
if len(data) < MIN_BYTES_FOR_SNIFF:
return None
filetype = MAGIC.get(data[:4])
if filetype != FileType.RIFF:
return filetype
return RIFF_MAGIC.get(data[8:12])
143 changes: 85 additions & 58 deletions src/dom_tokenizers/pre_tokenizers/splitter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re

from base64 import b64decode
from base64 import binascii, b64decode
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
Expand All @@ -12,7 +12,7 @@
from unidecode import unidecode

from ..internal import json
from .sniffer import sniff_base64
from .sniffer import sniff_bytes

logger = logging.getLogger(__name__)
debug = logger.debug
Expand Down Expand Up @@ -65,9 +65,6 @@ def special_tokens(self) -> Iterable[str]:
ENTITY_STARTS = {"&", "&#"}
ESCAPE_START_RE = re.compile(r".([&%\\])")

MIN_BYTES_FOR_SNIFF = 32 # Smallest I've seen is a 35 byte GIF
MIN_BASE64_FOR_SNIFF = (MIN_BYTES_FOR_SNIFF * 8) // 6

# XXX older bits
MAXWORDLEN = 32
WORD_RE = re.compile(r"(?:\w+['’]?)+")
Expand Down Expand Up @@ -164,38 +161,55 @@ def split(self, text: str) -> Iterable[str]:
continue

# Are we looking at something that might be base64?
if len(curr) >= self.MIN_BASE64_FOR_SNIFF:
filetype = sniff_base64(curr)
if filetype:
splits[cursor:cursor+1] = [
self.base64_token,
filetype.name.lower(),
SPLIT
]
cursor += 3
continue

if self.BASE64_RE.match(curr):
if curr.isdecimal():
if VERBOSE: # pragma: no cover
debug("it's a decimal number")
cursor += 1
continue

match = self.B64_HEX_RE.match(curr)
if match:
if VERBOSE: # pragma: no cover
debug("it's hex")
new_splits = match.groups()
if new_splits[0] is not None:
splits[cursor:cursor+1] = new_splits
cursor += 1
cursor += 1
continue

cursor = self._sub_base64(splits, cursor)
if len(curr) > 4 and ( # XXX 4? too short??
new_splits := self._split_base64(curr)):
new_splits.append(SPLIT)
splits[cursor:cursor+1] = new_splits
cursor += len(new_splits)
continue

#if len(curr) < self.MIN_BASE64_FOR_SNIFF:
# if self.BASE64_RE.match(curr):
# data = b64decode(curr)
# try:
# text = data.decode("utf-8")
# except UnicodeDecodeError:
# text = None
#if text is not None:
#return self._enter_base64_utf8(text)
#return self._enter_base64_binary(data, encoded)
#else:
# filetype = sniff_base64(curr)
# if filetype:
# splits[cursor:cursor+1] = [
# self.base64_token,
# filetype.name.lower(),
# SPLIT
# ]
# cursor += 3
# continue
#
#if self.BASE64_RE.match(curr):
# if curr.isdecimal():
# if VERBOSE: # pragma: no cover
# debug("it's a decimal number")
# cursor += 1
# continue
#
# match = self.B64_HEX_RE.match(curr)
# if match:
# if VERBOSE: # pragma: no cover
# debug("it's hex")
# new_splits = match.groups()
# if new_splits[0] is not None:
# splits[cursor:cursor+1] = new_splits
# cursor += 1
# cursor += 1
# continue
#
# cursor = self._sub_base64(splits, cursor)
# continue

# Is the whole thing one word?
words = self.WORD_RE.findall(curr)
if len(words) == 1 and words[0] == curr:
Expand Down Expand Up @@ -408,6 +422,42 @@ def _sub_urlencoded(self, splits, cursor):
splits.insert(cursor, "".join(parts))
return cursor

def _split_base64(self, encoded):
try:
encoded = encoded.encode("ascii")
except UnicodeEncodeError:
return None
try:
data = b64decode(encoded, validate=True)
except binascii.Error:
return None
try:
text = data.decode("utf-8")
except UnicodeDecodeError:
return self._split_base64_binary(data)
return self._split_base64_utf8(text)

def _split_base64_utf8(self, text):
match = self.XML_HDR_RE.match(text)
if match is not None:
if match.group(1) == "svg":
return [self.base64_token, "svg"]
return [self.base64_token, "xml"]
try:
_ = json.loads(text)
return [self.base64_token, "json"]
except json.JSONDecodeError:
pass
return [self.base64_token, "text"]

def _split_base64_binary(self, data):
filetype = sniff_bytes(data)
if not filetype:
return None
return [self.base64_token, filetype.name.lower()]

# XXX junk?

def _sub_base64(self, splits, cursor):
curr = splits[cursor]
try:
Expand Down Expand Up @@ -482,29 +532,6 @@ def _is_urlish_looking_base64(self, splits, cursor):

return False

def _enter_base64(self, encoded):
# Lots of false-positives here, try sniffing
data = b64decode(encoded)
try:
text = data.decode("utf-8")
except UnicodeDecodeError:
text = None
if text is not None:
return self._enter_base64_utf8(text)
return self._enter_base64_binary(data, encoded)

def _enter_base64_utf8(self, text):
# XXX recurse??
match = self.XML_HDR_RE.match(text)
if match is not None:
return [self.base64_token, "xml"]
try:
_ = json.loads(text)
return [self.base64_token, "json"]
except json.JSONDecodeError:
pass
return [self.base64_token, "utf-8"]

def _enter_base64_binary(self, data, encoded):
# Not out of false-positive territory yet
full_magic = magic.from_buffer(data)
Expand Down

0 comments on commit 4e7f9ae

Please sign in to comment.