From 13b37bb4b9f6fd7ef7d9cc5a2870a652c9238cba Mon Sep 17 00:00:00 2001 From: Tyler Morrow Date: Mon, 13 Nov 2023 15:40:29 -0700 Subject: [PATCH] Rework ARAD examples. --- examples/modeling/arad.py | 39 ++++++++++++++++++++++++++++++++++++ examples/modeling/arad_v1.py | 16 --------------- examples/modeling/arad_v2.py | 16 --------------- 3 files changed, 39 insertions(+), 32 deletions(-) create mode 100644 examples/modeling/arad.py delete mode 100644 examples/modeling/arad_v1.py delete mode 100644 examples/modeling/arad_v2.py diff --git a/examples/modeling/arad.py b/examples/modeling/arad.py new file mode 100644 index 0000000..bac9cf8 --- /dev/null +++ b/examples/modeling/arad.py @@ -0,0 +1,39 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This example demonstrates how to use the PyRIID implementations of ARAD. +""" +import sys +import tensorflow as tf +from riid.models.neural_nets.arad import ARADv1TF, ARADv2TF + +if len(sys.argv) == 2: + import matplotlib + matplotlib.use("Agg") + + +def show_summaries(model): + model_type = type(model).__name__ + print(model_type) + + model.encoder.summary() + model.decoder.summary() + model.autoencoder.summary() + + # The following requires `graphviz` (system software) and `pydot` (Python package) + try: + tf.keras.utils.plot_model( + model.autoencoder, + f"{model_type}.png", + show_shapes=True, + rankdir="LR", + expand_nested=True, + ) + except Exception: + pass + + +v1_model = ARADv1TF() +show_summaries(v1_model) +v2_model = ARADv2TF() +show_summaries(v2_model) diff --git a/examples/modeling/arad_v1.py b/examples/modeling/arad_v1.py deleted file mode 100644 index 804c94f..0000000 --- a/examples/modeling/arad_v1.py +++ /dev/null @@ -1,16 +0,0 @@ -import tensorflow as tf -from riid.models.neural_nets.arad import ARADv1TF - -model = ARADv1TF() -model.encoder.summary() -model.decoder.summary() -model.autoencoder.summary() - -# The following requires `graphviz` (system software) and `pydot` (Python package) -tf.keras.utils.plot_model( - model.autoencoder, - "ARADv1.png", - show_shapes=True, - rankdir="LR", - expand_nested=True -) diff --git a/examples/modeling/arad_v2.py b/examples/modeling/arad_v2.py deleted file mode 100644 index ff549b2..0000000 --- a/examples/modeling/arad_v2.py +++ /dev/null @@ -1,16 +0,0 @@ -import tensorflow as tf -from riid.models.neural_nets.arad import ARADv2TF - -model = ARADv2TF() -model.encoder.summary() -model.decoder.summary() -model.autoencoder.summary() - -# The following requires `graphviz` (system software) and `pydot` (Python package) -tf.keras.utils.plot_model( - model.autoencoder, - "ARADv2.png", - show_shapes=True, - rankdir="LR", - expand_nested=True, -)