Skip to content

Commit

Permalink
Allow AMP in encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Oct 2, 2023
1 parent 17c9204 commit 74a676f
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def encode_dataset(
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["y"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], None]:
...

Expand All @@ -94,6 +95,7 @@ def encode_dataset(
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["s"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[None, Dataset[Tensor]]:
...

Expand All @@ -105,6 +107,7 @@ def encode_dataset(
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["both"],
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
...

Expand All @@ -115,23 +118,25 @@ def encode_dataset(
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: InvariantAttr = "s",
use_amp: bool = False,
) -> InvariantDatasets:
device = resolve_device(device)
zy_ls, zs_ls, s_ls, y_ls = [], [], [], []
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)
with torch.cuda.amp.autocast(enabled=use_amp): # type: ignore
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)

# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)
# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)

if invariant_to in ("s", "both"):
zy_ls.append(encodings.zy.detach().cpu())
if invariant_to in ("s", "both"):
zy_ls.append(encodings.zy.detach().cpu())

if invariant_to in ("y", "both"):
zs_ls.append(encodings.zs.detach().cpu())
if invariant_to in ("y", "both"):
zs_ls.append(encodings.zs.detach().cpu())

s_ls = torch.cat(s_ls, dim=0)
y_ls = torch.cat(y_ls, dim=0)
Expand Down

0 comments on commit 74a676f

Please sign in to comment.