Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 21, 2024
1 parent 75e7688 commit ba91b08
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 126 deletions.
148 changes: 74 additions & 74 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,79 +1,79 @@
fail_fast: false
default_language_version:
python: python3
python: python3
default_stages:
- commit
- push
- commit
- push
minimum_pre_commit_version: 2.16.0
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
- repo: https://github.com/asottile/blacken-docs
rev: 1.16.0
hooks:
- id: blacken-docs
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
additional_dependencies:
- flake8-tidy-imports
- flake8-docstrings
- flake8-rst-docstrings
- flake8-comprehensions
- flake8-bugbear
- flake8-blind-except
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
- id: end-of-file-fixer
- id: mixed-line-ending
args: [--fix=lf]
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args:
- --in-place
- --remove-all-unused-imports
- --remove-unused-variable
- --ignore-init-module-imports
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-tidy-imports
- flake8-docstrings
- flake8-rst-docstrings
- flake8-comprehensions
- flake8-bugbear
- flake8-blind-except
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
- repo: local
hooks:
- id: forbid-to-commit
name: Don't commit rej files
entry: |
Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
Fix the merge conflicts manually and remove the .rej files.
language: fail
files: '.*\.rej$'
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
- repo: https://github.com/asottile/blacken-docs
rev: 1.16.0
hooks:
- id: blacken-docs
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
additional_dependencies:
- flake8-tidy-imports
- flake8-docstrings
- flake8-rst-docstrings
- flake8-comprehensions
- flake8-bugbear
- flake8-blind-except
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
- id: end-of-file-fixer
- id: mixed-line-ending
args: [--fix=lf]
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args:
- --in-place
- --remove-all-unused-imports
- --remove-unused-variable
- --ignore-init-module-imports
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-tidy-imports
- flake8-docstrings
- flake8-rst-docstrings
- flake8-comprehensions
- flake8-bugbear
- flake8-blind-except
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
- repo: local
hooks:
- id: forbid-to-commit
name: Don't commit rej files
entry: |
Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates.
Fix the merge conflicts manually and remove the .rej files.
language: fail
files: '.*\.rej$'
22 changes: 11 additions & 11 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# https://docs.readthedocs.io/en/stable/config-file/v2.html
version: 2
build:
os: ubuntu-20.04
tools:
python: "3.10"
os: ubuntu-20.04
tools:
python: "3.10"
sphinx:
configuration: docs/conf.py
# disable this for more lenient docs builds
fail_on_warning: false
configuration: docs/conf.py
# disable this for more lenient docs builds
fail_on_warning: false
python:
install:
- method: pip
path: .
extra_requirements:
- doc
install:
- method: pip
path: .
extra_requirements:
- doc
4 changes: 2 additions & 2 deletions docs/index.md.rej
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
diff a/docs/index.md b/docs/index.md (rejected hunks)
@@ -8,7 +8,6 @@

api.md
changelog.md
-template_usage.md
contributing.md
references.md

4 changes: 2 additions & 2 deletions pyproject.toml.rej
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ diff a/pyproject.toml b/pyproject.toml (rejected hunks)
- "session-info"
+ "session-info",
]

[project.optional-dependencies]
dev = [
- # CLI for bumping the version number
Expand Down Expand Up @@ -36,5 +36,5 @@ diff a/pyproject.toml b/pyproject.toml (rejected hunks)
- "pytest-cov",
+ "coverage",
]

[tool.coverage.run]
24 changes: 14 additions & 10 deletions src/simple_scvi/_mymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
from anndata import AnnData
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
)
from scvi.data.fields import (CategoricalJointObsField, CategoricalObsField,
LayerField, NumericalJointObsField)
from scvi.model._utils import _init_library_size
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp
Expand Down Expand Up @@ -57,7 +53,9 @@ def __init__(
):
super().__init__(adata)

library_log_means, library_log_vars = _init_library_size(self.adata_manager, self.summary_stats["n_batch"])
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, self.summary_stats["n_batch"]
)

# self.summary_stats provides information about anndata dimensions and other tensor info

Expand Down Expand Up @@ -108,9 +106,15 @@ def setup_anndata(
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
66 changes: 50 additions & 16 deletions src/simple_scvi/_mymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ def __init__(
# this is needed to comply with some requirement of the VAEMixin class
self.latent_distribution = "normal"

self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float())
self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float())
self.register_buffer(
"library_log_means", torch.from_numpy(library_log_means).float()
)
self.register_buffer(
"library_log_vars", torch.from_numpy(library_log_vars).float()
)

# setup the parameters of your generative model, as well as your inference model
self.px_r = torch.nn.Parameter(torch.randn(n_input))
Expand Down Expand Up @@ -129,7 +133,9 @@ def generative(self, z, library):
px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library)
px_r = torch.exp(self.px_r)

return dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout)
return dict(
px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout
)

def loss(
self,
Expand All @@ -151,20 +157,28 @@ def loss(
mean = torch.zeros_like(qz_m)
scale = torch.ones_like(qz_v)

kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)
kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
dim=1
)

batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
n_batch = self.library_log_means.shape[1]
local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means)
local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars)
local_library_log_means = F.linear(
one_hot(batch_index, n_batch), self.library_log_means
)
local_library_log_vars = F.linear(
one_hot(batch_index, n_batch), self.library_log_vars
)

kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)

reconst_loss = (
-ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1)
-ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
.log_prob(x)
.sum(dim=-1)
)

kl_local_for_warmup = kl_divergence_z
Expand All @@ -174,8 +188,12 @@ def loss(

loss = torch.mean(reconst_loss + weighted_kl_local)

kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z)
return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local)
kl_local = dict(
kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
)
return LossOutput(
loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local
)

@torch.no_grad()
def sample(
Expand Down Expand Up @@ -216,10 +234,14 @@ def sample(
px_rate = generative_outputs["px_rate"]
px_dropout = generative_outputs["px_dropout"]

dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
dist = ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
)

if n_samples > 1:
exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples)
exprs = dist.sample().permute(
[1, 2, 0]
) # Shape : (n_cells_batch, n_genes, n_samples)
else:
exprs = dist.sample()

Expand Down Expand Up @@ -249,11 +271,23 @@ def marginal_ll(self, tensors: TensorDict, n_mc_samples: int):

# Log-probabilities
n_batch = self.library_log_means.shape[1]
local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means)
local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars)
p_l = Normal(local_library_log_means, local_library_log_vars.sqrt()).log_prob(library).sum(dim=-1)

p_z = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)).log_prob(z).sum(dim=-1)
local_library_log_means = F.linear(
one_hot(batch_index, n_batch), self.library_log_means
)
local_library_log_vars = F.linear(
one_hot(batch_index, n_batch), self.library_log_vars
)
p_l = (
Normal(local_library_log_means, local_library_log_vars.sqrt())
.log_prob(library)
.sum(dim=-1)
)

p_z = (
Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
.log_prob(z)
.sum(dim=-1)
)
p_x_zl = -reconst_loss
q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1)
Expand Down
Loading

0 comments on commit ba91b08

Please sign in to comment.