Skip to content

Commit

Permalink
Make the Dataset type alias generic (#319)
Browse files Browse the repository at this point in the history
And specialize more of the generic types.
  • Loading branch information
tmke8 authored Oct 2, 2023
1 parent 4ea90d1 commit 17c9204
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
13 changes: 3 additions & 10 deletions src/algs/adv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def discriminator_step(
raise NotImplementedError()

def encoder_step(
self, comp: Components, *, batch_tr: TernarySample, x_dep: Tensor, warmup: bool
self, comp: Components, *, batch_tr: TernarySample[Tensor], x_dep: Tensor, warmup: bool
) -> defaultdict[str, float]:
logging_dict: defaultdict[str, float] = defaultdict(float)
for _ in range(self.ga_steps):
Expand Down Expand Up @@ -247,7 +247,7 @@ def _predictor_loss(

@abstractmethod
def _encoder_loss(
self, comp: Components, *, x_dep: Tensor, batch_tr: TernarySample, warmup: bool
self, comp: Components, *, x_dep: Tensor, batch_tr: TernarySample[Tensor], warmup: bool
) -> tuple[Tensor, dict[str, float]]:
raise NotImplementedError()

Expand Down Expand Up @@ -297,14 +297,7 @@ def fit(self, dm: DataModule, *, ae: SplitLatentAe, disc: Model, evaluator: Eval
comp = Components(ae=ae, disc=disc, pred_y=pred_y, pred_s=pred_s)
comp.to(self.device)

val_freq = max(
(
self.val_freq
if isinstance(self.val_freq, int)
else round(self.val_freq * self.steps)
),
1,
)
val_freq = max(f if isinstance(f := self.val_freq, int) else round(f * self.steps), 1)
with tqdm(total=self.steps, desc="Training", colour=self._PBAR_COL) as pbar:
for step in range(1, self.steps + 1):
logging_dict = self.training_step(
Expand Down
33 changes: 19 additions & 14 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
]


DY = TypeVar("DY", bound=Optional[Dataset])
DS = TypeVar("DS", bound=Optional[Dataset])
DY = TypeVar("DY", bound=Optional[Dataset[Tensor]])
DS = TypeVar("DS", bound=Optional[Dataset[Tensor]])


class EvalTrainData(Enum):
Expand All @@ -58,10 +58,15 @@ class InvariantDatasets(Generic[DY, DS]):


def log_sample_images(
*, data: CdtVisionDataset, dm: DataModule, name: str, step: int, num_samples: int = 64
*,
data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor],
dm: DataModule,
name: str,
step: int,
num_samples: int = 64,
) -> None:
inds = torch.randperm(len(data))[:num_samples]
images = data[inds.tolist()]
inds: list[int] = torch.randperm(len(data))[:num_samples].tolist()
images = data[inds]
log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)


Expand All @@ -73,39 +78,39 @@ def log_sample_images(

@overload
def encode_dataset(
dl: CdtDataLoader[TernarySample],
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["y"] = ...,
) -> InvariantDatasets[Dataset, None]:
) -> InvariantDatasets[Dataset[Tensor], None]:
...


@overload
def encode_dataset(
dl: CdtDataLoader[TernarySample],
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["s"] = ...,
) -> InvariantDatasets[None, Dataset]:
) -> InvariantDatasets[None, Dataset[Tensor]]:
...


@overload
def encode_dataset(
dl: CdtDataLoader[TernarySample],
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["both"],
) -> InvariantDatasets[Dataset, Dataset]:
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
...


def encode_dataset(
dl: CdtDataLoader[TernarySample],
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
Expand Down Expand Up @@ -145,7 +150,7 @@ def encode_dataset(
return InvariantDatasets(inv_y=inv_y, inv_s=inv_s)


def _log_enc_statistics(encoded: Dataset, *, step: Optional[int], s_count: int) -> None:
def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None:
"""Compute and log statistics about the encoding."""
x, y, s = encoded.x, encoded.y, encoded.s
class_ids = labels_to_group_id(s=s, y=y, s_count=s_count)
Expand All @@ -154,7 +159,7 @@ def _log_enc_statistics(encoded: Dataset, *, step: Optional[int], s_count: int)
mapper = umap.UMAP(n_neighbors=25, n_components=2) # type: ignore
umap_z = mapper.fit_transform(x.numpy())
umap_plot = visualize_clusters(umap_z, labels=class_ids, s_count=s_count)
to_log = {"umap": wandb.Image(umap_plot)}
to_log: dict[str, Union[wandb.Image, float]] = {"umap": wandb.Image(umap_plot)}
logger.info("Done.")

for y_value in y.unique():
Expand Down
7 changes: 4 additions & 3 deletions src/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from pathlib import Path
import platform
from typing import Any, Final, Generic, TypeVar, Union
from typing import Final, Generic, TypeVar, Union
from typing_extensions import TypeAlias

from conduit.data import LoadedData, TernarySample
from conduit.data import LoadedData, TernarySample, UnloadedData
from conduit.data.datasets import CdtDataset
from conduit.data.datasets.vision import CdtVisionDataset
from hydra.utils import to_absolute_path
Expand Down Expand Up @@ -44,7 +44,8 @@ def process_data_dir(root: Union[Path, str, None]) -> Path:
return Path(to_absolute_path(str(root))).resolve()


Dataset: TypeAlias = CdtDataset[TernarySample[LoadedData], Any, Tensor, Tensor]
X = TypeVar("X", bound=UnloadedData)
Dataset: TypeAlias = CdtDataset[TernarySample[LoadedData], X, Tensor, Tensor]
D = TypeVar("D", bound=Dataset)


Expand Down

0 comments on commit 17c9204

Please sign in to comment.