diff --git a/src/algs/adv/evaluator.py b/src/algs/adv/evaluator.py index c6ba9814..6527a52a 100644 --- a/src/algs/adv/evaluator.py +++ b/src/algs/adv/evaluator.py @@ -83,6 +83,7 @@ def encode_dataset( encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["y"] = ..., + use_amp: bool = False, ) -> InvariantDatasets[Dataset[Tensor], None]: ... @@ -94,6 +95,7 @@ def encode_dataset( encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["s"] = ..., + use_amp: bool = False, ) -> InvariantDatasets[None, Dataset[Tensor]]: ... @@ -105,6 +107,7 @@ def encode_dataset( encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: Literal["both"], + use_amp: bool = False, ) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]: ... @@ -115,23 +118,25 @@ def encode_dataset( encoder: SplitLatentAe, device: Union[str, torch.device], invariant_to: InvariantAttr = "s", + use_amp: bool = False, ) -> InvariantDatasets: device = resolve_device(device) zy_ls, zs_ls, s_ls, y_ls = [], [], [], [] - with torch.no_grad(): - for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL): - x = batch.x.to(device, non_blocking=True) - s_ls.append(batch.s) - y_ls.append(batch.y) + with torch.cuda.amp.autocast(enabled=use_amp): # type: ignore + with torch.no_grad(): + for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL): + x = batch.x.to(device, non_blocking=True) + s_ls.append(batch.s) + y_ls.append(batch.y) - # don't do the zs transform here because we might want to look at the raw distribution - encodings = encoder.encode(x, transform_zs=False) + # don't do the zs transform here because we might want to look at the raw distribution + encodings = encoder.encode(x, transform_zs=False) - if invariant_to in ("s", "both"): - zy_ls.append(encodings.zy.detach().cpu()) + if invariant_to in ("s", "both"): + zy_ls.append(encodings.zy.detach().cpu()) - if invariant_to in ("y", "both"): - zs_ls.append(encodings.zs.detach().cpu()) + if invariant_to in ("y", "both"): + zs_ls.append(encodings.zs.detach().cpu()) s_ls = torch.cat(s_ls, dim=0) y_ls = torch.cat(y_ls, dim=0)