-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
We're going to go pytorch OR numpyro. The framework will have the following skeleton
model.py (mmvec.py)
import torch
import torch.nn
from torch.distributions import Multinomial
class MMvec(nn.Module):
def __init__(self, num_microbes, num_metabolites, latent_dim):
self.encoder = nn.Embedding(num_microbes, latent_dim)
self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
# TODO : may want to have a better softmax
def forward(X, Y):
""" X is one-hot encodings (B x num_microbes). Y is metabolite abundances (B x num_metabolites). B is the batch size"""
z = self.encoder(X)
pred_y = self.decoder(z)
lp = Multinomial(pred_y).log_prob(Y).mean()
return lptrain.py (could use Pytorch lightning)
The wishlist
- Early stopping (see video for example)
- Arviz for diagnostics diagnostics
- Typing would be great. See torchtyping
- Torchtests could be cool also. See torchtest
- Being Bayesian would be nice. SWAG is the laziest approach
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels