Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify subcomponent resolve in base module #2473

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 35 additions & 88 deletions src/anomalib/models/components/base/anomalib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -137,10 +137,10 @@ def __init__(
self.loss: nn.Module
self.callbacks: list[Callback]

self.pre_processor = self._resolve_pre_processor(pre_processor)
self.post_processor = self._resolve_post_processor(post_processor)
self.evaluator = self._resolve_evaluator(evaluator)
self.visualizer = self._resolve_visualizer(visualizer)
self.pre_processor = self._resolve_component(pre_processor, PreProcessor, self.configure_pre_processor)
self.post_processor = self._resolve_component(post_processor, PostProcessor, self.configure_post_processor)
self.evaluator = self._resolve_component(evaluator, Evaluator, self.configure_evaluator)
self.visualizer = self._resolve_component(visualizer, Visualizer, self.configure_visualizer)

self._input_size: tuple[int, int] | None = None
self._is_setup = False
Expand Down Expand Up @@ -299,34 +299,42 @@ def learning_type(self) -> LearningType:
"""
raise NotImplementedError

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate the pre-processor configuration.
@staticmethod
def _resolve_component(component: nn.Module | None, component_type: type, default_callable: Callable) -> nn.Module:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
"""Resolve and validate the subcomponent configuration.

This method resolves the configuration for various subcomponents like
pre-processor, post-processor, evaluator and visualizer. It validates
the configuration and returns the configured component. If the component
is a boolean, it uses the default callable to create the component. If
the component is already an instance of the component type, it returns
the component as is.

Args:
pre_processor (PreProcessor | bool): Pre-processor configuration
- ``True`` -> use default pre-processor
- ``False`` -> no pre-processor
- ``PreProcessor`` -> use provided pre-processor
component (object): Component configuration
component_type (Type): Type of the component
default_callable (Callable): Callable to create default component

Returns:
PreProcessor | None: Configured pre-processor
Component | None: Configured component

Raises:
TypeError: If pre_processor is invalid type
TypeError: If component is invalid type
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
if isinstance(component, component_type):
return component
if isinstance(component, bool):
return default_callable() if component else None
msg = f"Passed object should be {component_type} or bool, got: {type(component)}"
raise TypeError(msg)

@classmethod
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
@staticmethod
def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor:
"""Configure the default pre-processor.

The default pre-processor resizes images and normalizes using ImageNet
statistics.
statistics. Override this method to provide a custom pre-processor for
the model.

Args:
image_size (tuple[int, int] | None, optional): Target size for
Expand All @@ -348,31 +356,12 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P
]),
)

def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None:
"""Resolve and validate the post-processor configuration.

Args:
post_processor (PostProcessor | bool): Post-processor configuration
- ``True`` -> use default post-processor
- ``False`` -> no post-processor
- ``PostProcessor`` -> use provided post-processor

Returns:
PostProcessor | None: Configured post-processor

Raises:
TypeError: If post_processor is invalid type
"""
if isinstance(post_processor, PostProcessor):
return post_processor
if isinstance(post_processor, bool):
return self.configure_post_processor() if post_processor else None
msg = f"Invalid post-processor type: {type(post_processor)}"
raise TypeError(msg)

def configure_post_processor(self) -> PostProcessor | None:
"""Configure the default post-processor.

The default post-processor is based on the model's learning type. Override
this method to provide a custom post-processor for the model.

Returns:
PostProcessor | None: Configured post-processor based on learning type

Expand All @@ -394,34 +383,12 @@ def configure_post_processor(self) -> PostProcessor | None:
)
raise NotImplementedError(msg)

def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None:
"""Resolve and validate the evaluator configuration.

Args:
evaluator (Evaluator | bool): Evaluator configuration
- ``True`` -> use default evaluator
- ``False`` -> no evaluator
- ``Evaluator`` -> use provided evaluator

Returns:
Evaluator | None: Configured evaluator

Raises:
TypeError: If evaluator is invalid type
"""
if isinstance(evaluator, Evaluator):
return evaluator
if isinstance(evaluator, bool):
return self.configure_evaluator() if evaluator else None
msg = f"evaluator must be of type Evaluator or bool, got {type(evaluator)}"
raise TypeError(msg)

@staticmethod
def configure_evaluator() -> Evaluator:
"""Configure the default evaluator.

The default evaluator includes metrics for both image-level and
pixel-level evaluation.
pixel-level evaluation. Override this method to provide custom metrics for the model.

Returns:
Evaluator: Configured evaluator with default metrics
Expand All @@ -438,32 +405,12 @@ def configure_evaluator() -> Evaluator:
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
return Evaluator(test_metrics=test_metrics)

def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None:
"""Resolve and validate the visualizer configuration.

Args:
visualizer (Visualizer | bool): Visualizer configuration
- ``True`` -> use default visualizer
- ``False`` -> no visualizer
- ``Visualizer`` -> use provided visualizer

Returns:
Visualizer | None: Configured visualizer

Raises:
TypeError: If visualizer is invalid type
"""
if isinstance(visualizer, Visualizer):
return visualizer
if isinstance(visualizer, bool):
return self.configure_visualizer() if visualizer else None
msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}"
raise TypeError(msg)

@classmethod
def configure_visualizer(cls) -> ImageVisualizer:
"""Configure the default visualizer.

Override this method to provide a custom visualizer for the model.

Returns:
ImageVisualizer: Default image visualizer instance

Expand Down
58 changes: 58 additions & 0 deletions tests/unit/models/components/base/test_anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import pytest
from torch import nn

from anomalib.models.components.base import AnomalibModule

Expand Down Expand Up @@ -57,3 +58,60 @@ def test_from_config(self, model_name: str) -> None:
model = AnomalibModule.from_config(config_path=config_path)
assert model is not None
assert isinstance(model, AnomalibModule)


class TestResolveComponents:
"""Test AnomalibModule._resolve_component."""

class DummyComponent(nn.Module):
"""Dummy component class."""

def __init__(self, value: int) -> None:
self.value = value

@classmethod
def dummy_configure_component(cls) -> DummyComponent:
"""Dummy configure component method, simulates configure_<component> methods in module."""
return cls.DummyComponent(value=1)

def test_component_passed(self) -> None:
"""Test that the component is returned as is if it is an instance of the component type."""
component = self.DummyComponent(value=0)
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert isinstance(resolved, self.DummyComponent)
assert resolved.value == 0

def test_component_true(self) -> None:
"""Test that the default_callable is called if component is True."""
component = True
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert isinstance(resolved, self.DummyComponent)
assert resolved.value == 1

def test_component_false(self) -> None:
"""Test that None is returned if component is False."""
component = False
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert resolved is None

def test_raises_type_error(self) -> None:
"""Test that a TypeError is raised if the component is not of the correct type."""
component = 1
with pytest.raises(TypeError):
AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)