diff --git a/src/algs/adv/evaluator.py b/src/algs/adv/evaluator.py index c6ba9814..b4db5d59 100644 --- a/src/algs/adv/evaluator.py +++ b/src/algs/adv/evaluator.py @@ -53,8 +53,8 @@ class EvalTrainData(Enum): @dataclass(frozen=True) class InvariantDatasets(Generic[DY, DS]): - inv_y: DY - inv_s: DS + zs: DY + zy: DS def log_sample_images( @@ -70,7 +70,7 @@ def log_sample_images( log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step) -InvariantAttr = Literal["s", "y", "both"] +InvariantAttr = Literal["zy", "zs", "both"] _PBAR_COL: Final[str] = "#ffe252" @@ -82,7 +82,8 @@ def encode_dataset( *, encoder: SplitLatentAe, device: Union[str, torch.device], - invariant_to: Literal["y"] = ..., + segment: Literal["zs"] = ..., + use_amp: bool = False, ) -> InvariantDatasets[Dataset[Tensor], None]: ... @@ -93,7 +94,8 @@ def encode_dataset( *, encoder: SplitLatentAe, device: Union[str, torch.device], - invariant_to: Literal["s"] = ..., + segment: Literal["zy"] = ..., + use_amp: bool = False, ) -> InvariantDatasets[None, Dataset[Tensor]]: ... @@ -104,7 +106,8 @@ def encode_dataset( *, encoder: SplitLatentAe, device: Union[str, torch.device], - invariant_to: Literal["both"], + segment: Literal["both"], + use_amp: bool = False, ) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]: ... @@ -114,40 +117,40 @@ def encode_dataset( *, encoder: SplitLatentAe, device: Union[str, torch.device], - invariant_to: InvariantAttr = "s", + segment: InvariantAttr = "zy", + use_amp: bool = False, ) -> InvariantDatasets: device = resolve_device(device) zy_ls, zs_ls, s_ls, y_ls = [], [], [], [] - with torch.no_grad(): - for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL): - x = batch.x.to(device, non_blocking=True) - s_ls.append(batch.s) - y_ls.append(batch.y) + with torch.cuda.amp.autocast(enabled=use_amp): # type: ignore + with torch.no_grad(): + for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL): + x = batch.x.to(device, non_blocking=True) + s_ls.append(batch.s) + y_ls.append(batch.y) - # don't do the zs transform here because we might want to look at the raw distribution - encodings = encoder.encode(x, transform_zs=False) + # don't do the zs transform here because we might want to look at the raw distribution + encodings = encoder.encode(x, transform_zs=False) - if invariant_to in ("s", "both"): - zy_ls.append(encodings.zy.detach().cpu()) + if segment in ("zy", "both"): + zy_ls.append(encodings.zy.detach().cpu()) - if invariant_to in ("y", "both"): - zs_ls.append(encodings.zs.detach().cpu()) + if segment in ("zs", "both"): + zs_ls.append(encodings.zs.detach().cpu()) s_ls = torch.cat(s_ls, dim=0) y_ls = torch.cat(y_ls, dim=0) - inv_y = None + zs_ds = None if zs_ls: - inv_y = torch.cat(zs_ls, dim=0) - inv_y = CdtDataset(x=inv_y, s=s_ls, y=y_ls) + zs_ds = CdtDataset(x=torch.cat(zs_ls, dim=0), s=s_ls, y=y_ls) - inv_s = None + zy_ds = None if zy_ls: - inv_s = torch.cat(zy_ls, dim=0) - inv_s = CdtDataset(x=inv_s, s=s_ls, y=y_ls) + zy_ds = CdtDataset(x=torch.cat(zy_ls, dim=0), s=s_ls, y=y_ls) logger.info("Finished encoding") - return InvariantDatasets(inv_y=inv_y, inv_s=inv_s) + return InvariantDatasets(zs=zs_ds, zy=zy_ds) def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None: @@ -295,25 +298,25 @@ def run( ) -> DataModule: device = resolve_device(device) encoder.eval() - invariant_to = "both" if self.eval_s_from_zs is not None else "s" + segment = "both" if self.eval_s_from_zs is not None else "zy" logger.info("Encoding training set") train_eval = encode_dataset( dl=dm.train_dataloader(eval=True, batch_size=dm.batch_size_te), encoder=encoder, device=device, - invariant_to=invariant_to, + segment=segment, ) logger.info("Encoding test set") test_eval = encode_dataset( - dl=dm.test_dataloader(), encoder=encoder, device=device, invariant_to=invariant_to + dl=dm.test_dataloader(), encoder=encoder, device=device, segment=segment ) s_count = dm.dim_s if dm.dim_s > 1 else 2 if self.umap_viz: - _log_enc_statistics(test_eval.inv_s, step=step, s_count=s_count) - if test_eval.inv_y is not None and (test_eval.inv_y.x[0].size(1) == 1): - zs = test_eval.inv_y.x[:, 0].view((test_eval.inv_y.x.size(0),)).sigmoid() + _log_enc_statistics(test_eval.zy, step=step, s_count=s_count) + if test_eval.zs is not None and (test_eval.zs.x[0].size(1) == 1): + zs = test_eval.zs.x[:, 0].view((test_eval.zs.x.size(0),)).sigmoid() zs_np = zs.detach().cpu().numpy() fig, plot = plt.subplots(dpi=200, figsize=(6, 4)) plot.hist(zs_np, bins=20, range=(0, 1)) @@ -322,7 +325,7 @@ def run( wandb.log({"zs_histogram": wandb.Image(fig)}, step=step) enc_size = encoder.encoding_size - dm_zy = gcopy(dm, deep=False, train=train_eval.inv_s, test=test_eval.inv_s) + dm_zy = gcopy(dm, deep=False, train=train_eval.zy, test=test_eval.zy) logger.info("\nComputing metrics...") self._evaluate( dm=dm_zy, @@ -335,16 +338,16 @@ def run( if self.eval_s_from_zs is not None: if self.eval_s_from_zs is EvalTrainData.train: - train_data = train_eval.inv_y # the part that is invariant to y corresponds to zs + train_data = train_eval.zs # the part that is invariant to y corresponds to zs else: encoded_dep = encode_dataset( dl=dm.deployment_dataloader(eval=True), encoder=encoder, device=device, - invariant_to="y", + segment="zs", ) - train_data = encoded_dep.inv_y - dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.inv_y) + train_data = encoded_dep.zs + dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.zs) self._evaluate( dm=dm_zs, device=device, diff --git a/src/models/autoencoder.py b/src/models/autoencoder.py index bf9a4546..38369bc1 100644 --- a/src/models/autoencoder.py +++ b/src/models/autoencoder.py @@ -159,7 +159,7 @@ def decode( s_ = s.view(-1, 1) split_encoding = replace(split_encoding, zs=s_.float()) - decoding = self.model.decoder(split_encoding.join()) + decoding = self.model.decode(split_encoding.join()) if mode in ("hard", "relaxed") and self.feature_group_slices: discrete_outputs_ls: list[Tensor] = [] stop_index = 0