Skip to content

Commit

Permalink
Ignore error in decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 23, 2024
1 parent 3a8f946 commit 4cf7dd3
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 25 deletions.
6 changes: 3 additions & 3 deletions src/algs/adv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Components(DcModule, Generic[D]):
pred_y: Optional[Classifier]
pred_s: Optional[Classifier]

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def train_ae(self) -> None:
self.ae.train()
if self.pred_y is not None:
Expand All @@ -52,7 +52,7 @@ def train_ae(self) -> None:
if isinstance(self.disc, nn.Module):
self.disc.eval()

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def train_disc(self) -> None:
self.ae.eval()
if self.pred_y is not None:
Expand Down Expand Up @@ -172,7 +172,7 @@ def training_step(
self.log_recons(x=x_dep, dm=dm, ae=comp.ae, itr=itr, split="deployment")
return logging_dict

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def log_recons(
self,
x: Tensor,
Expand Down
4 changes: 2 additions & 2 deletions src/algs/adv/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_PBAR_COL: Final[str] = "#ffe252"


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def _encode_and_score_recons(
dl: CdtDataLoader[TernarySample],
*,
Expand Down Expand Up @@ -60,7 +60,7 @@ def _encode_and_score_recons(
return CdtDataset(x=zy, y=y, s=s), recon_score


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def balanced_accuracy(y_pred: Tensor, *, y_true: Tensor) -> Tensor:
return cdtm.subclass_balanced_accuracy(y_pred=y_pred, y_true=y_true, s=y_true)

Expand Down
6 changes: 3 additions & 3 deletions src/algs/fs/lff.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def __init__(self, labels: Tensor, *, alpha: float = 0.9) -> None:
self.register_buffer("parameter", torch.zeros(len(labels)))
self.register_buffer("updated", torch.zeros(len(labels)))

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def update(self, data: Tensor, *, index: Union[Tensor, int]) -> None:
self.parameter[index] = (
self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
)
self.updated[index] = 1

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def max_loss(self, label: int) -> Tensor:
label_index = self.labels == label
return self.parameter[label_index].max()

@override
@torch.no_grad()
@torch.no_grad() # pyright: ignore
def __getitem__(self, index: IndexType) -> Tensor:
return self.parameter[index].clone()

Expand Down
4 changes: 2 additions & 2 deletions src/arch/autoencoder/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
FILENAME: Final[str] = "model.pt"


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def save_ae_artifact(
model: AePair, *, run: Union[Run, RunDisabled], factory_config: dict[str, Any], name: str
) -> None:
Expand Down Expand Up @@ -52,7 +52,7 @@ def _process_root_dir(root: Optional[Union[Path, str]]) -> Path:
return root


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def load_ae_from_artifact(
name: str,
*,
Expand Down
2 changes: 1 addition & 1 deletion src/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SummaryMetric(Enum):
)


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def compute_metrics(
pair: EmEvalPair,
*,
Expand Down
4 changes: 2 additions & 2 deletions src/labelling/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_npz(cls, fpath: Union[Path, str]) -> Self:
return enc


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def generate_encodings(
dm: DataModule,
*,
Expand Down Expand Up @@ -102,7 +102,7 @@ def generate_encodings(
return encodings


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def encode_with_group_ids(
model: nn.Module, *, dl: CdtDataLoader[TernarySample[Tensor]], device: Union[str, torch.device]
) -> tuple[Tensor, Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions src/labelling/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
def forward(self, x: Tensor) -> Tensor: # type: ignore
return self.encoder(x)

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def load_from_path(self, fpath: Union[Path, str]) -> None:
fpath = Path(fpath)
if fpath.exists():
Expand All @@ -58,7 +58,7 @@ def load_from_path(self, fpath: Union[Path, str]) -> None:
else:
raise RuntimeError(f"Checkpoint {fpath.resolve()} does not exist.")

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def encode(
self,
dm: DataModule,
Expand Down
4 changes: 2 additions & 2 deletions src/labelling/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train_step(
optimizer.step()
return output, loss.item()

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def predict_loop(
self,
model: nn.Module,
Expand Down Expand Up @@ -159,7 +159,7 @@ def predict_loop(
y = torch.cat(all_y)
return preds, s, y

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def validate(
self,
model: nn.Module,
Expand Down
6 changes: 3 additions & 3 deletions src/labelling/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
__all__ = ["ClnMetric", "centroidal_label_noise", "sample_noise_indices", "uniform_label_noise"]


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def sample_noise_indices(
labels: Tensor,
*,
Expand All @@ -27,7 +27,7 @@ def sample_noise_indices(
return torch.randperm(len(labels), generator=generator)[:num_to_flip]


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def uniform_label_noise(
labels: Tensor,
*,
Expand All @@ -50,7 +50,7 @@ class ClnMetric(Enum):
EUCLIDEAN = "euclidean"


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def centroidal_label_noise(
labels: Tensor,
*,
Expand Down
2 changes: 1 addition & 1 deletion src/labelling/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ClipClassifier(Labeller):

# cache_encoder: bool = False

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def evaluate(
self, g_pred: Tensor, *, g_true: Tensor, use_wandb: bool, prefix: Optional[str] = None
) -> dict[str, float]:
Expand Down
8 changes: 4 additions & 4 deletions src/models/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
__all__ = ["Classifier", "SetClassifier"]


@torch.no_grad()
@torch.no_grad() # pyright: ignore
def cat_cpu_flatten(*ls: list[Tensor], dim: int = 0) -> Iterator[Tensor]:
for ls_ in ls:
yield torch.cat(ls_, dim=dim).cpu().flatten()
Expand Down Expand Up @@ -58,7 +58,7 @@ def predict(
) -> EvalTuple[Tensor, Tensor]:
...

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def predict(
self,
data: CdtDataLoader[TernarySample],
Expand Down Expand Up @@ -174,7 +174,7 @@ class SetClassifier(Model):
model: SetPredictor # overriding the definition in `Model`
criterion: Optional[Loss] = None

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def _fetch_train_data(
self, *args: tuple[Iterator[S], int], device: torch.device
) -> Iterator[_ScSample]:
Expand Down Expand Up @@ -233,7 +233,7 @@ def fit(
pbar.close()
logger.info("Finished training")

@torch.no_grad()
@torch.no_grad() # pyright: ignore
def predict(
self, *dls: CdtDataLoader[S], device: Union[torch.device, str], max_steps: int
) -> EvalTuple[None, None]:
Expand Down

0 comments on commit 4cf7dd3

Please sign in to comment.