diff --git a/.flake8 b/.flake8 index 6c31553..998689f 100644 --- a/.flake8 +++ b/.flake8 @@ -5,5 +5,3 @@ per-file-ignores = src/dom_tokenizers/**/__init__.py: F401 # line too long src/dom_tokenizers/pre_tokenizers/dom_snapshot.py: E501 - # module level import not at top of file - src/dom_tokenizers/train.py: E402 diff --git a/src/dom_tokenizers/internal/__init__.py b/src/dom_tokenizers/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dom_tokenizers/internal/transformers.py b/src/dom_tokenizers/internal/transformers.py new file mode 100644 index 0000000..3ddc75c --- /dev/null +++ b/src/dom_tokenizers/internal/transformers.py @@ -0,0 +1,18 @@ +import os + +# Don't print "None of PyTorch, TensorFlow >= 2.0, or Flax have been +# found. Models won't be available and only tokenizers, configuration +# and file/data utilities can be used" warning. Tokenizers is all we +# want! + +__var_name = "TRANSFORMERS_NO_ADVISORY_WARNINGS" +__orig_val = os.environ.get(__var_name) +os.environ[__var_name] = "1" +try: + from transformers import AutoTokenizer # noqa: F401 +finally: + if __orig_val is None: + os.environ.pop(__var_name) + else: + os.environ[__var_name] = __orig_val + del __var_name, __orig_val, os diff --git a/src/dom_tokenizers/train.py b/src/dom_tokenizers/train.py index 0d9aa56..92c7092 100644 --- a/src/dom_tokenizers/train.py +++ b/src/dom_tokenizers/train.py @@ -1,4 +1,3 @@ -import os import json import warnings @@ -8,9 +7,7 @@ from datasets import load_dataset from tokenizers.pre_tokenizers import PreTokenizer, WhitespaceSplit -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = str(True) -from transformers import AutoTokenizer - +from .internal.transformers import AutoTokenizer from .pre_tokenizers import DOMSnapshotPreTokenizer DEFAULT_BASE = "bert-base-uncased"