Skip to content

Commit

Permalink
Full fluent API implementation for ARADv1TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
tymorrow committed Nov 13, 2023
1 parent bb0238d commit ea75a5c
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions riid/models/neural_nets/arad.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, latent_dim: int = 5):
latent_dim: dimension of internal latent represention.
5 was the final one in the paper, but 4 to 8 were found to work well.
"""
input_size = (128, 1)
# Encoder
b1_config = (
(5, 1, 32),
Expand All @@ -74,32 +75,39 @@ def __init__(self, latent_dim: int = 5):
(4, 2, 4),
(1, 1, 2),
)
self.b1_encoder = self._get_branch(b1_config, 0.1, "softplus", "B1", 5)
self.b2_encoder = self._get_branch(b2_config, 0.1, "softplus", "B2", 5)
self.b3_encoder = self._get_branch(b3_config, 0.1, "softplus", "B3", 5)

self.branch_concat = Concatenate(axis=1)

concat_encoder = Sequential(name="Final Encoder")
d1 = Dense(units=latent_dim, kernel_regularizer="KLD", name="D1") # Latent space
concat_encoder.add(d1)
d1_batch_norm = BatchNormalization(name="D1 Batch Norm")
concat_encoder.add(d1_batch_norm)
self.concat_encoder = concat_encoder
encoder_input = Input(shape=input_size, name="Encoder Input")
b1 = self._get_branch(b1_config, 0.1, "softplus", "B1", 5)(encoder_input)
b2 = self._get_branch(b2_config, 0.1, "softplus", "B2", 5)(encoder_input)
b3 = self._get_branch(b3_config, 0.1, "softplus", "B3", 5)(encoder_input)

x = Concatenate(axis=1)([b1, b2, b3])
x = Dense(units=latent_dim, kernel_regularizer="KLD",
name="D1 (latend space)")(x)
encoder_output = BatchNormalization(name="D1 Batch Norm")(x)
encoder = Model(encoder_input, encoder_output, name="Encoder")
encoder.summary()

# Decoder
decoder = Sequential(name="Decoder")
d2 = Dense(units=40, name="D2")
decoder.add(d2)
d2_dropout = Dropout(rate=0.1, name="D2 Dropout")
decoder.add(d2_dropout)
d3 = Dense(units=128, name="D3")
decoder.add(d3)
decoder_input = Input(shape=(latent_dim, 1), name="Decoder Input")
x = Dense(units=40, name="D2")(decoder_input)
x = Dropout(rate=0.1, name="D2 Dropout")(x)
decoder_output = Dense(units=128, name="D3")(x)
decoder = Model(decoder_input, decoder_output, name="Decoder")
decoder.summary()

# Autoencoder
autoencoder_input = Input(shape=input_size, name="Spectrum")
encoded_spectrum = self.encoder(autoencoder_input)
decoded_spectrum = self.decoder(encoded_spectrum)
autoencoder = Model(autoencoder_input, decoded_spectrum, name="autoencoder")
autoencoder.summary()

self.encoder = encoder
self.decoder = decoder
self.autoencoder = autoencoder

def _get_branch(self, config, dropout_rate, activation, branch_name, dense_units) -> Sequential:
branch = Sequential(name=branch_name)
branch.add(Input(shape=(128, 1,), name=f"{branch_name} Input Spectrum"))
for i, (kernel_size, strides, filters) in enumerate(config, start=1):
layer_name = f"{branch_name}_C{i}"
branch.add(Conv1D(kernel_size=kernel_size, strides=strides, filters=filters,
Expand All @@ -112,12 +120,7 @@ def _get_branch(self, config, dropout_rate, activation, branch_name, dense_units
return branch

def call(self, x):
b1_output = self.b1_encoder(x)
b2_output = self.b2_encoder(x)
b3_output = self.b3_encoder(x)
concat_output = self.branch_concat([b1_output, b2_output, b3_output])
encoded = self.concat_encoder(concat_output)
decoded = self.decoder(encoded)
decoded = self.autoencoder(x)
return decoded


Expand Down

0 comments on commit ea75a5c

Please sign in to comment.