Skip to content

Commit 1b8c8d1

Browse files
committed
enabled several nb and poisson losses
1 parent bd56d27 commit 1b8c8d1

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/multigrate/model/_multivae.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def train(
291291
weight_decay: float = 1e-3,
292292
eps: float = 1e-08,
293293
early_stopping: bool = True,
294+
early_stopping_patience = 50,
294295
save_best: bool = True,
295296
check_val_every_n_epoch: int | None = None,
296297
n_epochs_kl_warmup: int | None = None,
@@ -419,7 +420,7 @@ def train(
419420
early_stopping=early_stopping,
420421
check_val_every_n_epoch=check_val_every_n_epoch,
421422
early_stopping_monitor="reconstruction_loss_validation",
422-
early_stopping_patience=50,
423+
early_stopping_patience=early_stopping_patience,
423424
enable_checkpointing=True,
424425
**kwargs,
425426
)

src/multigrate/module/_multivae_torch.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,13 @@ def __init__(
181181

182182
# assume for now that can only use nb/zinb once, i.e. for RNA-seq modality
183183
# TODO: add check for multiple nb/zinb losses given
184-
self.theta = None
184+
self.theta = []
185+
j = []
185186
for i, loss in enumerate(losses):
186187
if loss in ["nb", "zinb"]:
187-
self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups))
188-
break
188+
self.theta.append(torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups)))
189+
else:
190+
self.theta.append([])
189191

190192
# modality encoders
191193
cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0
@@ -307,6 +309,7 @@ def _h_to_x(self, h, i):
307309
return x
308310

309311
def _product_of_experts(self, mus, logvars, masks):
312+
#print(mus, logvars, masks)
310313
vars = torch.exp(logvars)
311314
masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1])
312315
mus_joint = torch.sum(mus * masks / vars, dim=1)
@@ -657,7 +660,7 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
657660
dec_mean = r
658661
size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1))
659662
dec_mean = dec_mean * size_factor_view
660-
dispersion = self.theta.T[group.squeeze().long()]
663+
dispersion = self.theta[i].to(self.device).T[group.squeeze().long()]
661664
dispersion = torch.exp(dispersion)
662665
nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1)
663666
nb_loss = loss_coefs[str(i)] * nb_loss
@@ -666,9 +669,9 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
666669
dec_mean, dec_dropout = r
667670
dec_mean = dec_mean.squeeze()
668671
dec_dropout = dec_dropout.squeeze()
669-
size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1))
672+
size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1))
670673
dec_mean = dec_mean * size_factor_view
671-
dispersion = self.theta.T[group.squeeze().long()]
674+
dispersion = self.theta[i].to(self.device).T[group.squeeze().long()]
672675
dispersion = torch.exp(dispersion)
673676
zinb_loss = torch.sum(
674677
ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x),

0 commit comments

Comments
 (0)