From b9ab050b2ff21565c2e4a3512830c6d4072631ca Mon Sep 17 00:00:00 2001 From: Thomas M Kehrenberg Date: Mon, 2 Oct 2023 16:34:26 +0200 Subject: [PATCH] Make the Dataset type alias generic And specialize more of the generic types. --- src/algs/adv/base.py | 13 +++---------- src/algs/adv/evaluator.py | 33 +++++++++++++++++++-------------- src/data/common.py | 7 ++++--- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/algs/adv/base.py b/src/algs/adv/base.py index f4d056af..96c96d16 100644 --- a/src/algs/adv/base.py +++ b/src/algs/adv/base.py @@ -128,7 +128,7 @@ def discriminator_step( raise NotImplementedError() def encoder_step( - self, comp: Components, *, batch_tr: TernarySample, x_dep: Tensor, warmup: bool + self, comp: Components, *, batch_tr: TernarySample[Tensor], x_dep: Tensor, warmup: bool ) -> defaultdict[str, float]: logging_dict: defaultdict[str, float] = defaultdict(float) for _ in range(self.ga_steps): @@ -247,7 +247,7 @@ def _predictor_loss( @abstractmethod def _encoder_loss( - self, comp: Components, *, x_dep: Tensor, batch_tr: TernarySample, warmup: bool + self, comp: Components, *, x_dep: Tensor, batch_tr: TernarySample[Tensor], warmup: bool ) -> tuple[Tensor, dict[str, float]]: raise NotImplementedError() @@ -297,14 +297,7 @@ def fit(self, dm: DataModule, *, ae: SplitLatentAe, disc: Model, evaluator: Eval comp = Components(ae=ae, disc=disc, pred_y=pred_y, pred_s=pred_s) comp.to(self.device) - val_freq = max( - ( - self.val_freq - if isinstance(self.val_freq, int) - else round(self.val_freq * self.steps) - ), - 1, - ) + val_freq = max(f if isinstance(f := self.val_freq, int) else round(f * self.steps), 1) with tqdm(total=self.steps, desc="Training", colour=self._PBAR_COL) as pbar: for step in range(1, self.steps + 1): logging_dict = self.training_step( diff --git a/src/algs/adv/evaluator.py b/src/algs/adv/evaluator.py index 9c939cbd..c6ba9814 100644 --- a/src/algs/adv/evaluator.py +++ b/src/algs/adv/evaluator.py @@ -40,8 +40,8 @@ ] -DY = TypeVar("DY", bound=Optional[Dataset]) -DS = TypeVar("DS", bound=Optional[Dataset]) +DY = TypeVar("DY", bound=Optional[Dataset[Tensor]]) +DS = TypeVar("DS", bound=Optional[Dataset[Tensor]]) class EvalTrainData(Enum): @@ -58,10 +58,15 @@ class InvariantDatasets(Generic[DY, DS]): def log_sample_images( - *, data: CdtVisionDataset, dm: DataModule, name: str, step: int, num_samples: int = 64 + *, + data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor], + dm: DataModule, + name: str, + step: int, + num_samples: int = 64, ) -> None: - inds = torch.randperm(len(data))[:num_samples] - images = data[inds.tolist()] + inds: list[int] = torch.randperm(len(data))[:num_samples].tolist() + images = data[inds] log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step) @@ -73,39 +78,39 @@ def log_sample_images( @overload def encode_dataset( - dl: CdtDataLoader[TernarySample], + dl: CdtDataLoader[TernarySample[Tensor]], *, encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["y"] = ..., -) -> InvariantDatasets[Dataset, None]: +) -> InvariantDatasets[Dataset[Tensor], None]: ... @overload def encode_dataset( - dl: CdtDataLoader[TernarySample], + dl: CdtDataLoader[TernarySample[Tensor]], *, encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["s"] = ..., -) -> InvariantDatasets[None, Dataset]: +) -> InvariantDatasets[None, Dataset[Tensor]]: ... @overload def encode_dataset( - dl: CdtDataLoader[TernarySample], + dl: CdtDataLoader[TernarySample[Tensor]], *, encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["both"], -) -> InvariantDatasets[Dataset, Dataset]: +) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]: ... def encode_dataset( - dl: CdtDataLoader[TernarySample], + dl: CdtDataLoader[TernarySample[Tensor]], *, encoder: SplitLatentAe, device: Union[str, torch.device], @@ -145,7 +150,7 @@ def encode_dataset( return InvariantDatasets(inv_y=inv_y, inv_s=inv_s) -def _log_enc_statistics(encoded: Dataset, *, step: Optional[int], s_count: int) -> None: +def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None: """Compute and log statistics about the encoding.""" x, y, s = encoded.x, encoded.y, encoded.s class_ids = labels_to_group_id(s=s, y=y, s_count=s_count) @@ -154,7 +159,7 @@ def _log_enc_statistics(encoded: Dataset, *, step: Optional[int], s_count: int) mapper = umap.UMAP(n_neighbors=25, n_components=2) # type: ignore umap_z = mapper.fit_transform(x.numpy()) umap_plot = visualize_clusters(umap_z, labels=class_ids, s_count=s_count) - to_log = {"umap": wandb.Image(umap_plot)} + to_log: dict[str, Union[wandb.Image, float]] = {"umap": wandb.Image(umap_plot)} logger.info("Done.") for y_value in y.unique(): diff --git a/src/data/common.py b/src/data/common.py index fbbd7359..9406d08f 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from pathlib import Path import platform -from typing import Any, Final, Generic, TypeVar, Union +from typing import Final, Generic, TypeVar, Union from typing_extensions import TypeAlias -from conduit.data import LoadedData, TernarySample +from conduit.data import LoadedData, TernarySample, UnloadedData from conduit.data.datasets import CdtDataset from conduit.data.datasets.vision import CdtVisionDataset from hydra.utils import to_absolute_path @@ -44,7 +44,8 @@ def process_data_dir(root: Union[Path, str, None]) -> Path: return Path(to_absolute_path(str(root))).resolve() -Dataset: TypeAlias = CdtDataset[TernarySample[LoadedData], Any, Tensor, Tensor] +X = TypeVar("X", bound=UnloadedData) +Dataset: TypeAlias = CdtDataset[TernarySample[LoadedData], X, Tensor, Tensor] D = TypeVar("D", bound=Dataset)