Skip to content

Commit

Permalink
Fix normalization (#2130)
Browse files Browse the repository at this point in the history
* fix normalization

* precommit config...

* reset normalization metrics on validation start

* fix model loading and saving normalitzation metrics

* Update src/anomalib/callbacks/normalization/min_max_normalization.py

* Update src/anomalib/callbacks/normalization/min_max_normalization.py

---------

Co-authored-by: Samet Akcay <[email protected]>
  • Loading branch information
alexriedel1 and samet-akcay authored Jul 16, 2024
1 parent d094d4b commit d1f824a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 30 deletions.
48 changes: 34 additions & 14 deletions src/anomalib/callbacks/normalization/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torchmetrics import MetricCollection

from anomalib.metrics import MinMax
from anomalib.models.components import AnomalyModule
Expand All @@ -27,13 +28,26 @@ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None =
del trainer, stage # These variables are not used.

if not hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics = MinMax().cpu()
elif not isinstance(pl_module.normalization_metrics, MinMax):
msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
raise AttributeError(
msg,
pl_module.normalization_metrics = MetricCollection(
{
"anomaly_maps": MinMax().cpu(),
"box_scores": MinMax().cpu(),
"pred_scores": MinMax().cpu(),
},
)

elif not isinstance(pl_module.normalization_metrics, MetricCollection):
msg = (
f"Expected normalization_metrics to be of type MetricCollection"
f"got {type(pl_module.normalization_metrics)}"
)
raise TypeError(msg)

for name, metric in pl_module.normalization_metrics.items():
if not isinstance(metric, MinMax):
msg = f"Expected normalization_metric {name} to be of type MinMax, got {type(metric)}"
raise TypeError(msg)

def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
"""Call when the test begins."""
del trainer # `trainer` variable is not used.
Expand All @@ -42,6 +56,13 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
if metric is not None:
metric.set_threshold(0.5)

def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
"""Call when the validation epoch begins."""
del trainer # `trainer` variable is not used.

if hasattr(pl_module, "normalization_metrics"):
pl_module.normalization_metrics.reset()

def on_validation_batch_end(
self,
trainer: Trainer,
Expand All @@ -55,14 +76,11 @@ def on_validation_batch_end(
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.

if "anomaly_maps" in outputs:
pl_module.normalization_metrics(outputs["anomaly_maps"])
elif "box_scores" in outputs:
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
elif "pred_scores" in outputs:
pl_module.normalization_metrics(outputs["pred_scores"])
else:
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
raise ValueError(msg)
pl_module.normalization_metrics["anomaly_maps"](outputs["anomaly_maps"])
if "box_scores" in outputs:
pl_module.normalization_metrics["box_scores"](torch.cat(outputs["box_scores"]))
if "pred_scores" in outputs:
pl_module.normalization_metrics["pred_scores"](outputs["pred_scores"])

def on_test_batch_end(
self,
Expand Down Expand Up @@ -97,12 +115,14 @@ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: A
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
stats = pl_module.normalization_metrics.cpu()
if "pred_scores" in outputs:
stats = pl_module.normalization_metrics["pred_scores"].cpu()
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
if "anomaly_maps" in outputs:
stats = pl_module.normalization_metrics["anomaly_maps"].cpu()
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
if "box_scores" in outputs:
stats = pl_module.normalization_metrics["box_scores"].cpu()
outputs["box_scores"] = [
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
]
12 changes: 6 additions & 6 deletions src/anomalib/deploy/inferencers/base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,19 @@ def _normalize(
visualized and predicted scores.
"""
# min max normalization
if "min" in metadata and "max" in metadata:
if anomaly_maps is not None:
if "pred_scores.min" in metadata and "pred_scores.max" in metadata:
if anomaly_maps is not None and "anomaly_maps.max" in metadata:
anomaly_maps = normalize_min_max(
anomaly_maps,
metadata["pixel_threshold"],
metadata["min"],
metadata["max"],
metadata["anomaly_maps.min"],
metadata["anomaly_maps.max"],
)
pred_scores = normalize_min_max(
pred_scores,
metadata["image_threshold"],
metadata["min"],
metadata["max"],
metadata["pred_scores.min"],
metadata["pred_scores.max"],
)

return anomaly_maps, float(pred_scores)
Expand Down
26 changes: 20 additions & 6 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torchmetrics import MetricCollection
from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform

from anomalib import LearningType
Expand All @@ -25,7 +26,6 @@

if TYPE_CHECKING:
from lightning.pytorch.callbacks import Callback
from torchmetrics import Metric


logger = logging.getLogger(__name__)
Expand All @@ -49,7 +49,7 @@ def __init__(self) -> None:
self.image_threshold: BaseThreshold
self.pixel_threshold: BaseThreshold

self.normalization_metrics: Metric
self.normalization_metrics: MetricCollection

self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection
Expand Down Expand Up @@ -155,8 +155,9 @@ def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars:
f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}"
)
if hasattr(self, "normalization_metrics"):
normalization_class = self.normalization_metrics.__class__
destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}"
for metric in self.normalization_metrics:
metric_class = self.normalization_metrics[metric].__class__
destination[f"{metric}_normalization_class"] = f"{metric_class.__module__}.{metric_class.__name__}"

return super()._save_to_state_dict(destination, prefix, keep_vars)

Expand All @@ -166,8 +167,21 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
self.image_threshold = self._get_instance(state_dict, "image_threshold_class")
if "pixel_threshold_class" in state_dict:
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class")
if "normalization_class" in state_dict:
self.normalization_metrics = self._get_instance(state_dict, "normalization_class")

if "anomaly_maps_normalization_class" in state_dict:
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
if "box_scores_normalization_class" in state_dict:
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
if "pred_scores_normalization_class" in state_dict:
self.pred_scores_normalization_metrics = self._get_instance(state_dict, "pred_scores_normalization_class")

self.normalization_metrics = MetricCollection(
{
"anomaly_maps": self.anomaly_maps_normalization_metrics,
"box_scores": self.box_scores_normalization_metrics,
"pred_scores": self.pred_scores_normalization_metrics,
},
)
# Used to load metrics if there is any related data in state_dict
self._load_metrics(state_dict)

Expand Down
6 changes: 3 additions & 3 deletions src/anomalib/models/image/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs)
"""
del args, kwargs # These variables are not used.

anomaly_maps, anomaly_scores = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_scores
output = self.model(batch["image"])
batch["anomaly_maps"] = output["anomaly_map"]
batch["pred_scores"] = output["pred_score"]
return batch

@property
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/image/csflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z_dist, _ = self.graph(features) # Ignore Jacobians
anomaly_scores = self._compute_anomaly_scores(z_dist)
anomaly_maps = self.anomaly_map_generator(z_dist)
output = anomaly_maps, anomaly_scores
output = {"anomaly_map": anomaly_maps, "pred_score": anomaly_scores}
return output

def _compute_anomaly_scores(self, z_dists: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit d1f824a

Please sign in to comment.