Skip to content

Commit

Permalink
Improve encoder (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Oct 2, 2023
2 parents 17c9204 + 6bcdc20 commit a9dd3db
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 37 deletions.
75 changes: 39 additions & 36 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class EvalTrainData(Enum):

@dataclass(frozen=True)
class InvariantDatasets(Generic[DY, DS]):
inv_y: DY
inv_s: DS
zs: DY
zy: DS


def log_sample_images(
Expand All @@ -70,7 +70,7 @@ def log_sample_images(
log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)


InvariantAttr = Literal["s", "y", "both"]
InvariantAttr = Literal["zy", "zs", "both"]


_PBAR_COL: Final[str] = "#ffe252"
Expand All @@ -82,7 +82,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["y"] = ...,
segment: Literal["zs"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], None]:
...

Expand All @@ -93,7 +94,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["s"] = ...,
segment: Literal["zy"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[None, Dataset[Tensor]]:
...

Expand All @@ -104,7 +106,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["both"],
segment: Literal["both"],
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
...

Expand All @@ -114,40 +117,40 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: InvariantAttr = "s",
segment: InvariantAttr = "zy",
use_amp: bool = False,
) -> InvariantDatasets:
device = resolve_device(device)
zy_ls, zs_ls, s_ls, y_ls = [], [], [], []
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)
with torch.cuda.amp.autocast(enabled=use_amp): # type: ignore
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)

# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)
# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)

if invariant_to in ("s", "both"):
zy_ls.append(encodings.zy.detach().cpu())
if segment in ("zy", "both"):
zy_ls.append(encodings.zy.detach().cpu())

if invariant_to in ("y", "both"):
zs_ls.append(encodings.zs.detach().cpu())
if segment in ("zs", "both"):
zs_ls.append(encodings.zs.detach().cpu())

s_ls = torch.cat(s_ls, dim=0)
y_ls = torch.cat(y_ls, dim=0)
inv_y = None
zs_ds = None
if zs_ls:
inv_y = torch.cat(zs_ls, dim=0)
inv_y = CdtDataset(x=inv_y, s=s_ls, y=y_ls)
zs_ds = CdtDataset(x=torch.cat(zs_ls, dim=0), s=s_ls, y=y_ls)

inv_s = None
zy_ds = None
if zy_ls:
inv_s = torch.cat(zy_ls, dim=0)
inv_s = CdtDataset(x=inv_s, s=s_ls, y=y_ls)
zy_ds = CdtDataset(x=torch.cat(zy_ls, dim=0), s=s_ls, y=y_ls)

logger.info("Finished encoding")

return InvariantDatasets(inv_y=inv_y, inv_s=inv_s)
return InvariantDatasets(zs=zs_ds, zy=zy_ds)


def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None:
Expand Down Expand Up @@ -295,25 +298,25 @@ def run(
) -> DataModule:
device = resolve_device(device)
encoder.eval()
invariant_to = "both" if self.eval_s_from_zs is not None else "s"
segment = "both" if self.eval_s_from_zs is not None else "zy"

logger.info("Encoding training set")
train_eval = encode_dataset(
dl=dm.train_dataloader(eval=True, batch_size=dm.batch_size_te),
encoder=encoder,
device=device,
invariant_to=invariant_to,
segment=segment,
)
logger.info("Encoding test set")
test_eval = encode_dataset(
dl=dm.test_dataloader(), encoder=encoder, device=device, invariant_to=invariant_to
dl=dm.test_dataloader(), encoder=encoder, device=device, segment=segment
)

s_count = dm.dim_s if dm.dim_s > 1 else 2
if self.umap_viz:
_log_enc_statistics(test_eval.inv_s, step=step, s_count=s_count)
if test_eval.inv_y is not None and (test_eval.inv_y.x[0].size(1) == 1):
zs = test_eval.inv_y.x[:, 0].view((test_eval.inv_y.x.size(0),)).sigmoid()
_log_enc_statistics(test_eval.zy, step=step, s_count=s_count)
if test_eval.zs is not None and (test_eval.zs.x[0].size(1) == 1):
zs = test_eval.zs.x[:, 0].view((test_eval.zs.x.size(0),)).sigmoid()
zs_np = zs.detach().cpu().numpy()
fig, plot = plt.subplots(dpi=200, figsize=(6, 4))
plot.hist(zs_np, bins=20, range=(0, 1))
Expand All @@ -322,7 +325,7 @@ def run(
wandb.log({"zs_histogram": wandb.Image(fig)}, step=step)

enc_size = encoder.encoding_size
dm_zy = gcopy(dm, deep=False, train=train_eval.inv_s, test=test_eval.inv_s)
dm_zy = gcopy(dm, deep=False, train=train_eval.zy, test=test_eval.zy)
logger.info("\nComputing metrics...")
self._evaluate(
dm=dm_zy,
Expand All @@ -335,16 +338,16 @@ def run(

if self.eval_s_from_zs is not None:
if self.eval_s_from_zs is EvalTrainData.train:
train_data = train_eval.inv_y # the part that is invariant to y corresponds to zs
train_data = train_eval.zs # the part that is invariant to y corresponds to zs
else:
encoded_dep = encode_dataset(
dl=dm.deployment_dataloader(eval=True),
encoder=encoder,
device=device,
invariant_to="y",
segment="zs",
)
train_data = encoded_dep.inv_y
dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.inv_y)
train_data = encoded_dep.zs
dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.zs)
self._evaluate(
dm=dm_zs,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion src/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def decode(
s_ = s.view(-1, 1)
split_encoding = replace(split_encoding, zs=s_.float())

decoding = self.model.decoder(split_encoding.join())
decoding = self.model.decode(split_encoding.join())
if mode in ("hard", "relaxed") and self.feature_group_slices:
discrete_outputs_ls: list[Tensor] = []
stop_index = 0
Expand Down

0 comments on commit a9dd3db

Please sign in to comment.