Skip to content

Commit

Permalink
Fix ARADv1TF issues; add example.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Nov 13, 2023
1 parent ea75a5c commit 59885a3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 28 deletions.
15 changes: 15 additions & 0 deletions examples/modeling/arad_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
from riid.models.neural_nets.arad import ARADv1TF

model = ARADv1TF()
model.encoder.summary()
model.decoder.summary()
model.autoencoder.summary()

# The following requires `graphviz` (system software) and `pydot` (Python package)
tf.keras.utils.plot_model(
model.autoencoder,
"ARADv1.png",
show_shapes=True,
expand_nested=True
)
68 changes: 40 additions & 28 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,26 @@
from keras.models import Model
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.layers import BatchNormalization, Conv1D, Dense, Dropout, Concatenate, Input, Flatten
from keras.models import Sequential
from keras.regularizers import Regularizer, _check_penalty_number

from riid.data.sampleset import SampleSet
from riid.losses import jensen_shannon_distance, mish
from riid.models import TFModelBase


@tf.keras.saving.register_keras_serializable(package="riid")
class KLDRegularizer(Regularizer):
def __init__(self, sparsity=0.5):
_check_penalty_number(sparsity)
self.sparsity = sparsity

def __call__(self, inputs):
return inputs # TODO

def get_config(self):
return {"sparsity": float(self.sparsity)}


class ARADv1(TFModelBase):
def __init__(self, latent_dim: int = 5):
super().__init__()
Expand Down Expand Up @@ -52,6 +65,8 @@ def __init__(self, latent_dim: int = 5):
latent_dim: dimension of internal latent represention.
5 was the final one in the paper, but 4 to 8 were found to work well.
"""
super().__init__()

input_size = (128, 1)
# Encoder
b1_config = (
Expand All @@ -75,49 +90,46 @@ def __init__(self, latent_dim: int = 5):
(4, 2, 4),
(1, 1, 2),
)
encoder_input = Input(shape=input_size, name="Encoder Input")
b1 = self._get_branch(b1_config, 0.1, "softplus", "B1", 5)(encoder_input)
b2 = self._get_branch(b2_config, 0.1, "softplus", "B2", 5)(encoder_input)
b3 = self._get_branch(b3_config, 0.1, "softplus", "B3", 5)(encoder_input)
encoder_input = Input(shape=input_size, name="encoder_input")
b1 = self._get_branch(encoder_input, b1_config, 0.1, "softplus", "B1", 5)
b2 = self._get_branch(encoder_input, b2_config, 0.1, "softplus", "B2", 5)
b3 = self._get_branch(encoder_input, b3_config, 0.1, "softplus", "B3", 5)

x = Concatenate(axis=1)([b1, b2, b3])
x = Dense(units=latent_dim, kernel_regularizer="KLD",
name="D1 (latend space)")(x)
encoder_output = BatchNormalization(name="D1 Batch Norm")(x)
encoder = Model(encoder_input, encoder_output, name="Encoder")
encoder.summary()
x = Dense(units=latent_dim, kernel_regularizer=KLDRegularizer(sparsity=0.5),
name="D1_latent_space")(x)
encoder_output = BatchNormalization(name="D1_batch_norm")(x)
encoder = Model(encoder_input, encoder_output, name="encoder")

# Decoder
decoder_input = Input(shape=(latent_dim, 1), name="Decoder Input")
decoder_input = Input(shape=(latent_dim, 1), name="decoder_input")
x = Dense(units=40, name="D2")(decoder_input)
x = Dropout(rate=0.1, name="D2 Dropout")(x)
x = Dropout(rate=0.1, name="D2_dropout")(x)
decoder_output = Dense(units=128, name="D3")(x)
decoder = Model(decoder_input, decoder_output, name="Decoder")
decoder.summary()
decoder = Model(decoder_input, decoder_output, name="decoder")

# Autoencoder
autoencoder_input = Input(shape=input_size, name="Spectrum")
encoded_spectrum = self.encoder(autoencoder_input)
decoded_spectrum = self.decoder(encoded_spectrum)
autoencoder_input = Input(shape=input_size, name="spectrum")
encoded_spectrum = encoder(autoencoder_input)
decoded_spectrum = decoder(encoded_spectrum)
autoencoder = Model(autoencoder_input, decoded_spectrum, name="autoencoder")
autoencoder.summary()

self.encoder = encoder
self.decoder = decoder
self.autoencoder = autoencoder

def _get_branch(self, config, dropout_rate, activation, branch_name, dense_units) -> Sequential:
branch = Sequential(name=branch_name)
def _get_branch(self, input_layer, config, dropout_rate, activation, branch_name, dense_units):
x = input_layer
for i, (kernel_size, strides, filters) in enumerate(config, start=1):
layer_name = f"{branch_name}_C{i}"
branch.add(Conv1D(kernel_size=kernel_size, strides=strides, filters=filters,
activation=activation, name=layer_name))
branch.add(BatchNormalization(name=f"{layer_name} Batch Norm"))
branch.add(Dropout(rate=dropout_rate, name=f"{layer_name} Dropout"))
branch.add(Flatten(name=f"{branch_name} Flatten"))
branch.add(Dense(units=(dense_units,), name=f"{branch_name} Dense"))
branch.add(BatchNormalization(name=f"{branch_name} Batch Norm"))
return branch
x = Conv1D(kernel_size=kernel_size, strides=strides, filters=filters,
activation=activation, name=layer_name)(x)
x = BatchNormalization(name=f"{layer_name}_batch_norm")(x)
x = Dropout(rate=dropout_rate, name=f"{layer_name}_dropout")(x)
x = Flatten(name=f"{branch_name}_flatten")(x)
x = Dense(units=dense_units, name=f"{branch_name}_D1")(x)
x = BatchNormalization(name=f"{branch_name}_batch_norm")(x)
return x

def call(self, x):
decoded = self.autoencoder(x)
Expand Down

0 comments on commit 59885a3

Please sign in to comment.