diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 107a658..b9e7467 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,79 +1,79 @@ fail_fast: false default_language_version: - python: python3 + python: python3 default_stages: - - commit - - push + - commit + - push minimum_pre_commit_version: 2.16.0 repos: - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - - repo: https://github.com/asottile/blacken-docs - rev: 1.16.0 - hooks: - - id: blacken-docs - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - - repo: https://github.com/asottile/yesqa - rev: v1.5.0 - hooks: - - id: yesqa - additional_dependencies: - - flake8-tidy-imports - - flake8-docstrings - - flake8-rst-docstrings - - flake8-comprehensions - - flake8-bugbear - - flake8-blind-except - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 - hooks: - - id: detect-private-key - - id: check-ast - - id: end-of-file-fixer - - id: mixed-line-ending - args: [--fix=lf] - - id: trailing-whitespace - - id: check-case-conflict - - repo: https://github.com/PyCQA/autoflake - rev: v2.3.1 - hooks: - - id: autoflake - args: - - --in-place - - --remove-all-unused-imports - - --remove-unused-variable - - --ignore-init-module-imports - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: - - flake8-tidy-imports - - flake8-docstrings - - flake8-rst-docstrings - - flake8-comprehensions - - flake8-bugbear - - flake8-blind-except - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 - hooks: - - id: pyupgrade - args: [--py3-plus, --py38-plus, --keep-runtime-typing] - - repo: local - hooks: - - id: forbid-to-commit - name: Don't commit rej files - entry: | - Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. - Fix the merge conflicts manually and remove the .rej files. - language: fail - files: '.*\.rej$' + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + - repo: https://github.com/asottile/blacken-docs + rev: 1.16.0 + hooks: + - id: blacken-docs + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/asottile/yesqa + rev: v1.5.0 + hooks: + - id: yesqa + additional_dependencies: + - flake8-tidy-imports + - flake8-docstrings + - flake8-rst-docstrings + - flake8-comprehensions + - flake8-bugbear + - flake8-blind-except + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: detect-private-key + - id: check-ast + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: trailing-whitespace + - id: check-case-conflict + - repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 + hooks: + - id: autoflake + args: + - --in-place + - --remove-all-unused-imports + - --remove-unused-variable + - --ignore-init-module-imports + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: + - flake8-tidy-imports + - flake8-docstrings + - flake8-rst-docstrings + - flake8-comprehensions + - flake8-bugbear + - flake8-blind-except + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.2 + hooks: + - id: pyupgrade + args: [--py3-plus, --py38-plus, --keep-runtime-typing] + - repo: local + hooks: + - id: forbid-to-commit + name: Don't commit rej files + entry: | + Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. + Fix the merge conflicts manually and remove the .rej files. + language: fail + files: '.*\.rej$' diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 1bf3915..23a5340 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,16 +1,16 @@ # https://docs.readthedocs.io/en/stable/config-file/v2.html version: 2 build: - os: ubuntu-20.04 - tools: - python: "3.10" + os: ubuntu-20.04 + tools: + python: "3.10" sphinx: - configuration: docs/conf.py - # disable this for more lenient docs builds - fail_on_warning: false + configuration: docs/conf.py + # disable this for more lenient docs builds + fail_on_warning: false python: - install: - - method: pip - path: . - extra_requirements: - - doc + install: + - method: pip + path: . + extra_requirements: + - doc diff --git a/docs/index.md.rej b/docs/index.md.rej index 14a8d31..922185b 100644 --- a/docs/index.md.rej +++ b/docs/index.md.rej @@ -1,9 +1,9 @@ diff a/docs/index.md b/docs/index.md (rejected hunks) @@ -8,7 +8,6 @@ - + api.md changelog.md -template_usage.md contributing.md references.md - + diff --git a/pyproject.toml.rej b/pyproject.toml.rej index 8906fd0..9518688 100644 --- a/pyproject.toml.rej +++ b/pyproject.toml.rej @@ -6,7 +6,7 @@ diff a/pyproject.toml b/pyproject.toml (rejected hunks) - "session-info" + "session-info", ] - + [project.optional-dependencies] dev = [ - # CLI for bumping the version number @@ -36,5 +36,5 @@ diff a/pyproject.toml b/pyproject.toml (rejected hunks) - "pytest-cov", + "coverage", ] - + [tool.coverage.run] diff --git a/src/simple_scvi/_mymodel.py b/src/simple_scvi/_mymodel.py index 69cc50d..0cb5da3 100644 --- a/src/simple_scvi/_mymodel.py +++ b/src/simple_scvi/_mymodel.py @@ -4,12 +4,8 @@ from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import ( - CategoricalJointObsField, - CategoricalObsField, - LayerField, - NumericalJointObsField, -) +from scvi.data.fields import (CategoricalJointObsField, CategoricalObsField, + LayerField, NumericalJointObsField) from scvi.model._utils import _init_library_size from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin from scvi.utils import setup_anndata_dsp @@ -57,7 +53,9 @@ def __init__( ): super().__init__(adata) - library_log_means, library_log_vars = _init_library_size(self.adata_manager, self.summary_stats["n_batch"]) + library_log_means, library_log_vars = _init_library_size( + self.adata_manager, self.summary_stats["n_batch"] + ) # self.summary_stats provides information about anndata dimensions and other tensor info @@ -108,9 +106,15 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), - NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + CategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys + ), + NumericalJointObsField( + REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys + ), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/simple_scvi/_mymodule.py b/src/simple_scvi/_mymodule.py index 28bf534..ed0f725 100644 --- a/src/simple_scvi/_mymodule.py +++ b/src/simple_scvi/_mymodule.py @@ -59,8 +59,12 @@ def __init__( # this is needed to comply with some requirement of the VAEMixin class self.latent_distribution = "normal" - self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float()) - self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float()) + self.register_buffer( + "library_log_means", torch.from_numpy(library_log_means).float() + ) + self.register_buffer( + "library_log_vars", torch.from_numpy(library_log_vars).float() + ) # setup the parameters of your generative model, as well as your inference model self.px_r = torch.nn.Parameter(torch.randn(n_input)) @@ -129,7 +133,9 @@ def generative(self, z, library): px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) px_r = torch.exp(self.px_r) - return dict(px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout) + return dict( + px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout + ) def loss( self, @@ -151,12 +157,18 @@ def loss( mean = torch.zeros_like(qz_m) scale = torch.ones_like(qz_v) - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1) + kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( + dim=1 + ) batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) - local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) + local_library_log_means = F.linear( + one_hot(batch_index, n_batch), self.library_log_means + ) + local_library_log_vars = F.linear( + one_hot(batch_index, n_batch), self.library_log_vars + ) kl_divergence_l = kl( Normal(ql_m, torch.sqrt(ql_v)), @@ -164,7 +176,9 @@ def loss( ).sum(dim=1) reconst_loss = ( - -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout).log_prob(x).sum(dim=-1) + -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) + .log_prob(x) + .sum(dim=-1) ) kl_local_for_warmup = kl_divergence_z @@ -174,8 +188,12 @@ def loss( loss = torch.mean(reconst_loss + weighted_kl_local) - kl_local = dict(kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z) - return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local) + kl_local = dict( + kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z + ) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local + ) @torch.no_grad() def sample( @@ -216,10 +234,14 @@ def sample( px_rate = generative_outputs["px_rate"] px_dropout = generative_outputs["px_dropout"] - dist = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) + dist = ZeroInflatedNegativeBinomial( + mu=px_rate, theta=px_r, zi_logits=px_dropout + ) if n_samples > 1: - exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) + exprs = dist.sample().permute( + [1, 2, 0] + ) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() @@ -249,11 +271,23 @@ def marginal_ll(self, tensors: TensorDict, n_mc_samples: int): # Log-probabilities n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means) - local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars) - p_l = Normal(local_library_log_means, local_library_log_vars.sqrt()).log_prob(library).sum(dim=-1) - - p_z = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)).log_prob(z).sum(dim=-1) + local_library_log_means = F.linear( + one_hot(batch_index, n_batch), self.library_log_means + ) + local_library_log_vars = F.linear( + one_hot(batch_index, n_batch), self.library_log_vars + ) + p_l = ( + Normal(local_library_log_means, local_library_log_vars.sqrt()) + .log_prob(library) + .sum(dim=-1) + ) + + p_z = ( + Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) + .log_prob(z) + .sum(dim=-1) + ) p_x_zl = -reconst_loss q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) diff --git a/src/simple_scvi/_mypyromodel.py b/src/simple_scvi/_mypyromodel.py index 104483f..730040f 100644 --- a/src/simple_scvi/_mypyromodel.py +++ b/src/simple_scvi/_mypyromodel.py @@ -6,12 +6,8 @@ from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import ( - CategoricalJointObsField, - CategoricalObsField, - LayerField, - NumericalJointObsField, -) +from scvi.data.fields import (CategoricalJointObsField, CategoricalObsField, + LayerField, NumericalJointObsField) from scvi.dataloaders import DataSplitter from scvi.model.base import BaseModelClass from scvi.train import PyroTrainingPlan, TrainRunner @@ -100,7 +96,9 @@ def get_latent( Low-dimensional representation for each cell """ adata = self._validate_anndata(adata) - scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) latent = [] for tensors in scdl: qz_m = self.module.get_latent(tensors) @@ -197,9 +195,15 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), - NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + CategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys + ), + NumericalJointObsField( + REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys + ), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/simple_scvi/_mypyromodule.py b/src/simple_scvi/_mypyromodule.py index ad6f412..38f0045 100644 --- a/src/simple_scvi/_mypyromodule.py +++ b/src/simple_scvi/_mypyromodule.py @@ -74,7 +74,9 @@ def model(self, x: torch.Tensor, log_library: torch.Tensor, kl_weight: float = 1 # decode the latent code z px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library) # build count distribution - nb_logits = (px_rate + self.epsilon).log() - (self.px_r.exp() + self.epsilon).log() + nb_logits = (px_rate + self.epsilon).log() - ( + self.px_r.exp() + self.epsilon + ).log() x_dist = dist.ZeroInflatedNegativeBinomial( gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits )