Skip to content

Commit

Permalink
test: fix test that tried to use tensorflowjs
Browse files Browse the repository at this point in the history
  • Loading branch information
unmonoqueteclea committed Feb 13, 2024
1 parent 4ee2231 commit 396c3e7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
6 changes: 5 additions & 1 deletion sensenet/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sensenet import __tree_ext_prefix__

logger = logging.getLogger(__name__)
logging.getLogger("tensorflow").setLevel(logging.ERROR)

warnings.filterwarnings("ignore", message=".*binary incompatibility.*")
Expand All @@ -33,7 +34,10 @@
# but it is not mandatory for Sensenet to work
try:
import tensorflowjs
except: # noqa: E722
except Exception as e: # noqa: E722
logger.info(
f"tensorflowjs not found, you can't export models to JS: {e}"
)
tensorflowjs = None

bigml_tf_module = None
Expand Down
2 changes: 1 addition & 1 deletion sensenet/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def write_tfjs_files(self, model_path, save_path):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*alias for the.*")
with suppress_stdout():
if tfjs:
if tfjs and tfjs.converters:
tfjs.converters.convert_tf_saved_model(
model_path, save_path, skip_op_check=True
)
Expand Down
20 changes: 11 additions & 9 deletions tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
np = sensenet.importers.import_numpy()
tf = sensenet.importers.import_tensorflow()

import os
import shutil
import gzip
import json

from PIL import Image
import os
import pathlib
import shutil

from sensenet.constants import WARP
from sensenet.models.wrappers import Deepnet, ObjectDetector
from sensenet.models.wrappers import convert, tflite_predict
from sensenet.models.wrappers import (
Deepnet,
ObjectDetector,
convert,
tflite_predict,
)

from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA
from .test_pretrained import create_image_model
from .utils import TEST_DATA_DIR, TEST_IMAGE_DATA

MOBILENET_PATH = os.path.join(TEST_DATA_DIR, "mobilenetv2.json.gz")
TEST_SAVE_MODEL = os.path.join(TEST_DATA_DIR, "test_model_save")
Expand Down Expand Up @@ -118,8 +121,7 @@ def test_all_conversions():
for aformat in ["tflite", "tfjs", "smbundle", "h5"]:
outpath = TEST_SAVE_MODEL + "." + aformat
convert(jmodel, None, outpath, aformat)

if aformat == "tfjs":
if aformat == "tfjs" and pathlib.Path(outpath).exists():
shutil.rmtree(outpath)
else:
os.remove(outpath)

0 comments on commit 396c3e7

Please sign in to comment.