Skip to content

Commit

Permalink
Modernize the type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 29, 2024
1 parent 61f27b2 commit d3c39c2
Show file tree
Hide file tree
Showing 57 changed files with 324 additions and 341 deletions.
2 changes: 1 addition & 1 deletion analysis/wandb_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def groups(self, *groups_: str) -> pd.DataFrame:
return pd.concat(dfs, axis="index", sort=False, ignore_index=True)

def modify_config(
self, group: str, config_key: str, new_value: Union[bool, int, float, str]
self, group: str, config_key: str, new_value: bool | int | float | str
) -> None:
path = f"{self.entity}/{self.project}"
runs = self.api.runs(path, {"group": group})
Expand Down
17 changes: 9 additions & 8 deletions analysis/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import math
import operator
from pathlib import Path
from typing import ClassVar, Final, NamedTuple, Optional, TypeAlias, TypedDict, TypeVar, Union
from typing import ClassVar, Final, NamedTuple, Optional, TypedDict, TypeVar, Union
from typing_extensions import TypeAliasType

from matplotlib import pyplot as plt
from matplotlib.axes import Axes
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, aggregate: Aggregate):
self.aggregate = aggregate


Triplet: TypeAlias = tuple[Metrics, Aggregation | None, str]
Triplet = TypeAliasType("Triplet", tuple[Metrics, Aggregation | None, str])


class SpecialMetrics:
Expand Down Expand Up @@ -283,7 +284,7 @@ class PlotKwargs(TypedDict, total=False):
fig_dim: tuple[float, float]
file_prefix: str
sens_attr: str
output_dir: Union[Path, str]
output_dir: Path | str
separator_after: int | None


Expand All @@ -292,7 +293,7 @@ def plot(
groupby: str = "misc.log_method",
metrics: list[Metrics | Triplet] = [Metrics.acc],
sens_attr: str = "colour",
output_dir: Union[Path, str] = Path("."),
output_dir: Path | str = Path("."),
file_format: str = "png",
file_prefix: str = "",
fig_dim: tuple[float, float] = (4.0, 6.0),
Expand All @@ -301,9 +302,9 @@ def plot(
agg: Aggregation | None = None,
fillna: bool = False,
hide_left_ticks: bool = False,
x_label: Optional[str] = None,
x_label: str | None = None,
plot_style: PlotStyle = PlotStyle.boxplot,
plot_title: Optional[str] = None,
plot_title: str | None = None,
with_legend: bool = True,
separator_after: int | None = None,
) -> None:
Expand Down Expand Up @@ -407,9 +408,9 @@ def _make_plot(
x_limits: tuple[float, float],
y_limits: tuple[float, float],
hide_left_ticks: bool,
x_label: Optional[str],
x_label: str | None,
plot_style: PlotStyle,
plot_title: Optional[str] = None,
plot_title: str | None = None,
with_legend: bool = True,
separator_after: int | None = None,
) -> Figure:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ ignore = [
"F541", # f-string without placeholder
"E501", # line too long
"E741", # Ambiguous variable name
"UP006", # generic standard library
"UP007", # new-style unions
"UP038", # isinstance check with unions
]

Expand Down
22 changes: 11 additions & 11 deletions src/algs/adv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import ClassVar, Generic, Literal, Optional, TypeAlias, TypeVar, Union
from typing_extensions import Self
from typing import ClassVar, Generic, Literal, TypeVar
from typing_extensions import Self, TypeAliasType

from conduit.data.structures import SampleBase, TernarySample
from conduit.metrics import accuracy
Expand Down Expand Up @@ -32,16 +32,16 @@

D = TypeVar("D")

IterTr: TypeAlias = Iterator[TernarySample[Tensor]]
IterDep: TypeAlias = Iterator[SampleBase[Tensor]]
IterTr = TypeAliasType("IterTr", Iterator[TernarySample[Tensor]])
IterDep = TypeAliasType("IterDep", Iterator[SampleBase[Tensor]])


@dataclass(repr=False, eq=False)
class Components(DcModule, Generic[D]):
ae: SplitLatentAe
disc: D
pred_y: Optional[Classifier]
pred_s: Optional[Classifier]
pred_y: Classifier | None
pred_s: Classifier | None

@torch.no_grad() # pyright: ignore
def train_ae(self) -> None:
Expand Down Expand Up @@ -77,7 +77,7 @@ class AdvSemiSupervisedAlg(Algorithm):

enc_loss_w: float = 1
disc_loss_w: float = 1
prior_loss_w: Optional[float] = None
prior_loss_w: float | None = None
num_disc_updates: int = 3
# Whether to use the deployment set when computing the encoder's adversarial loss
twoway_disc_loss: bool = True
Expand All @@ -91,7 +91,7 @@ class AdvSemiSupervisedAlg(Algorithm):

# Misc
validate: bool = True
val_freq: Union[int, float] = 0.1 # how often to do validation
val_freq: int | float = 0.1 # how often to do validation
log_freq: int = 150

def __post_init__(self) -> None:
Expand All @@ -107,7 +107,7 @@ def _sample_tr(self, iterator_tr: Iterator[TernarySample[Tensor]]) -> TernarySam

def _build_predictors(
self, ae: SplitLatentAe, *, y_dim: int, s_dim: int
) -> tuple[Optional[Classifier], Optional[Classifier]]:
) -> tuple[Classifier | None, Classifier | None]:
pred_y = None
if self.pred_y_loss_w > 0:
model, _ = self.pred_y(input_dim=ae.encoding_size.zy, target_dim=y_dim)
Expand Down Expand Up @@ -273,12 +273,12 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]:
return iter(dl_tr), iter(dl_dep)

def _evaluate(
self, dm: DataModule, *, ae: SplitLatentAe, evaluator: Evaluator, step: Optional[int] = None
self, dm: DataModule, *, ae: SplitLatentAe, evaluator: Evaluator, step: int | None = None
) -> DataModule:
return evaluator(dm=dm, encoder=ae, step=step, device=self.device)

def _evaluate_pred_y(
self, dm: DataModule, *, comp: Components, step: Optional[int] = None
self, dm: DataModule, *, comp: Components, step: int | None = None
) -> None:
if comp.pred_y is not None:
et = comp.pred_y.predict(dm.test_dataloader(), device=self.device)
Expand Down
38 changes: 19 additions & 19 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Final, Generic, Literal, Optional, TypeVar, Union, overload
from typing import Any, Final, Generic, Literal, TypeVar, overload

from conduit.data import TernarySample
from conduit.data.datasets import CdtDataLoader, CdtDataset
Expand Down Expand Up @@ -31,8 +31,8 @@
]


DY = TypeVar("DY", bound=Optional[Dataset[Tensor]])
DS = TypeVar("DS", bound=Optional[Dataset[Tensor]])
DY = TypeVar("DY", bound=Dataset[Tensor] | None)
DS = TypeVar("DS", bound=Dataset[Tensor] | None)


class EvalTrainData(Enum):
Expand Down Expand Up @@ -72,7 +72,7 @@ def encode_dataset(
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
device: str | torch.device,
segment: Literal["zs"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], None]:
Expand All @@ -84,7 +84,7 @@ def encode_dataset(
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
device: str | torch.device,
segment: Literal["zy"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[None, Dataset[Tensor]]:
Expand All @@ -96,7 +96,7 @@ def encode_dataset(
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
device: str | torch.device,
segment: Literal["both"],
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
Expand All @@ -107,7 +107,7 @@ def encode_dataset(
dl: CdtDataLoader[TernarySample[Tensor]],
*,
encoder: SplitLatentAe,
device: Union[str, torch.device],
device: str | torch.device,
segment: InvariantAttr = "zy",
use_amp: bool = False,
) -> InvariantDatasets:
Expand Down Expand Up @@ -144,7 +144,7 @@ def encode_dataset(
return InvariantDatasets(zs=zs_ds, zy=zy_ds)


def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_count: int) -> None:
def _log_enc_statistics(encoded: Dataset[Tensor], *, step: int | None, s_count: int) -> None:
"""Compute and log statistics about the encoding."""
x, y, s = encoded.x, encoded.y, encoded.s
class_ids = labels_to_group_id(s=s, y=y, s_count=s_count)
Expand All @@ -153,7 +153,7 @@ def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_coun
mapper = umap.UMAP(n_neighbors=25, n_components=2) # type: ignore
umap_z = mapper.fit_transform(x.numpy())
umap_plot = visualize_clusters(umap_z, labels=class_ids, s_count=s_count)
to_log: dict[str, Union[wandb.Image, float]] = {"umap": wandb.Image(umap_plot)}
to_log: dict[str, wandb.Image | float] = {"umap": wandb.Image(umap_plot)}
logger.info("Done.")

for y_value in y.unique():
Expand All @@ -164,11 +164,11 @@ def _log_enc_statistics(encoded: Dataset[Tensor], *, step: Optional[int], s_coun


def visualize_clusters(
x: Union[np.ndarray, Tensor],
x: np.ndarray | Tensor,
*,
labels: Union[np.ndarray, Tensor],
labels: np.ndarray | Tensor,
s_count: int,
title: Optional[str] = None,
title: str | None = None,
legend: bool = True,
) -> plt.Figure: # type: ignore
if x.shape[1] != 2:
Expand Down Expand Up @@ -222,9 +222,9 @@ def _flip(items: Sequence[Any], ncol: int) -> Sequence[Any]:
class Evaluator:
steps: int = 10_000
batch_size: int = 128
hidden_dim: Optional[int] = None
hidden_dim: int | None = None
num_hidden: int = 0
eval_s_from_zs: Optional[EvalTrainData] = None
eval_s_from_zs: EvalTrainData | None = None
balanced_sampling: bool = True
umap_viz: bool = False
save_summary: bool = True
Expand Down Expand Up @@ -263,7 +263,7 @@ def _evaluate(
*,
input_dim: int,
device: torch.device,
step: Optional[int] = None,
step: int | None = None,
name: str = "",
pred_s: bool = False,
) -> None:
Expand All @@ -284,8 +284,8 @@ def run(
dm: DataModule,
*,
encoder: SplitLatentAe,
device: Union[str, torch.device, int],
step: Optional[int] = None,
device: str | torch.device | int,
step: int | None = None,
) -> DataModule:
device = resolve_device(device)
encoder.eval()
Expand Down Expand Up @@ -354,7 +354,7 @@ def __call__(
dm: DataModule,
*,
encoder: SplitLatentAe,
device: Union[str, torch.device, int],
step: Optional[int] = None,
device: str | torch.device | int,
step: int | None = None,
) -> DataModule:
return self.run(dm=dm, encoder=encoder, device=device, step=step)
8 changes: 4 additions & 4 deletions src/algs/adv/scorer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Final, Optional, Union
from typing import Final
from typing_extensions import override

from conduit.data import TernarySample
Expand Down Expand Up @@ -31,7 +31,7 @@ def _encode_and_score_recons(
dl: CdtDataLoader[TernarySample],
*,
ae: SplitLatentAe,
device: Union[str, torch.device],
device: str | torch.device,
minimize: bool = False,
) -> tuple[CdtDataset[TernarySample, Tensor, Tensor, Tensor], float]:
device = resolve_device(device)
Expand Down Expand Up @@ -99,8 +99,8 @@ def run(
class NeuralScorer(Scorer):
steps: int = 5_000
batch_size_tr: int = 16
batch_size_te: Optional[int] = None
batch_size_enc: Optional[int] = None
batch_size_te: int | None = None
batch_size_enc: int | None = None

opt: OptimizerCfg = field(default_factory=OptimizerCfg)
eval_batches: int = 1000
Expand Down
4 changes: 2 additions & 2 deletions src/algs/adv/supmatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, cast
from typing import cast
from typing_extensions import Self, override

from conduit.data.structures import TernarySample
Expand Down Expand Up @@ -155,7 +155,7 @@ def fit_evaluate_score(
disc: BinaryDiscriminator,
evaluator: Evaluator,
scorer: Scorer,
) -> Optional[float]:
) -> float | None:
"""First fit, then evaluate, then score."""
disc_model_sd0 = None
if isinstance(disc, NeuralDiscriminator) and isinstance(disc.model, SetPredictor):
Expand Down
5 changes: 2 additions & 3 deletions src/algs/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Optional

from loguru import logger
import torch
Expand All @@ -20,13 +19,13 @@ class Algorithm(DcModule):

use_amp: bool = False # Whether to use mixed-precision training
gpu: int = 0 # which GPU to use (if available)
max_grad_norm: Optional[float] = None
max_grad_norm: float | None = None

def __post_init__(self) -> None:
self.device: torch.device = resolve_device(self.gpu)
use_gpu = torch.cuda.is_available() and self.gpu >= 0
self.use_amp = self.use_amp and use_gpu
self.grad_scaler: Optional[GradScaler] = GradScaler() if self.use_amp else None
self.grad_scaler: GradScaler | None = GradScaler() if self.use_amp else None
logger.info(f"{torch.cuda.device_count()} GPU(s) available - using device '{self.device}'")

def _clip_gradients(self, parameters: Iterator[Parameter]) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/algs/fs/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Optional

from loguru import logger
from torch import Tensor
Expand Down Expand Up @@ -29,7 +28,7 @@ def alg_name(self) -> str:
def routine(self, dm: DataModule, *, model: nn.Module) -> EvalTuple[Tensor, None]:
raise NotImplementedError()

def run(self, dm: DataModule, *, model: nn.Module) -> Optional[float]:
def run(self, dm: DataModule, *, model: nn.Module) -> float | None:
if dm.deployment_ids is not None:
dm = dm.merge_deployment_into_train()
et = self.routine(dm=dm, model=model)
Expand Down
5 changes: 2 additions & 3 deletions src/algs/fs/dro.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional, Union
from typing_extensions import override

from conduit.types import Loss
Expand All @@ -18,10 +17,10 @@ class DroLoss(nn.Module, Loss):

def __init__(
self,
loss_fn: Optional[Loss] = None,
loss_fn: Loss | None = None,
*,
eta: float = 0.5,
reduction: Union[ReductionType, str] = ReductionType.mean,
reduction: ReductionType | str = ReductionType.mean,
) -> None:
"""Set up the loss, set which loss you want to optimize and the eta to offset by."""
super().__init__()
Expand Down
Loading

0 comments on commit d3c39c2

Please sign in to comment.