Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set up more CelebA experiments #321

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions external_confs/alg/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
disc_loss_w: 0
enc_loss_w: 0.0
num_disc_updates: 0
pred_s_loss_w: 0
pred_y_loss_w: 1.0
prior_loss_w: null
twoway_disc_loss: false
warmup_steps: 0
pred:
scheduler_cls: ranzen.torch.schedulers.CosineLRWithLinearWarmup
scheduler_kwargs:
total_iters: ${ alg.steps }
lr_min: 5.e-7
warmup_iters: 0.05
8 changes: 4 additions & 4 deletions external_confs/alg/supmatch_no_disc.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
---
disc_loss_w: 0
num_disc_updates: 0
twoway_disc_loss: false
prior_loss_w: 0
pred_y_loss_w: 0
pred_s_loss_w: 0
pred_y_loss_w: 0
prior_loss_w: 0
twoway_disc_loss: false
warmup_steps: 0
disc_loss_w: 0
2 changes: 1 addition & 1 deletion external_confs/dm/nicopp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
stratified_sampler: approx_class
num_workers: 4
batch_size_tr: 1
batch_size_te: 10
batch_size_te: 20
7 changes: 7 additions & 0 deletions external_confs/ds/celeba/gender_blond.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
defaults:
- celeba
- _self_
download: false
superclass: BLOND_HAIR
subclass: MALE
2 changes: 0 additions & 2 deletions external_confs/ds/celeba/gender_smiling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,3 @@ defaults:
download: false
superclass: SMILING
subclass: MALE
transform: null
split: null
58 changes: 58 additions & 0 deletions external_confs/experiment/celeba/rn50/pretrained_enc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# @package _global_

defaults:
- override /ds: celeba
- override /split: celeba/random/no_nonblond_females
- override /labeller: gt
- override /ae_arch: resnet/rn50_256_pre
- _self_

ae:
lr: 1.e-5
zs_dim: 6
zs_transform: none

alg:
use_amp: true
pred:
lr: ${ ae.lr }
log_freq: ${ alg.steps }
val_freq: 200
num_disc_updates: 5
# enc_loss_w: 0.0001
enc_loss_w: 1
disc_loss_w: 0.03
# prior_loss_w: 0.01
prior_loss_w: null
pred_y_loss_w: 1
pred_s_loss_w: 0
pred_y:
num_hidden: 1 # for decoding the pre-trained RN50 output
dropout_prob: 0.1
s_pred_with_bias: false
s_as_zs: false

disc:
lr: 1.e-4

# disc_arch:
# dropout_prob: 0.1

dm:
stratified_sampler: exact
num_workers: 4
batch_size_tr: 10
batch_size_te: 20

eval:
batch_size: 10
balanced_sampling: true
hidden_dim: null
num_hidden: 1
steps: 10000
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_kwargs:
T_max: ${ eval.steps }
eta_min: 5e-7
11 changes: 2 additions & 9 deletions external_confs/experiment/nicopp/rn18/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

defaults:
- /ae: cosine_annealing
- /alg: supmatch_no_disc
- /alg: only_pred_y_loss
- /eval: nicopp
- override /ae_arch: resnet
- override /dm: nicopp
Expand All @@ -19,17 +19,10 @@ ae_arch:
pretrained_enc: true

alg:
pred_y_loss_w: 1.0
enc_loss_w: 0.0
steps: 30000
use_amp: true
pred:
lr: 5.e-5
scheduler_cls: ranzen.torch.schedulers.CosineLRWithLinearWarmup
scheduler_kwargs:
total_iters: ${ alg.steps }
lr_min: 5.e-7
warmup_iters: 0.05
lr: ${ ae lr }
log_freq: 100000000000 # never
val_freq: 1000

Expand Down
7 changes: 2 additions & 5 deletions external_confs/experiment/nicopp/rn50/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

defaults:
- /ae: cosine_annealing
- /alg: supmatch_no_disc
- /alg/pred: cosine_annealing
- /alg: only_pred_y_loss
- /eval: nicopp
- override /ae_arch: resnet/rn50_256_pre
- override /dm: nicopp
Expand All @@ -13,12 +12,10 @@ defaults:
- _self_

alg:
pred_y_loss_w: 1.0
enc_loss_w: 0.0
steps: 30000
use_amp: true
pred:
lr: 5.e-5
lr: ${ ae.lr }
log_freq: 100000000000 # never
val_freq: 1000

Expand Down
4 changes: 2 additions & 2 deletions external_confs/experiment/nicopp/rn50/pretrained_enc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ defaults:
- _self_

ae:
lr: 5.e-5
lr: 1.e-5
zs_dim: 6
zs_transform: none

alg:
use_amp: true
pred:
lr: ${ ae.lr }
log_freq: 100000000000 # never
log_freq: ${ alg.steps }
val_freq: 200
num_disc_updates: 5
# enc_loss_w: 0.0001
Expand Down
2 changes: 1 addition & 1 deletion external_confs/hydra/launcher/slurm/ada.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- submitit_slurm

partition: ada
cpus_per_task: 8 # on ada, we have 3 CPUs per GPU
cpus_per_task: 10 # on ada, we have 8 CPUs per GPU, but we use 10
timeout_min: 99999 # 99999 minutes = a few months

additional_parameters:
Expand Down
10 changes: 10 additions & 0 deletions external_confs/split/celeba/random/no_nonblond_females.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
defaults:
- random
- /split/celeba/random/base@_here_
- _self_

dep_prop: 0.1
test_prop: 0.1
artifact_name: split_celeba_no_nonblond_females_${oc.env:SLURM_NODELIST}_${.seed}
train_subsampling_props: {0: {0: 0}} # Drop all nonblond females
7 changes: 7 additions & 0 deletions external_confs/split/nicopp/change_is_hard_seeded.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
defaults:
- artifact
- /split/nicopp/base@_here_
- _self_

artifact_name: split_nicopp_change_is_hard_kyiv_${ seed }
38 changes: 38 additions & 0 deletions scripts/save_nicopp_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import platform
from sys import argv

from conduit.data.datasets.vision import NICOPP
import numpy as np
import torch
import wandb

from src.data.common import find_data_dir
from src.data.splitter import save_split_inds_as_artifact


def main(seed: int) -> None:
assert seed >= 0
run = wandb.init(
project="support-matching", entity="predictive-analytics-lab", dir="local_logging"
)
NICOPP.data_split_seed = seed
ds = NICOPP(root=find_data_dir())
split_ids = ds.metadata["split"]
train_inds = torch.as_tensor(np.nonzero(split_ids == NICOPP.Split.TRAIN.value)[0])
test_inds = torch.as_tensor(np.nonzero(split_ids == NICOPP.Split.TEST.value)[0])
dep_inds = torch.as_tensor(np.nonzero(split_ids == NICOPP.Split.VAL.value)[0])
name_of_machine = platform.node()
save_split_inds_as_artifact(
run=run,
train_inds=train_inds,
test_inds=test_inds,
dep_inds=dep_inds,
ds=ds,
seed=seed,
artifact_name=f"split_nicopp_change_is_hard_{name_of_machine}_{seed}",
)
run.finish()


if __name__ == "__main__":
main(int(argv[1]))
6 changes: 5 additions & 1 deletion src/algs/adv/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]:
dl_tr = dm.train_dataloader(balance=True)
# The batch size needs to be consistent for the aggregation layer in the setwise neural
# discriminator
dl_dep = dm.deployment_dataloader(batch_size=dm.batch_size_tr)
dl_dep = dm.deployment_dataloader(
batch_size=dl_tr.batch_sampler.batch_size
if dm.deployment_ids is None
else dm.batch_size_tr
)
return iter(dl_tr), iter(dl_dep)

@override
Expand Down