Skip to content

Commit

Permalink
Big spring cleaning (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Feb 23, 2024
2 parents e05fe71 + 4cf7dd3 commit e8d0e1b
Show file tree
Hide file tree
Showing 30 changed files with 183 additions and 238 deletions.
44 changes: 43 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
- main

jobs:
format_with_black:
format_with_ruff:

runs-on: ubuntu-latest

Expand Down Expand Up @@ -40,3 +40,45 @@ jobs:
- name: Lint with ruff
run: |
ruff check --output-format=github .
run_type_checking:
needs:
- format_with_ruff
- lint_with_ruff
runs-on: ubuntu-latest

steps:
# ----------------------------------------------
# ---- check-out repo and set-up python ----
# ----------------------------------------------
- name: Check out repository
uses: actions/checkout@v3
# ----------------------------------------------
# ----- install & configure poetry -----
# ----------------------------------------------
- name: Install poetry
run: pipx install poetry
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'poetry'

# ----------------------------------------------
# install dependencies if cache does not exist
# ----------------------------------------------
- name: Install dependencies
run: |
poetry env use 3.10
poetry install --no-interaction --no-root --with torch
- name: Set python path for all subsequent actions
run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH

# ----------------------------------------------
# ----- install and run pyright -----
# ----------------------------------------------
- uses: jakebailey/pyright-action@v2
with:
# don't show warnings
level: error
extra-args: src
110 changes: 1 addition & 109 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ torchvision = ">=0.15.2"

[tool.poetry.group.dev.dependencies]
ruff = "*"
mypy = "*"
pytest = "*"
types-tqdm = "*"
pandas-stubs = "*"

Expand Down Expand Up @@ -124,6 +122,7 @@ reportUnknownLambdaType = "none"
reportUnknownVariableType = "none"
reportUnknownMemberType = "none"
reportMissingTypeArgument = "none"
reportUnnecessaryCast = "warning"
reportUnnecessaryTypeIgnoreComment = "warning"
exclude = [
"outputs",
Expand All @@ -136,5 +135,7 @@ exclude = [
"hydra_plugins",
"external_confs",
"conf",
"scripts",
"experiments",
".venv",
]
6 changes: 3 additions & 3 deletions src/algs/adv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Components(DcModule, Generic[D]):
pred_y: Optional[Classifier]
pred_s: Optional[Classifier]

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def train_ae(self) -> None:
self.ae.train()
if self.pred_y is not None:
Expand All @@ -52,7 +52,7 @@ def train_ae(self) -> None:
if isinstance(self.disc, nn.Module):
self.disc.eval()

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def train_disc(self) -> None:
self.ae.eval()
if self.pred_y is not None:
Expand Down Expand Up @@ -172,7 +172,7 @@ def training_step(
self.log_recons(x=x_dep, dm=dm, ae=comp.ae, itr=itr, split="deployment")
return logging_dict

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def log_recons(
self,
x: Tensor,
Expand Down
25 changes: 11 additions & 14 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from conduit.data import TernarySample
from conduit.data.datasets import CdtDataLoader, CdtDataset
from conduit.data.datasets.vision import CdtVisionDataset
from loguru import logger
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
Expand All @@ -22,14 +21,12 @@
from src.arch.predictors import Fcn
from src.data import DataModule, Dataset, group_id_to_label, labels_to_group_id, resolve_device
from src.evaluation.metrics import EmEvalPair, compute_metrics
from src.logging import log_images
from src.models import Classifier, OptimizerCfg, SplitLatentAe

__all__ = [
"Evaluator",
"InvariantDatasets",
"encode_dataset",
"log_sample_images",
"visualize_clusters",
]

Expand All @@ -51,17 +48,17 @@ class InvariantDatasets(Generic[DY, DS]):
zy: DS


def log_sample_images(
*,
data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor],
dm: DataModule,
name: str,
step: int,
num_samples: int = 64,
) -> None:
inds: list[int] = torch.randperm(len(data))[:num_samples].tolist()
images = data[inds]
log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)
# def log_sample_images(
# *,
# data: CdtVisionDataset[TernarySample[Tensor], Tensor, Tensor],
# dm: DataModule,
# name: str,
# step: int,
# num_samples: int = 64,
# ) -> None:
# inds: list[int] = torch.randperm(len(data))[:num_samples].tolist()
# images = data[inds]
# log_images(images=images, dm=dm, name=f"Samples from {name}", prefix="eval", step=step)


InvariantAttr = Literal["zy", "zs", "both"]
Expand Down
4 changes: 2 additions & 2 deletions src/algs/adv/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_PBAR_COL: Final[str] = "#ffe252"


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def _encode_and_score_recons(
dl: CdtDataLoader[TernarySample],
*,
Expand Down Expand Up @@ -60,7 +60,7 @@ def _encode_and_score_recons(
return CdtDataset(x=zy, y=y, s=s), recon_score


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def balanced_accuracy(y_pred: Tensor, *, y_true: Tensor) -> Tensor:
return cdtm.subclass_balanced_accuracy(y_pred=y_pred, y_true=y_true, s=y_true)

Expand Down
6 changes: 3 additions & 3 deletions src/algs/adv/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]:
dl_tr = dm.train_dataloader(balance=True)
# The batch size needs to be consistent for the aggregation layer in the setwise neural
# discriminator
batch_size: int = dl_tr.batch_sampler.batch_size # type: ignore
dl_dep = dm.deployment_dataloader(
batch_size=dl_tr.batch_sampler.batch_size
if dm.deployment_ids is None
else dm.batch_size_tr
batch_size=batch_size if dm.deployment_ids is None else dm.batch_size_tr
)
return iter(dl_tr), iter(dl_dep)

Expand Down Expand Up @@ -161,6 +160,7 @@ def fit_evaluate_score(
disc_model_sd0 = None
if isinstance(disc, NeuralDiscriminator) and isinstance(disc.model, SetPredictor):
disc_model_sd0 = disc.model.state_dict()
assert isinstance(disc, Model)
super().fit_and_evaluate(dm=dm, ae=ae, disc=disc, evaluator=evaluator)
# TODO: Generalise this to other discriminator types and architectures
if disc_model_sd0 is not None:
Expand Down
8 changes: 2 additions & 6 deletions src/algs/fs/gdro.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,14 @@ def update_stats(
self.avg_acc = group_frac @ self.avg_group_acc


@dataclass
class _LcMixin:
@dataclass(kw_only=True, repr=False, eq=False, frozen=True)
class GdroClassifier(Classifier):
loss_computer: LossComputer


@dataclass(repr=False, eq=False)
class GdroClassifier(Classifier, _LcMixin):
def __post_init__(self) -> None:
# LossComputer requires that the criterion return per-sample (unreduced) losses.
if self.criterion is not None:
self.criterion.reduction = ReductionType.none
super().__post_init__()

@override
def training_step(self, batch: TernarySample[Tensor], *, pred_s: bool = False) -> Tensor:
Expand Down
Loading

0 comments on commit e8d0e1b

Please sign in to comment.