Skip to content

MMvec refactor #166

@mortonjt

Description

@mortonjt

We're going to go pytorch OR numpyro. The framework will have the following skeleton

model.py (mmvec.py)

import torch
import torch.nn
from torch.distributions import Multinomial

class MMvec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
        # TODO : may want to have a better softmax

    def forward(X, Y):
        """ X is one-hot encodings (B x num_microbes).  Y is metabolite abundances (B x num_metabolites).  B is the batch size""" 
        z = self.encoder(X)
        pred_y = self.decoder(z)
        lp = Multinomial(pred_y).log_prob(Y).mean()
        return lp

train.py (could use Pytorch lightning)

The wishlist

  • Early stopping (see video for example)
  • Arviz for diagnostics diagnostics
  • Typing would be great. See torchtyping
  • Torchtests could be cool also. See torchtest
  • Being Bayesian would be nice. SWAG is the laziest approach

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions