Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve encoder #320

Merged
merged 2 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 39 additions & 36 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class EvalTrainData(Enum):

@dataclass(frozen=True)
class InvariantDatasets(Generic[DY, DS]):
inv_y: DY
inv_s: DS
zs: DY
zy: DS


def log_sample_images(
Expand All @@ -70,7 +70,7 @@ def log_sample_images(
log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)


InvariantAttr = Literal["s", "y", "both"]
InvariantAttr = Literal["zy", "zs", "both"]


_PBAR_COL: Final[str] = "#ffe252"
Expand All @@ -82,7 +82,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["y"] = ...,
segment: Literal["zs"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], None]:
...

Expand All @@ -93,7 +94,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["s"] = ...,
segment: Literal["zy"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[None, Dataset[Tensor]]:
...

Expand All @@ -104,7 +106,8 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: Literal["both"],
segment: Literal["both"],
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
...

Expand All @@ -114,40 +117,40 @@ def encode_dataset(
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
invariant_to: InvariantAttr = "s",
segment: InvariantAttr = "zy",
use_amp: bool = False,
) -> InvariantDatasets:
device = resolve_device(device)
zy_ls, zs_ls, s_ls, y_ls = [], [], [], []
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)
with torch.cuda.amp.autocast(enabled=use_amp): # type: ignore
with torch.no_grad():
for batch in tqdm(dl, desc="Encoding dataset", colour=_PBAR_COL):
x = batch.x.to(device, non_blocking=True)
s_ls.append(batch.s)
y_ls.append(batch.y)

# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)
# don't do the zs transform here because we might want to look at the raw distribution
encodings = encoder.encode(x, transform_zs=False)

if invariant_to in ("s", "both"):
zy_ls.append(encodings.zy.detach().cpu())
if segment in ("zy", "both"):
zy_ls.append(encodings.zy.detach().cpu())

if invariant_to in ("y", "both"):
zs_ls.append(encodings.zs.detach().cpu())
if segment in ("zs", "both"):
zs_ls.append(encodings.zs.detach().cpu())

s_ls = torch.cat(s_ls, dim=0)
y_ls = torch.cat(y_ls, dim=0)
inv_y = None
zs_ds = None
if zs_ls:
inv_y = torch.cat(zs_ls, dim=0)
inv_y = CdtDataset(x=inv_y, s=s_ls, y=y_ls)
zs_ds = CdtDataset(x=torch.cat(zs_ls, dim=0), s=s_ls, y=y_ls)

inv_s = None
zy_ds = None
if zy_ls:
inv_s = torch.cat(zy_ls, dim=0)
inv_s = CdtDataset(x=inv_s, s=s_ls, y=y_ls)
zy_ds = CdtDataset(x=torch.cat(zy_ls, dim=0), s=s_ls, y=y_ls)

logger.info("Finished encoding")

return InvariantDatasets(inv_y=inv_y, inv_s=inv_s)
return InvariantDatasets(zs=zs_ds, zy=zy_ds)


def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None:
Expand Down Expand Up @@ -295,25 +298,25 @@ def run(
) -> DataModule:
device = resolve_device(device)
encoder.eval()
invariant_to = "both" if self.eval_s_from_zs is not None else "s"
segment = "both" if self.eval_s_from_zs is not None else "zy"

logger.info("Encoding training set")
train_eval = encode_dataset(
dl=dm.train_dataloader(eval=True, batch_size=dm.batch_size_te),
encoder=encoder,
device=device,
invariant_to=invariant_to,
segment=segment,
)
logger.info("Encoding test set")
test_eval = encode_dataset(
dl=dm.test_dataloader(), encoder=encoder, device=device, invariant_to=invariant_to
dl=dm.test_dataloader(), encoder=encoder, device=device, segment=segment
)

s_count = dm.dim_s if dm.dim_s > 1 else 2
if self.umap_viz:
_log_enc_statistics(test_eval.inv_s, step=step, s_count=s_count)
if test_eval.inv_y is not None and (test_eval.inv_y.x[0].size(1) == 1):
zs = test_eval.inv_y.x[:, 0].view((test_eval.inv_y.x.size(0),)).sigmoid()
_log_enc_statistics(test_eval.zy, step=step, s_count=s_count)
if test_eval.zs is not None and (test_eval.zs.x[0].size(1) == 1):
zs = test_eval.zs.x[:, 0].view((test_eval.zs.x.size(0),)).sigmoid()
zs_np = zs.detach().cpu().numpy()
fig, plot = plt.subplots(dpi=200, figsize=(6, 4))
plot.hist(zs_np, bins=20, range=(0, 1))
Expand All @@ -322,7 +325,7 @@ def run(
wandb.log({"zs_histogram": wandb.Image(fig)}, step=step)

enc_size = encoder.encoding_size
dm_zy = gcopy(dm, deep=False, train=train_eval.inv_s, test=test_eval.inv_s)
dm_zy = gcopy(dm, deep=False, train=train_eval.zy, test=test_eval.zy)
logger.info("\nComputing metrics...")
self._evaluate(
dm=dm_zy,
Expand All @@ -335,16 +338,16 @@ def run(

if self.eval_s_from_zs is not None:
if self.eval_s_from_zs is EvalTrainData.train:
train_data = train_eval.inv_y # the part that is invariant to y corresponds to zs
train_data = train_eval.zs # the part that is invariant to y corresponds to zs
else:
encoded_dep = encode_dataset(
dl=dm.deployment_dataloader(eval=True),
encoder=encoder,
device=device,
invariant_to="y",
segment="zs",
)
train_data = encoded_dep.inv_y
dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.inv_y)
train_data = encoded_dep.zs
dm_zs = gcopy(dm, deep=False, train=train_data, test=test_eval.zs)
self._evaluate(
dm=dm_zs,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion src/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def decode(
s_ = s.view(-1, 1)
split_encoding = replace(split_encoding, zs=s_.float())

decoding = self.model.decoder(split_encoding.join())
decoding = self.model.decode(split_encoding.join())
if mode in ("hard", "relaxed") and self.feature_group_slices:
discrete_outputs_ls: list[Tensor] = []
stop_index = 0
Expand Down