From 8ed02d55ca55691b42d965f303a65b3112eb0019 Mon Sep 17 00:00:00 2001 From: Gary Benson Date: Wed, 15 May 2024 23:30:54 +0100 Subject: [PATCH] User-friendly trainer --- .flake8 | 2 + README.md | 5 ++ src/dom_tokenizers/train.py | 102 ++++++++++++++++++++++++++++-------- 3 files changed, 86 insertions(+), 23 deletions(-) diff --git a/.flake8 b/.flake8 index 998689f..6c31553 100644 --- a/.flake8 +++ b/.flake8 @@ -5,3 +5,5 @@ per-file-ignores = src/dom_tokenizers/**/__init__.py: F401 # line too long src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501 + # module level import not at top of file + src/dom_tokenizers/train.py: E402 diff --git a/README.md b/README.md index 226519a..b809b48 100644 --- a/README.md +++ b/README.md @@ -29,3 +29,8 @@ python3 -m venv .venv pip install --upgrade pip pip install -e .[dev,train] ``` + +## Train a tokenizer +```sh +train-tokenizer gbenson/interesting-dom-snapshots -n 10000 +``` diff --git a/src/dom_tokenizers/train.py b/src/dom_tokenizers/train.py index 391a60b..5cf0579 100644 --- a/src/dom_tokenizers/train.py +++ b/src/dom_tokenizers/train.py @@ -1,31 +1,29 @@ +import os import json import warnings +from argparse import ArgumentParser +from math import log10, floor + from datasets import load_dataset from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = str(True) from transformers import AutoTokenizer from .pre_tokenizers import DOMSnapshotPreTokenizer -FULL_DATASET = "gbenson/webui-dom-snapshots" -TEST_DATASET = "gbenson/interesting-dom-snapshots" +DEFAULT_BASE = "bert-base-uncased" +DEFAULT_SPLIT = "train" +DEFAULT_SIZE = 1024 +SEND_BUGS_TO = "https://github.com/gbenson/dom-tokenizers/issues" def train_tokenizer( - *args, - training_dataset=None, - base_tokenizer="bert-base-uncased", - vocab_size=1024, # XXX including all tokens and alphabet - **kwargs): - """ - XXX - base_tokenizer - all other args passed to load_dataset for XXX... - """ - - # Load the training data we'll train our new tokenizer with. - if training_dataset is None: - training_dataset = load_dataset(*args, **kwargs) + training_dataset, + base_tokenizer=DEFAULT_BASE, + vocab_size=DEFAULT_SIZE, + corpus_size=None): # Create the base tokenizer we'll train our new tokenizer from. if isinstance(base_tokenizer, str): @@ -65,25 +63,83 @@ def get_training_corpus(): for row in training_dataset: yield futz_input(json.dumps(row["dom_snapshot"])) + # Try and get a dataset length, for the progress tracker. + if corpus_size is None: + try: + corpus_size = len(training_dataset) + except TypeError: + pass + # Train the new tokenizer. new_tokenizer = base_tokenizer.train_new_from_iterator( text_iterator=get_training_corpus(), vocab_size=vocab_size, new_special_tokens=new_special_tokens, - length=len(training_dataset), # used for progress tracking + length=corpus_size, show_progress=True, ) return new_tokenizer -def main(save_directory="pretrained", use_full_dataset=False): +def _round_and_prefix(value): + """314159 -> '314k'.""" + whole, frac = divmod(log10(value), 1) + unit_index, whole = divmod(floor(whole), 3) + value = round(10 ** (whole + frac)) + unit = ([""] + list("kMBTQ"))[unit_index] + return f"{value}{unit}" + + +def main(): + p = ArgumentParser( + description="Train DOM-aware tokenizers.", + epilog=f"Report bugs to: <{SEND_BUGS_TO}>.") + p.add_argument( + "dataset", metavar="DATASET", + help="dataset containing the training corpus") + p.add_argument( + "--base-tokenizer", metavar="ID", default=DEFAULT_BASE, + help=f"tokenizer to train ours from [default: {DEFAULT_BASE}]") + p.add_argument( + "--split", default=DEFAULT_SPLIT, metavar="SPLIT", dest="split_name", + help=(f"split of the training dataset to use" + f" [default: {DEFAULT_SPLIT}]")) + p.add_argument( + "-N", "--num-inputs", metavar="N", dest="corpus_size", + type=int, + help=("number of sequences in the training dataset, if known;" + " this is used to provide meaningful progress tracking")) + p.add_argument( + "-n", "--num-tokens", metavar="N", dest="vocab_size", type=int, + default=DEFAULT_SIZE, + help=(f"desired vocabulary size, including all special tokens and" + f" the initial alphabet [default: {DEFAULT_SIZE} tokens]")) + p.add_argument( + "-o", "--output", metavar="DIR", dest="save_directory", + help=("directory to save the trained tokenizer into" + " [default: something based on targeted vocabulary size]")) + args = p.parse_args() + + save_directory = args.save_directory + if save_directory is None: + pretty_size = _round_and_prefix(args.vocab_size) + save_directory = f"dom-tokenizer-{pretty_size}" + print(f"Output directory: {save_directory}\n") + warnings.filterwarnings("ignore", message=r".*resume_download.*") - if use_full_dataset: - dataset, kwargs = FULL_DATASET, dict(streaming=True) - else: - dataset, kwargs = TEST_DATASET, {} + tokenizer = train_tokenizer( + load_dataset( + args.dataset, + split=args.split_name, + streaming=True), + base_tokenizer=args.base_tokenizer, + vocab_size=args.vocab_size, + corpus_size=args.corpus_size) + print(f'\n{tokenizer.tokenize("training complete")}') - tokenizer = train_tokenizer(dataset, split="train", **kwargs) tokenizer.save_pretrained(save_directory) + + print(tokenizer.tokenize("tokenizer state saved")) + print(tokenizer.tokenize("see you soon") + ["!!"])