Skip to content

Commit

Permalink
Fix or silence the remaining type errors (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Feb 24, 2024
2 parents e8d0e1b + 831304a commit acc7c8a
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 86 deletions.
135 changes: 77 additions & 58 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ neoconfigen = ">=2.3.3"
numpy = { version = ">=1.23.2" }
pandas = { version = ">=1.5.0" }
pillow = "*"
python = ">=3.10,<3.12"
python = ">=3.10,<3.13"
ranzen = { version = "^2.1.2" }
scikit-image = ">=0.14"
scikit_learn = { version = ">=0.20.1" }
scipy = { version = ">=1.2.1" }
seaborn = { version = ">=0.9.0" }
torch-conduit = { version = "^0.3.4", extras = ["image"] }
torch-conduit = { version = ">=0.3.4", extras = ["image"] }

tqdm = { version = ">=4.31.1" }
typer = "*"
Expand All @@ -56,6 +56,7 @@ torchvision = ">=0.15.2"
ruff = "*"
types-tqdm = "*"
pandas-stubs = "*"
python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608"}

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
6 changes: 3 additions & 3 deletions src/algs/adv/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from conduit.data.datasets import CdtDataLoader, CdtDataset
import conduit.metrics as cdtm
from conduit.models.utils import prefix_keys
from conduit.types import Loss
from loguru import logger
from ranzen.misc import gcopy
from ranzen.torch.loss import CrossEntropyLoss, ReductionType
Expand Down Expand Up @@ -141,9 +142,8 @@ def run(
score = recon_score = self.recon_score_w * 0.5 * (recon_score_tr + recon_score_dep)
logger.info(f"Aggregate reconstruction score: {recon_score}")

classifier = SetClassifier(
model=disc, opt=self.opt, criterion=CrossEntropyLoss(reduction=ReductionType.mean)
)
cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.mean) # type: ignore
classifier = SetClassifier(model=disc, opt=self.opt, criterion=cross_entropy)
logger.info("Training invariance-scorer")
classifier.fit(
dm.train_dataloader(batch_size=self.batch_size_tr),
Expand Down
5 changes: 3 additions & 2 deletions src/algs/fs/dro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ def __init__(
reduction = str_to_enum(str_=reduction, enum=ReductionType)
self.reduction = reduction
if loss_fn is None:
loss_fn = CrossEntropyLoss(reduction=ReductionType.none)
cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore
loss_fn = cross_entropy
else:
loss_fn.reduction = ReductionType.none
self.reduction = reduction
self.loss_fn = loss_fn
self.eta = eta

@override
def forward(self, input: Tensor, *, target: Tensor) -> Tensor: # type: ignore
def forward(self, input: Tensor, *, target: Tensor) -> Tensor:
sample_losses = (self.loss_fn(input, target=target) - self.eta).relu().pow(2)
return reduce(sample_losses, reduction_type=self.reduction)

Expand Down
8 changes: 5 additions & 3 deletions src/algs/fs/lff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing_extensions import Self, override

from conduit.data.datasets.base import CdtDataset
from conduit.data.structures import XI, LoadedData, SampleBase, SizedDataset, TernarySample, X
from conduit.data.structures import LoadedData, SampleBase, SizedDataset, TernarySample, X
from conduit.types import Indexable, IndexType
import numpy as np
from ranzen.misc import gcopy
from ranzen.torch import CrossEntropyLoss
import torch
Expand Down Expand Up @@ -74,7 +75,8 @@ def __add__(self, other: Self) -> Self:
return copy

@override
def __getitem__(self: "IndexedSample[XI]", index: IndexType) -> "IndexedSample[XI]":
def __getitem__(self, index: IndexType) -> Self:
assert isinstance(self.x, (Tensor, np.ndarray)), "x is not indexable"
return gcopy(
self, deep=False, x=self.x[index], y=self.y[index], s=self.s[index], idx=self.idx[index]
)
Expand All @@ -101,7 +103,7 @@ def __len__(self) -> int:

@dataclass(kw_only=True, repr=False, eq=False, frozen=True)
class LfFClassifier(Classifier):
criterion: CrossEntropyLoss
criterion: CrossEntropyLoss # type: ignore
sample_loss_ema_b: LabelEma
sample_loss_ema_d: LabelEma
q: float = 0.7
Expand Down
5 changes: 3 additions & 2 deletions src/algs/fs/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
if isinstance(gamma, ListConfig):
gamma = list(gamma)
if loss_fn is None:
loss_fn = CrossEntropyLoss(reduction=ReductionType.mean)
cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.mean) # type: ignore
loss_fn = cross_entropy
self.loss_fn = loss_fn
if isinstance(lambda_, (tuple, list)):
self.register_buffer("lambda_", torch.as_tensor(lambda_, dtype=torch.float))
Expand All @@ -49,7 +50,7 @@ def reduction(self) -> Union[ReductionType, str]:
return self.loss_fn.reduction

@reduction.setter
def reduction(self, value: Union[ReductionType, str]) -> None:
def reduction(self, value: Union[ReductionType, str]) -> None: # type: ignore
self.loss_fn.reduction = value

@override
Expand Down
Loading

0 comments on commit acc7c8a

Please sign in to comment.