Skip to content

Commit

Permalink
Reworked ARAD wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Nov 13, 2023
1 parent 13b37bb commit 700d9a2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 52 deletions.
11 changes: 6 additions & 5 deletions examples/modeling/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import sys
import tensorflow as tf
from riid.models.neural_nets.arad import ARADv1TF, ARADv2TF
from riid.models.neural_nets.arad import ARAD, ARADv1TF, ARADv2TF

if len(sys.argv) == 2:
import matplotlib
Expand Down Expand Up @@ -33,7 +33,8 @@ def show_summaries(model):
pass


v1_model = ARADv1TF()
show_summaries(v1_model)
v2_model = ARADv2TF()
show_summaries(v2_model)
arad_v1 = ARAD(model=ARADv1TF())
show_summaries(arad_v1.model)

arad_v2 = ARAD(model=ARADv2TF())
show_summaries(arad_v2.model)
73 changes: 26 additions & 47 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# the U.S. Government retains certain rights in this software.
"""This module contains implementations of the ARAD model architecture."""
import tensorflow as tf
from keras.activations import sigmoid
from keras.initializers import GlorotNormal, HeNormal
from keras.layers import (BatchNormalization, Concatenate, Conv1D,
Conv1DTranspose, Dense, Dropout, Flatten, Input,
MaxPool1D, Reshape, UpSampling1D)
from keras.models import Model
from keras.activations import sigmoid
from keras.regularizers import L1L2, L2, Regularizer, _check_penalty_number

from riid.data.sampleset import SampleSet
Expand All @@ -29,29 +29,6 @@ def get_config(self):
return {"sparsity": float(self.sparsity)}


class ARADv1(TFModelBase):
def __init__(self, latent_dim: int = 5):
super().__init__()

self.latent_dim = latent_dim
self.model = None

# TODO: save as ONNX

def fit(self, ss: SampleSet):
if not self.model:
self.model = ARADv1TF(latent_dim=self.latent_dim)

# TODO: fit

def predict(self, ss: SampleSet, ood_threshold: float = 0.5):
pass
# TODO: predict
# TODO: save results in as:
# SampleSet.info.ood
# SampleSet.info.recon_error


class ARADv1TF(Model):
"""TensorFlow Implementation of ARAD v1.
Expand Down Expand Up @@ -138,29 +115,6 @@ def call(self, x):
return decoded


class ARADv2(TFModelBase):
def __init__(self, latent_dim: int = 5):
super().__init__()

self.latent_dim = latent_dim
self.model = None

# TODO: save as ONNX

def fit(self, ss: SampleSet):
if not self.model:
self.model = ARADv2TF(latent_dim=self.latent_dim)

# TODO: fit

def predict(self, ss: SampleSet, ood_threshold: float = 0.5):
pass
# TODO: predict
# TODO: save results in as:
# SampleSet.info.ood
# SampleSet.info.recon_error


class ARADv2TF(Model):
"""TensorFlow Implementation of ARAD v2.
Expand Down Expand Up @@ -269,3 +223,28 @@ def __init__(self, latent_dim: int = 5):
def call(self, x):
decoded = self.autoencoder(x)
return decoded


class ARAD(TFModelBase):
"""PyRIID-compatible wrapper around ARAD models.
"""
def __init__(self, model: Model = ARADv2TF()):
"""
Args:
model: instantiated model of the desired version of ARAD to use.
"""
super().__init__()

self.model = model

# TODO: enable saving as ONNX

def fit(self, ss: SampleSet):
pass # TODO: fit

def predict(self, ss: SampleSet, ood_threshold: float = 0.5):
pass
# TODO: predict
# TODO: save results in as:
# SampleSet.info.ood
# SampleSet.info.recon_error

0 comments on commit 700d9a2

Please sign in to comment.