-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
ucate/ucate/library/models/cevae.py
Lines 110 to 186 in f57e47b
| class Encoder(tf.keras.Model): | |
| def __init__( | |
| self, | |
| do_convolution, | |
| dim_latent, | |
| num_examples, | |
| dim_hidden, | |
| dropout_rate=0.1, | |
| beta=1.0, | |
| negative_sampling=True, | |
| *args, | |
| **kwargs | |
| ): | |
| super(Encoder, self).__init__(*args, **kwargs) | |
| self.conv = ( | |
| convolution.ConvHead( | |
| base_filters=32, | |
| num_examples=sum(num_examples), | |
| dropout_rate=dropout_rate, | |
| ) | |
| if do_convolution | |
| else dense.identity() | |
| ) | |
| self.hidden_1 = dense.Dense( | |
| units=dim_hidden, | |
| num_examples=sum(num_examples), | |
| dropout_rate=dropout_rate, | |
| activation="elu", | |
| name="encoder_hidden_1", | |
| ) | |
| self.hidden_2 = dense.Dense( | |
| units=dim_hidden, | |
| num_examples=sum(num_examples), | |
| dropout_rate=dropout_rate, | |
| activation="elu", | |
| name="encoder_hidden_2", | |
| ) | |
| self.hidden_3 = dense.Dense( | |
| units=dim_hidden, | |
| num_examples=num_examples, | |
| dropout_rate=dropout_rate, | |
| num_branches=2, | |
| activation="elu", | |
| name="encoder_hidden_3", | |
| ) | |
| self.hidden_4 = dense.Dense( | |
| units=dim_hidden, | |
| num_examples=num_examples, | |
| dropout_rate=dropout_rate, | |
| num_branches=2, | |
| activation="elu", | |
| name="encoder_hidden_4", | |
| ) | |
| self.sampler = samplers.NormalSampler( | |
| dim_output=dim_latent, | |
| num_branches=2, | |
| num_examples=num_examples, | |
| beta=beta / 2 if negative_sampling else beta, | |
| ) | |
| self.negative_sampling = negative_sampling | |
| def call(self, inputs, training=None): | |
| x, t = inputs | |
| q = self.forward([x, t], training=training) | |
| if self.negative_sampling: | |
| t_cf = 1.0 - t | |
| _ = self.forward([x, t_cf], training=training) | |
| return q | |
| def forward(self, inputs, training=None): | |
| x, t = inputs | |
| outputs = self.conv(x, training=training) | |
| outputs = self.hidden_1(outputs, training=training) | |
| outputs = self.hidden_2(outputs, training=training) | |
| outputs = self.hidden_3([outputs, t], training=training) | |
| outputs = self.hidden_4([outputs, t], training=training) | |
| return self.sampler([outputs, t], training=training) |
@anndvision could you please help me understand the encoder in cevae. If understand correctly; there supposed to be three inference models:
t ~ q(t|x) # treatment
y ~ q(y|t,x) # outcome
z ~ q(z|y,t,x) # latent confounder, an embedding
- In ucate code it seems that there is just a single inference model encoder that takes x and t only, which are passed as a concatenated array into the encoder network:
ucate/ucate/library/models/cevae.py
Line 184 in f57e47b
outputs = self.hidden_3([outputs, t], training=training)
Metadata
Metadata
Assignees
Labels
No labels