Skip to content

Commit

Permalink
test: fix test that tried to use tensorflowjs
Browse files Browse the repository at this point in the history
  • Loading branch information
unmonoqueteclea committed Feb 13, 2024
1 parent 4ee2231 commit cce25d0
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
- name: Run sensenet tests
run: |
pytest -sv
pytest -xsv
6 changes: 5 additions & 1 deletion sensenet/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sensenet import __tree_ext_prefix__

logger = logging.getLogger(__name__)
logging.getLogger("tensorflow").setLevel(logging.ERROR)

warnings.filterwarnings("ignore", message=".*binary incompatibility.*")
Expand All @@ -33,7 +34,10 @@
# but it is not mandatory for Sensenet to work
try:
import tensorflowjs
except: # noqa: E722
except Exception as e: # noqa: E722
logger.info(
f"tensorflowjs not found, you can't export models to JS: {e}"
)
tensorflowjs = None

bigml_tf_module = None
Expand Down
2 changes: 1 addition & 1 deletion sensenet/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def write_tfjs_files(self, model_path, save_path):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*alias for the.*")
with suppress_stdout():
if tfjs:
if tfjs and tfjs.converters:
tfjs.converters.convert_tf_saved_model(
model_path, save_path, skip_op_check=True
)
Expand Down
20 changes: 11 additions & 9 deletions tests/test_export.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
# flake8: noqa: E402
import sensenet.importers

np = sensenet.importers.import_numpy()
tf = sensenet.importers.import_tensorflow()

import os
import shutil
import gzip
import json

from PIL import Image
import os
import shutil

from sensenet.constants import WARP
from sensenet.models.wrappers import Deepnet, ObjectDetector
from sensenet.models.wrappers import convert, tflite_predict
from sensenet.models.wrappers import (
Deepnet,
ObjectDetector,
convert,
tflite_predict,
)

from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA
from .test_pretrained import create_image_model
from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA

MOBILENET_PATH = os.path.join(TEST_DATA_DIR, "mobilenetv2.json.gz")
TEST_SAVE_MODEL = os.path.join(TEST_DATA_DIR, "test_model_save")
Expand Down Expand Up @@ -118,8 +121,7 @@ def test_all_conversions():
for aformat in ["tflite", "tfjs", "smbundle", "h5"]:
outpath = TEST_SAVE_MODEL + "." + aformat
convert(jmodel, None, outpath, aformat)

if aformat == "tfjs":
shutil.rmtree(outpath)
shutil.rmtree(outpath, ignore_errors=True)
else:
os.remove(outpath)

0 comments on commit cce25d0

Please sign in to comment.