From 34824c5c9c16a2267a940c94ad7f31a2aa8f2751 Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Fri, 23 Feb 2024 23:36:44 +0100 Subject: [PATCH] Big spring clean up --- .github/workflows/ci.yml | 43 ++++++++++++- poetry.lock | 110 +--------------------------------- pyproject.toml | 3 +- src/algs/adv/evaluator.py | 25 ++++---- src/algs/adv/supmatch.py | 6 +- src/algs/fs/gdro.py | 8 +-- src/algs/fs/lff.py | 30 ++++------ src/arch/autoencoder/vqgan.py | 4 +- src/arch/backbones/vision.py | 2 +- src/data/common.py | 11 ++-- src/data/splitter.py | 7 +-- src/discrete.py | 4 +- src/labelling/encoder.py | 2 +- src/labelling/metrics.py | 12 ++-- src/mmd.py | 4 +- src/models/autoencoder.py | 35 ++++++----- src/models/base.py | 39 ++++++++---- src/models/classifier.py | 8 +-- src/models/discriminator.py | 3 +- src/relay/base.py | 6 +- src/relay/mimin.py | 3 +- src/relay/supmatch.py | 3 +- 22 files changed, 155 insertions(+), 213 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6e5707a5..eae86c7f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: - main jobs: - format_with_black: + format_with_ruff: runs-on: ubuntu-latest @@ -40,3 +40,44 @@ jobs: - name: Lint with ruff run: | ruff check --output-format=github . + + run_type_checking: + needs: + - format_with_black + - lint_with_ruff + runs-on: ubuntu-latest + + steps: + # ---------------------------------------------- + # ---- check-out repo and set-up python ---- + # ---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v3 + # ---------------------------------------------- + # ----- install & configure poetry ----- + # ---------------------------------------------- + - name: Install poetry + run: pipx install poetry + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'poetry' + + # ---------------------------------------------- + # install dependencies if cache does not exist + # ---------------------------------------------- + - name: Install dependencies + run: | + poetry env use 3.10 + poetry install --no-interaction --no-root --without torch + - name: Set python path for all subsequent actions + run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH + + # ---------------------------------------------- + # ----- install and run pyright ----- + # ---------------------------------------------- + - uses: jakebailey/pyright-action@v1 + with: + # don't show warnings + level: error diff --git a/poetry.lock b/poetry.lock index c5f8038f..9399af71 100644 --- a/poetry.lock +++ b/poetry.lock @@ -557,20 +557,6 @@ metrics = ["scikit-learn (>=0.20.1)"] models = ["GitPython (>=3.1.20,<4.0.0)", "cloudpickle (>=2.0.0,<3.0.0)", "fairlearn (==0.8.0)", "gitdb2 (==4.0.2)", "pdm (>=2.4.0,<3.0.0)", "scikit-learn (>=0.20.1)", "scipy (>=1.7.2,<2.0.0)", "smmap2 (==3.0.1)"] plot = ["matplotlib (>=3.0.2)", "seaborn (>=0.9.0)"] -[[package]] -name = "exceptiongroup" -version = "1.1.1" -description = "Backport of PEP 654 (exception groups)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, - {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, -] - -[package.extras] -test = ["pytest (>=6)"] - [[package]] name = "filelock" version = "3.12.0" @@ -1054,17 +1040,6 @@ pyav = ["av"] test = ["fsspec[github]", "pytest", "pytest-cov"] tifffile = ["tifffile"] -[[package]] -name = "iniconfig" -version = "2.0.0" -description = "brain-dead simple config-ini parsing" -optional = false -python-versions = ">=3.7" -files = [ - {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, - {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, -] - [[package]] name = "jinja2" version = "2.11.3" @@ -1503,52 +1478,6 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] -[[package]] -name = "mypy" -version = "1.3.0" -description = "Optional static typing for Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "mypy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c1eb485cea53f4f5284e5baf92902cd0088b24984f4209e25981cc359d64448d"}, - {file = "mypy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c99c3ecf223cf2952638da9cd82793d8f3c0c5fa8b6ae2b2d9ed1e1ff51ba85"}, - {file = "mypy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550a8b3a19bb6589679a7c3c31f64312e7ff482a816c96e0cecec9ad3a7564dd"}, - {file = "mypy-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cbc07246253b9e3d7d74c9ff948cd0fd7a71afcc2b77c7f0a59c26e9395cb152"}, - {file = "mypy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:a22435632710a4fcf8acf86cbd0d69f68ac389a3892cb23fbad176d1cddaf228"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6e33bb8b2613614a33dff70565f4c803f889ebd2f859466e42b46e1df76018dd"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7d23370d2a6b7a71dc65d1266f9a34e4cde9e8e21511322415db4b26f46f6b8c"}, - {file = "mypy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:658fe7b674769a0770d4b26cb4d6f005e88a442fe82446f020be8e5f5efb2fae"}, - {file = "mypy-1.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e42d29e324cdda61daaec2336c42512e59c7c375340bd202efa1fe0f7b8f8ca"}, - {file = "mypy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:d0b6c62206e04061e27009481cb0ec966f7d6172b5b936f3ead3d74f29fe3dcf"}, - {file = "mypy-1.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:76ec771e2342f1b558c36d49900dfe81d140361dd0d2df6cd71b3db1be155409"}, - {file = "mypy-1.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc95f8386314272bbc817026f8ce8f4f0d2ef7ae44f947c4664efac9adec929"}, - {file = "mypy-1.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:faff86aa10c1aa4a10e1a301de160f3d8fc8703b88c7e98de46b531ff1276a9a"}, - {file = "mypy-1.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8c5979d0deb27e0f4479bee18ea0f83732a893e81b78e62e2dda3e7e518c92ee"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c5d2cc54175bab47011b09688b418db71403aefad07cbcd62d44010543fc143f"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:87df44954c31d86df96c8bd6e80dfcd773473e877ac6176a8e29898bfb3501cb"}, - {file = "mypy-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473117e310febe632ddf10e745a355714e771ffe534f06db40702775056614c4"}, - {file = "mypy-1.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:74bc9b6e0e79808bf8678d7678b2ae3736ea72d56eede3820bd3849823e7f305"}, - {file = "mypy-1.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:44797d031a41516fcf5cbfa652265bb994e53e51994c1bd649ffcd0c3a7eccbf"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ddae0f39ca146972ff6bb4399f3b2943884a774b8771ea0a8f50e971f5ea5ba8"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c4c42c60a8103ead4c1c060ac3cdd3ff01e18fddce6f1016e08939647a0e703"}, - {file = "mypy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e86c2c6852f62f8f2b24cb7a613ebe8e0c7dc1402c61d36a609174f63e0ff017"}, - {file = "mypy-1.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f9dca1e257d4cc129517779226753dbefb4f2266c4eaad610fc15c6a7e14283e"}, - {file = "mypy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:95d8d31a7713510685b05fbb18d6ac287a56c8f6554d88c19e73f724a445448a"}, - {file = "mypy-1.3.0-py3-none-any.whl", hash = "sha256:a8763e72d5d9574d45ce5881962bc8e9046bf7b375b0abf031f3e6811732a897"}, - {file = "mypy-1.3.0.tar.gz", hash = "sha256:e1f4d16e296f5135624b34e8fb741eb0eadedca90862405b1f1fde2040b9bd11"}, -] - -[package.dependencies] -mypy-extensions = ">=1.0.0" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" - -[package.extras] -dmypy = ["psutil (>=4.0)"] -install-types = ["pip"] -python2 = ["typed-ast (>=1.4.0,<2)"] -reports = ["lxml"] - [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1940,21 +1869,6 @@ files = [ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-removed-in", "sphinxext-opengraph"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -[[package]] -name = "pluggy" -version = "1.0.0" -description = "plugin and hook calling mechanisms for python" -optional = false -python-versions = ">=3.6" -files = [ - {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, - {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, -] - -[package.extras] -dev = ["pre-commit", "tox"] -testing = ["pytest", "pytest-benchmark"] - [[package]] name = "protobuf" version = "4.23.0" @@ -2034,28 +1948,6 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] -[[package]] -name = "pytest" -version = "7.3.1" -description = "pytest: simple powerful testing with Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"}, - {file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} -iniconfig = "*" -packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} - -[package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] - [[package]] name = "python-dateutil" version = "2.8.2" @@ -3327,4 +3219,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "3ef3211082c1cfde4656cf3c15d88d459f47101c542753b26843229f9a832c3f" +content-hash = "ab002e6142da0d933f3da4085b997984ea27afb49ad5e2206fdff392b66935a6" diff --git a/pyproject.toml b/pyproject.toml index 868456ca..486242f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,6 @@ torchvision = ">=0.15.2" [tool.poetry.group.dev.dependencies] ruff = "*" -mypy = "*" -pytest = "*" types-tqdm = "*" pandas-stubs = "*" @@ -124,6 +122,7 @@ reportUnknownLambdaType = "none" reportUnknownVariableType = "none" reportUnknownMemberType = "none" reportMissingTypeArgument = "none" +reportUnnecessaryCast = "warning" reportUnnecessaryTypeIgnoreComment = "warning" exclude = [ "outputs", diff --git a/src/algs/adv/evaluator.py b/src/algs/adv/evaluator.py index f84ee648..bc88a332 100644 --- a/src/algs/adv/evaluator.py +++ b/src/algs/adv/evaluator.py @@ -5,7 +5,6 @@ from conduit.data import TernarySample from conduit.data.datasets import CdtDataLoader, CdtDataset -from conduit.data.datasets.vision import CdtVisionDataset from loguru import logger from matplotlib import pyplot as plt from matplotlib.colors import ListedColormap @@ -22,14 +21,12 @@ from src.arch.predictors import Fcn from src.data import DataModule, Dataset, group_id_to_label, labels_to_group_id, resolve_device from src.evaluation.metrics import EmEvalPair, compute_metrics -from src.logging import log_images from src.models import Classifier, OptimizerCfg, SplitLatentAe __all__ = [ "Evaluator", "InvariantDatasets", "encode_dataset", - "log_sample_images", "visualize_clusters", ] @@ -51,17 +48,17 @@ class InvariantDatasets(Generic[DY, DS]): zy: DS -def log_sample_images( - *, - data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor], - dm: DataModule, - name: str, - step: int, - num_samples: int = 64, -) -> None: - inds: list[int] = torch.randperm(len(data))[:num_samples].tolist() - images = data[inds] - log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step) +# def log_sample_images( +# *, +# data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor], +# dm: DataModule, +# name: str, +# step: int, +# num_samples: int = 64, +# ) -> None: +# inds: list[int] = torch.randperm(len(data))[:num_samples].tolist() +# images = data[inds] +# log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step) InvariantAttr = Literal["zy", "zs", "both"] diff --git a/src/algs/adv/supmatch.py b/src/algs/adv/supmatch.py index 7768d989..13f7a14e 100644 --- a/src/algs/adv/supmatch.py +++ b/src/algs/adv/supmatch.py @@ -33,10 +33,9 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]: dl_tr = dm.train_dataloader(balance=True) # The batch size needs to be consistent for the aggregation layer in the setwise neural # discriminator + batch_size: int = dl_tr.batch_sampler.batch_size # type: ignore dl_dep = dm.deployment_dataloader( - batch_size=dl_tr.batch_sampler.batch_size - if dm.deployment_ids is None - else dm.batch_size_tr + batch_size=batch_size if dm.deployment_ids is None else dm.batch_size_tr ) return iter(dl_tr), iter(dl_dep) @@ -161,6 +160,7 @@ def fit_evaluate_score( disc_model_sd0 = None if isinstance(disc, NeuralDiscriminator) and isinstance(disc.model, SetPredictor): disc_model_sd0 = disc.model.state_dict() + assert isinstance(disc, Model) super().fit_and_evaluate(dm=dm, ae=ae, disc=disc, evaluator=evaluator) # TODO: Generalise this to other discriminator types and architectures if disc_model_sd0 is not None: diff --git a/src/algs/fs/gdro.py b/src/algs/fs/gdro.py index af9a8540..186d64e4 100644 --- a/src/algs/fs/gdro.py +++ b/src/algs/fs/gdro.py @@ -232,18 +232,14 @@ def update_stats( self.avg_acc = group_frac @ self.avg_group_acc -@dataclass -class _LcMixin: +@dataclass(kw_only=True, repr=False, eq=False, frozen=True) +class GdroClassifier(Classifier): loss_computer: LossComputer - -@dataclass(repr=False, eq=False) -class GdroClassifier(Classifier, _LcMixin): def __post_init__(self) -> None: # LossComputer requires that the criterion return per-sample (unreduced) losses. if self.criterion is not None: self.criterion.reduction = ReductionType.none - super().__post_init__() @override def training_step(self, batch: TernarySample[Tensor], *, pred_s: bool = False) -> Tensor: diff --git a/src/algs/fs/lff.py b/src/algs/fs/lff.py index 061bda29..87dd6121 100644 --- a/src/algs/fs/lff.py +++ b/src/algs/fs/lff.py @@ -1,5 +1,6 @@ from collections.abc import Iterator -from dataclasses import dataclass, field +from dataclasses import dataclass +from functools import cached_property from typing import Any, TypeVar, Union from typing_extensions import Self, override @@ -98,24 +99,20 @@ def __len__(self) -> int: return len(self.dataset) -@dataclass -class _LabelEmaMixin: +@dataclass(kw_only=True, repr=False, eq=False, frozen=True) +class LfFClassifier(Classifier): + criterion: CrossEntropyLoss sample_loss_ema_b: LabelEma sample_loss_ema_d: LabelEma - - -@dataclass(repr=False, eq=False) -class LfFClassifier(Classifier, _LabelEmaMixin): q: float = 0.7 - biased_model: nn.Module = field(init=False) - biased_criterion: GeneralizedCELoss = field(init=False) - criterion: CrossEntropyLoss = field(init=False) - def __post_init__(self) -> None: - self.biased_model = gcopy(self.model, deep=True) - self.biased_criterion = GeneralizedCELoss(q=self.q, reduction="mean") - self.criterion = CrossEntropyLoss(reduction="mean") - super().__post_init__() + @cached_property + def biased_model(self) -> nn.Module: + return gcopy(self.model, deep=True) + + @cached_property + def biased_criterion(self) -> GeneralizedCELoss: + return GeneralizedCELoss(q=self.q, reduction="mean") def training_step(self, batch: IndexedSample[Tensor], *, pred_s: bool = False) -> Tensor: # type: ignore logit_b = self.biased_model(batch.x) @@ -158,14 +155,13 @@ def routine(self, dm: DataModule, *, model: nn.Module) -> EvalTuple[Tensor, None sample_loss_ema_d = LabelEma(dm.train.y, alpha=self.alpha).to(self.device) dm.train = IndexedDataset(dm.train) # type: ignore classifier = LfFClassifier( + criterion=CrossEntropyLoss(reduction="mean"), sample_loss_ema_b=sample_loss_ema_b, sample_loss_ema_d=sample_loss_ema_d, model=model, opt=self.opt, q=self.q, ) - classifier.sample_loss_ema_b = sample_loss_ema_b - classifier.sample_loss_ema_d = sample_loss_ema_d classifier.fit( train_data=dm.train_dataloader(), test_data=dm.test_dataloader(), diff --git a/src/arch/autoencoder/vqgan.py b/src/arch/autoencoder/vqgan.py index f825c91b..7f77a85c 100644 --- a/src/arch/autoencoder/vqgan.py +++ b/src/arch/autoencoder/vqgan.py @@ -195,7 +195,7 @@ def __init__( # end self.norm_out = Normalize(block_in) - flattened_size = np.prod((block_in, curr_res, curr_res)) + flattened_size = np.prod((block_in, curr_res, curr_res)).item() self.to_latent = nn.Sequential( nn.Flatten(), nn.Linear(flattened_size, out_features=latent_dim), @@ -253,7 +253,7 @@ def __init__( curr_res = resolution // 2 ** (self.num_resolutions - 1) unflattened_size = (block_in, curr_res, curr_res) self.from_latent = nn.Sequential( - nn.Linear(latent_dim, np.prod(unflattened_size)), + nn.Linear(latent_dim, np.prod(unflattened_size).item()), nn.Unflatten(dim=1, unflattened_size=unflattened_size), ) diff --git a/src/arch/backbones/vision.py b/src/arch/backbones/vision.py index 4d2acfc0..51799b67 100644 --- a/src/arch/backbones/vision.py +++ b/src/arch/backbones/vision.py @@ -162,7 +162,7 @@ def __call__(self, input_dim: int) -> BackboneFactoryOut["tm.SwinTransformer"]: model: "tm.SwinTransformer" = timm.create_model( self.version.value, pretrained=self.pretrained, checkpoint_path=self.checkpoint_path ) - model.head = nn.Identity() + model.head = nn.Identity() # type: ignore return model, model.num_features diff --git a/src/data/common.py b/src/data/common.py index fea35efb..05733c3f 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -3,12 +3,13 @@ from dataclasses import dataclass from pathlib import Path import platform -from typing import Final, Generic, TypeAlias, TypeVar, Union +from typing import Final, Generic, TypeVar, Union +from typing_extensions import TypeAliasType from conduit.data import LoadedData, TernarySample, UnloadedData from conduit.data.datasets import CdtDataset -from conduit.data.datasets.vision import CdtVisionDataset from hydra.utils import to_absolute_path +from numpy import typing as npt from torch import Tensor __all__ = [ @@ -44,7 +45,9 @@ def process_data_dir(root: Union[Path, str, None]) -> Path: X = TypeVar("X", bound=UnloadedData) -Dataset: TypeAlias = CdtDataset[TernarySample[LoadedData], X, Tensor, Tensor] +Dataset = TypeAliasType( + "Dataset", CdtDataset[TernarySample[LoadedData], X, Tensor, Tensor], type_params=(X,) +) D = TypeVar("D", bound=Dataset) @@ -75,5 +78,5 @@ def num_samples_te(self) -> int: class DatasetFactory(ABC): @abstractmethod - def __call__(self) -> CdtVisionDataset[TernarySample, Tensor, Tensor]: + def __call__(self) -> Dataset[npt.NDArray]: raise NotImplementedError() diff --git a/src/data/splitter.py b/src/data/splitter.py index ef06cc15..952366ba 100644 --- a/src/data/splitter.py +++ b/src/data/splitter.py @@ -257,14 +257,11 @@ def load_split_inds_from_artifact( return split_inds -@dataclass(eq=False) -class _ArtifactLoaderMixin: +@dataclass(eq=False, kw_only=True) +class SplitFromArtifact(DataSplitter): artifact_name: str version: Optional[int] = None - -@dataclass(eq=False) -class SplitFromArtifact(DataSplitter, _ArtifactLoaderMixin): @override def split(self, dataset: D) -> TrainDepTestSplit[D]: splits = load_split_inds_from_artifact( diff --git a/src/discrete.py b/src/discrete.py index dd39f9d1..8168b2ec 100644 --- a/src/discrete.py +++ b/src/discrete.py @@ -14,11 +14,11 @@ def forward(ctx: NestedIOFunction, tensor: Tensor) -> Tensor: # type: ignore @staticmethod def backward(ctx: NestedIOFunction, *grad_outputs: Tensor) -> tuple[Tensor]: """Straight-through estimator""" - return grad_outputs + return grad_outputs # type: ignore def round_ste(x: Tensor) -> Tensor: - return RoundSTE.apply(x) + return RoundSTE.apply(x) # type: ignore def discretize(inputs: Tensor, *, dim: int = 1) -> Tensor: diff --git a/src/labelling/encoder.py b/src/labelling/encoder.py index 222551fd..16fad20e 100644 --- a/src/labelling/encoder.py +++ b/src/labelling/encoder.py @@ -34,7 +34,7 @@ def __init__( ) -> None: super().__init__() logger.info("Loading CLIP model (downloading if needed)...") - import clip + import clip # type: ignore model, self.transforms = clip.load( name=version.value, # type: ignore diff --git a/src/labelling/metrics.py b/src/labelling/metrics.py index 87083bc2..485017a9 100644 --- a/src/labelling/metrics.py +++ b/src/labelling/metrics.py @@ -4,7 +4,7 @@ from loguru import logger import numpy as np import numpy.typing as npt -from scipy.optimize import linear_sum_assignment +from scipy.optimize import linear_sum_assignment # type: ignore from sklearn.metrics import ( adjusted_mutual_info_score, adjusted_rand_score, @@ -54,11 +54,11 @@ def evaluate( use_wandb: bool = True, prefix: Optional[str] = None, ) -> dict[str, float]: - metrics = { - "ARI": adjusted_rand_score(y_true, y_pred), - "AMI": adjusted_mutual_info_score(y_true, y_pred), - "NMI": normalized_mutual_info_score(y_true, y_pred), - "Accuracy": compute_accuracy(y_true, clusters=y_pred), + metrics: dict[str, float] = { + "ARI": float(adjusted_rand_score(y_true, y_pred)), + "AMI": float(adjusted_mutual_info_score(y_true, y_pred)), + "NMI": float(normalized_mutual_info_score(y_true, y_pred)), + "Accuracy": float(compute_accuracy(y_true, clusters=y_pred)), } if prefix is not None: metrics = prefix_keys(metrics, prefix=prefix, sep="/") diff --git a/src/mmd.py b/src/mmd.py index d0952d79..64f8b205 100644 --- a/src/mmd.py +++ b/src/mmd.py @@ -35,8 +35,8 @@ def _mix_rq_kernel( Rational quadratic kernel http://www.cs.toronto.edu/~duvenaud/cookbook/index.html """ - scales = (0.1, 1.0, 10.0) or scales - wts = [1.0] * len(scales) or wts + scales = (0.1, 1.0, 10.0) if scales is None else scales + wts = [1.0] * len(scales) if wts is None else wts xx_gm = x @ x.t() xy_gm = x @ y.t() diff --git a/src/models/autoencoder.py b/src/models/autoencoder.py index ae5595b8..249ee639 100644 --- a/src/models/autoencoder.py +++ b/src/models/autoencoder.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, replace from enum import Enum, auto +from functools import cached_property from typing import Literal, Optional, Union, cast from typing_extensions import Self, override @@ -71,7 +72,7 @@ def mask(self, random: bool = False, *, detach: bool = False) -> tuple[Self, Sel else: zs_m = SplitEncoding(zs=torch.zeros_like(zs), zy=zy) zy_m = SplitEncoding(zs=zs, zy=torch.zeros_like(zy)) - return zs_m, zy_m + return zs_m, zy_m # type: ignore @dataclass @@ -111,29 +112,35 @@ class SplitAeOptimizerCfg(OptimizerCfg): recon_loss: ReconstructionLoss = ReconstructionLoss.l2 -@dataclass(repr=False, eq=False) +@dataclass(repr=False, eq=False, frozen=True) class SplitLatentAe(Model): model: AePair # overriding the definition in `Model` opt: SplitAeOptimizerCfg # overriding the definition in `Model` feature_group_slices: Optional[dict[str, list[slice]]] = None - recon_loss_fn: Callable[[Tensor, Tensor], Tensor] = field(init=False) - zs_dim: int = field(init=False) - - def __post_init__(self) -> None: - zs_dim_t = self.opt.zs_dim - self.latent_dim: int = self.model.latent_dim - self.zs_dim = round(zs_dim_t * self.latent_dim) if isinstance(zs_dim_t, float) else zs_dim_t - self.encoding_size = EncodingSize(zs=self.zs_dim, zy=self.latent_dim - self.zs_dim) + @cached_property + def recon_loss_fn(self) -> Callable[[Tensor, Tensor], Tensor]: if self.opt.recon_loss is ReconstructionLoss.mixed: if self.feature_group_slices is None: raise ValueError("'MixedLoss' requires 'feature_group_slices' to be specified.") - self.recon_loss_fn = self.opt.recon_loss.value( + return self.opt.recon_loss.value( reduction="sum", feature_group_slices=self.feature_group_slices ) else: - self.recon_loss_fn = self.opt.recon_loss.value(reduction="sum") - super().__post_init__() + return self.opt.recon_loss.value(reduction="sum") + + @cached_property + def zs_dim(self) -> int: + zs_dim_t = self.opt.zs_dim + return round(zs_dim_t * self.latent_dim) if isinstance(zs_dim_t, float) else zs_dim_t + + @cached_property + def encoding_size(self) -> EncodingSize: + return EncodingSize(zs=self.zs_dim, zy=self.latent_dim - self.zs_dim) + + @property + def latent_dim(self) -> int: + return self.model.latent_dim def encode(self, inputs: Tensor, *, transform_zs: bool = True) -> SplitEncoding: enc = self._split_encoding(self.model.encoder(inputs)) diff --git a/src/models/base.py b/src/models/base.py index 121b32fc..197feb67 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -1,12 +1,13 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import Any, ClassVar, Optional -from typing_extensions import override +from functools import cached_property +from typing import Any, ClassVar, Optional, cast, final +from typing_extensions import Self, override from conduit.types import LRScheduler from hydra.utils import instantiate +from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from ranzen.torch import DcModule import torch from torch import Tensor from torch.cuda.amp.grad_scaler import GradScaler @@ -34,16 +35,24 @@ class OptimizerCfg: scheduler_kwargs: Optional[dict] = None -@dataclass(repr=False, eq=False) -class Model(DcModule): +@dataclass(unsafe_hash=True, frozen=True) +class FrozenDcModule(nn.Module): + @final + def __new__(cls, *args: Any, **kwargs: Any) -> Self: + obj = object.__new__(cls) + nn.Module.__init__(obj) + return obj + + +@dataclass(repr=False, eq=False, frozen=True) +class Model(FrozenDcModule): _PBAR_COL: ClassVar[str] = "#ffe252" model: nn.Module opt: OptimizerCfg - optimizer: torch.optim.Optimizer = field(init=False) - scheduler: Optional[LRScheduler] = field(init=False, default=None) - def __post_init__(self) -> None: + @cached_property + def optimizer(self) -> torch.optim.Optimizer: optimizer_config = DictConfig({"weight_decay": self.opt.weight_decay, "lr": self.opt.lr}) if self.opt.optimizer_kwargs is not None: optimizer_config.update(self.opt.optimizer_kwargs) @@ -51,12 +60,18 @@ def __post_init__(self) -> None: params = exclude_from_weight_decay( self.named_parameters(), weight_decay=optimizer_config["weight_decay"] ) - self.optimizer = self.opt.optimizer_cls.value(**optimizer_config, params=params) + kwargs = OmegaConf.to_container(optimizer_config, resolve=True) + assert isinstance(kwargs, dict) + return self.opt.optimizer_cls.value(**cast(dict[str, Any], kwargs), params=params) + + @cached_property + def scheduler(self) -> Optional[LRScheduler]: if self.opt.scheduler_cls is not None: scheduler_config = DictConfig({"_target_": self.opt.scheduler_cls}) if self.opt.scheduler_kwargs is not None: scheduler_config.update(self.opt.scheduler_kwargs) - self.scheduler = instantiate(scheduler_config, optimizer=self.optimizer) + return instantiate(scheduler_config, optimizer=self.optimizer) + return None def step(self, grad_scaler: Optional[GradScaler] = None, scaler_update: bool = True) -> None: if grad_scaler is None: @@ -69,5 +84,5 @@ def step(self, grad_scaler: Optional[GradScaler] = None, scaler_update: bool = T self.scheduler.step() @override - def forward(self, inputs: Tensor) -> Any: # type: ignore + def forward(self, inputs: Tensor) -> Any: return self.model(inputs) diff --git a/src/models/classifier.py b/src/models/classifier.py index ab7028c3..7283a35a 100644 --- a/src/models/classifier.py +++ b/src/models/classifier.py @@ -36,7 +36,7 @@ def cat_cpu_flatten(*ls: list[Tensor], dim: int = 0) -> Iterator[Tensor]: yield torch.cat(ls_, dim=dim).cpu().flatten() -@dataclass(repr=False, eq=False) +@dataclass(repr=False, eq=False, frozen=True) class Classifier(Model): """Wrapper for classifier models equipped witht training/inference routines.""" @@ -152,7 +152,7 @@ def fit( self.model.train() if use_wandb: wandb.log(log_dict) - pbar.set_postfix(**log_dict) + pbar.set_postfix(**log_dict) # type: ignore pbar.update() pbar.close() @@ -167,7 +167,7 @@ class _ScSample(BinarySample[Tensor]): S = TypeVar("S", bound=SampleBase[Tensor]) -@dataclass(repr=False, eq=False) +@dataclass(repr=False, eq=False, frozen=True) class SetClassifier(Model): """Wrapper for set classifier models equipped witht training/inference routines.""" @@ -227,7 +227,7 @@ def fit( loss.backward() # type: ignore self.step(grad_scaler=grad_scaler, scaler_update=True) self.optimizer.zero_grad() - pbar.set_postfix(**log_dict) + pbar.set_postfix(**log_dict) # type: ignore pbar.update() pbar.close() diff --git a/src/models/discriminator.py b/src/models/discriminator.py index e2c5091e..4c3d9149 100644 --- a/src/models/discriminator.py +++ b/src/models/discriminator.py @@ -74,14 +74,13 @@ class DiscOptimizerCfg(OptimizerCfg): criterion: GanLoss = GanLoss.LOGISTIC_NS -@dataclass(repr=False, eq=False) +@dataclass(repr=False, eq=False, frozen=True) class NeuralDiscriminator(BinaryDiscriminator, Model): opt: DiscOptimizerCfg # overriding the definition in `Model` def __post_init__(self) -> None: if self.opt.criterion is GanLoss.WASSERSTEIN: self.model.apply(_maybe_spectral_norm) - super().__post_init__() @override def discriminator_loss(self, fake: Tensor, *, real: Tensor) -> Tensor: diff --git a/src/relay/base.py b/src/relay/base.py index de2f2ec5..4f9ccbe8 100644 --- a/src/relay/base.py +++ b/src/relay/base.py @@ -2,14 +2,12 @@ from typing import Any, ClassVar from attrs import define, field -from conduit.data import TernarySample -from conduit.data.datasets.vision import CdtVisionDataset from loguru import logger from ranzen.torch import random_seed import torch -from torch import Tensor from src.data import DataModule, DataModuleConf, RandomSplitter, SplitFromArtifact +from src.data.common import Dataset from src.data.splitter import DataSplitter from src.labelling import Labeller from src.logging import WandbConf @@ -30,7 +28,7 @@ class BaseRelay: def init_dm( self, - ds: CdtVisionDataset[TernarySample, Tensor, Tensor], + ds: Dataset, labeller: Labeller, device: torch.device, ) -> DataModule: diff --git a/src/relay/mimin.py b/src/relay/mimin.py index b3fa0fce..c4673d5c 100644 --- a/src/relay/mimin.py +++ b/src/relay/mimin.py @@ -72,7 +72,8 @@ def run(self, raw_config: Optional[dict[str, Any]] = None) -> None: ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch)) dm = self.init_dm(ds, self.labeller, device=self.alg.device) - ae_pair = self.ae_arch(input_shape=dm.dim_x) + input_shape: tuple[int, int, int] = dm.dim_x # type: ignore + ae_pair = self.ae_arch(input_shape=input_shape) ae = SplitLatentAe(opt=self.ae, model=ae_pair, feature_group_slices=dm.feature_group_slices) logger.info(f"Encoding dim: {ae.latent_dim}, {ae.encoding_size}") card_s = dm.card_s diff --git a/src/relay/supmatch.py b/src/relay/supmatch.py index 7d080ecb..78831c26 100644 --- a/src/relay/supmatch.py +++ b/src/relay/supmatch.py @@ -102,7 +102,8 @@ def run(self, raw_config: Optional[dict[str, Any]] = None) -> Optional[float]: ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch)) dm = self.init_dm(ds, self.labeller, device=self.alg.device) - ae_pair = self.ae_arch(input_shape=dm.dim_x) + input_shape: tuple[int, int, int] = dm.dim_x # type: ignore + ae_pair = self.ae_arch(input_shape=input_shape) ae = SplitLatentAe(opt=self.ae, model=ae_pair, feature_group_slices=dm.feature_group_slices) logger.info(f"Encoding dim: {ae.latent_dim}, {ae.encoding_size}") disc_net, _ = self.disc_arch(