diff --git a/src/dom_tokenizers/train.py b/src/dom_tokenizers/train.py index 8724f2f..0d9aa56 100644 --- a/src/dom_tokenizers/train.py +++ b/src/dom_tokenizers/train.py @@ -80,7 +80,16 @@ def get_training_corpus(): length=corpus_size, show_progress=True, ) + + # Post-training fixups. new_tokenizer.name_or_path = _pretty_name(new_tokenizer) + for token in new_special_tokens: + attr = f"{token[1:-1].lower()}_token" + if not hasattr(new_tokenizer, attr): + continue # not a "special" special token + if getattr(new_tokenizer, attr) is not None: + continue # already got this one + setattr(new_tokenizer, attr, token) return new_tokenizer