@@ -181,11 +181,13 @@ def __init__(
181
181
182
182
# assume for now that can only use nb/zinb once, i.e. for RNA-seq modality
183
183
# TODO: add check for multiple nb/zinb losses given
184
- self .theta = None
184
+ self .theta = []
185
+ j = []
185
186
for i , loss in enumerate (losses ):
186
187
if loss in ["nb" , "zinb" ]:
187
- self .theta = torch .nn .Parameter (torch .randn (self .input_dims [i ], num_groups ))
188
- break
188
+ self .theta .append (torch .nn .Parameter (torch .randn (self .input_dims [i ], num_groups )))
189
+ else :
190
+ self .theta .append ([])
189
191
190
192
# modality encoders
191
193
cond_dim_enc = cond_dim * (len (cat_covariate_dims ) + len (cont_covariate_dims )) if self .condition_encoders else 0
@@ -307,6 +309,7 @@ def _h_to_x(self, h, i):
307
309
return x
308
310
309
311
def _product_of_experts (self , mus , logvars , masks ):
312
+ #print(mus, logvars, masks)
310
313
vars = torch .exp (logvars )
311
314
masks = masks .unsqueeze (- 1 ).repeat (1 , 1 , vars .shape [- 1 ])
312
315
mus_joint = torch .sum (mus * masks / vars , dim = 1 )
@@ -657,7 +660,7 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
657
660
dec_mean = r
658
661
size_factor_view = size_factor .expand (dec_mean .size (0 ), dec_mean .size (1 ))
659
662
dec_mean = dec_mean * size_factor_view
660
- dispersion = self .theta .T [group .squeeze ().long ()]
663
+ dispersion = self .theta [ i ]. to ( self . device ) .T [group .squeeze ().long ()]
661
664
dispersion = torch .exp (dispersion )
662
665
nb_loss = torch .sum (NegativeBinomial (mu = dec_mean , theta = dispersion ).log_prob (x ), dim = - 1 )
663
666
nb_loss = loss_coefs [str (i )] * nb_loss
@@ -666,9 +669,9 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
666
669
dec_mean , dec_dropout = r
667
670
dec_mean = dec_mean .squeeze ()
668
671
dec_dropout = dec_dropout .squeeze ()
669
- size_factor_view = size_factor .unsqueeze ( 1 ). expand (dec_mean .size (0 ), dec_mean .size (1 ))
672
+ size_factor_view = size_factor .expand (dec_mean .size (0 ), dec_mean .size (1 ))
670
673
dec_mean = dec_mean * size_factor_view
671
- dispersion = self .theta .T [group .squeeze ().long ()]
674
+ dispersion = self .theta [ i ]. to ( self . device ) .T [group .squeeze ().long ()]
672
675
dispersion = torch .exp (dispersion )
673
676
zinb_loss = torch .sum (
674
677
ZeroInflatedNegativeBinomial (mu = dec_mean , theta = dispersion , zi_logits = dec_dropout ).log_prob (x ),
0 commit comments