diff --git a/src/anomalib/models/components/base/anomalib_module.py b/src/anomalib/models/components/base/anomalib_module.py index b5fc6a57cf..8c056a4655 100644 --- a/src/anomalib/models/components/base/anomalib_module.py +++ b/src/anomalib/models/components/base/anomalib_module.py @@ -43,7 +43,7 @@ import logging import warnings from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from pathlib import Path from typing import Any @@ -137,10 +137,10 @@ def __init__( self.loss: nn.Module self.callbacks: list[Callback] - self.pre_processor = self._resolve_pre_processor(pre_processor) - self.post_processor = self._resolve_post_processor(post_processor) - self.evaluator = self._resolve_evaluator(evaluator) - self.visualizer = self._resolve_visualizer(visualizer) + self.pre_processor = self._resolve_component(pre_processor, PreProcessor, self.configure_pre_processor) + self.post_processor = self._resolve_component(post_processor, PostProcessor, self.configure_post_processor) + self.evaluator = self._resolve_component(evaluator, Evaluator, self.configure_evaluator) + self.visualizer = self._resolve_component(visualizer, Visualizer, self.configure_visualizer) self._input_size: tuple[int, int] | None = None self._is_setup = False @@ -299,34 +299,46 @@ def learning_type(self) -> LearningType: """ raise NotImplementedError - def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None: - """Resolve and validate the pre-processor configuration. + @staticmethod + def _resolve_component( + component: nn.Module | None, + component_type: type, + default_callable: Callable, + ) -> nn.Module | None: + """Resolve and validate the subcomponent configuration. + + This method resolves the configuration for various subcomponents like + pre-processor, post-processor, evaluator and visualizer. It validates + the configuration and returns the configured component. If the component + is a boolean, it uses the default callable to create the component. If + the component is already an instance of the component type, it returns + the component as is. Args: - pre_processor (PreProcessor | bool): Pre-processor configuration - - ``True`` -> use default pre-processor - - ``False`` -> no pre-processor - - ``PreProcessor`` -> use provided pre-processor + component (object): Component configuration + component_type (Type): Type of the component + default_callable (Callable): Callable to create default component Returns: - PreProcessor | None: Configured pre-processor + Component | None: Configured component Raises: - TypeError: If pre_processor is invalid type + TypeError: If component is invalid type """ - if isinstance(pre_processor, PreProcessor): - return pre_processor - if isinstance(pre_processor, bool): - return self.configure_pre_processor() if pre_processor else None - msg = f"Invalid pre-processor type: {type(pre_processor)}" + if isinstance(component, component_type): + return component + if isinstance(component, bool): + return default_callable() if component else None + msg = f"Passed object should be {component_type} or bool, got: {type(component)}" raise TypeError(msg) - @classmethod - def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor: + @staticmethod + def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor: """Configure the default pre-processor. The default pre-processor resizes images and normalizes using ImageNet - statistics. + statistics. Override this method to provide a custom pre-processor for + the model. Args: image_size (tuple[int, int] | None, optional): Target size for @@ -348,31 +360,12 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P ]), ) - def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None: - """Resolve and validate the post-processor configuration. - - Args: - post_processor (PostProcessor | bool): Post-processor configuration - - ``True`` -> use default post-processor - - ``False`` -> no post-processor - - ``PostProcessor`` -> use provided post-processor - - Returns: - PostProcessor | None: Configured post-processor - - Raises: - TypeError: If post_processor is invalid type - """ - if isinstance(post_processor, PostProcessor): - return post_processor - if isinstance(post_processor, bool): - return self.configure_post_processor() if post_processor else None - msg = f"Invalid post-processor type: {type(post_processor)}" - raise TypeError(msg) - def configure_post_processor(self) -> PostProcessor | None: """Configure the default post-processor. + The default post-processor is based on the model's learning type. Override + this method to provide a custom post-processor for the model. + Returns: PostProcessor | None: Configured post-processor based on learning type @@ -394,34 +387,12 @@ def configure_post_processor(self) -> PostProcessor | None: ) raise NotImplementedError(msg) - def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None: - """Resolve and validate the evaluator configuration. - - Args: - evaluator (Evaluator | bool): Evaluator configuration - - ``True`` -> use default evaluator - - ``False`` -> no evaluator - - ``Evaluator`` -> use provided evaluator - - Returns: - Evaluator | None: Configured evaluator - - Raises: - TypeError: If evaluator is invalid type - """ - if isinstance(evaluator, Evaluator): - return evaluator - if isinstance(evaluator, bool): - return self.configure_evaluator() if evaluator else None - msg = f"evaluator must be of type Evaluator or bool, got {type(evaluator)}" - raise TypeError(msg) - @staticmethod def configure_evaluator() -> Evaluator: """Configure the default evaluator. The default evaluator includes metrics for both image-level and - pixel-level evaluation. + pixel-level evaluation. Override this method to provide custom metrics for the model. Returns: Evaluator: Configured evaluator with default metrics @@ -438,32 +409,12 @@ def configure_evaluator() -> Evaluator: test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score] return Evaluator(test_metrics=test_metrics) - def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None: - """Resolve and validate the visualizer configuration. - - Args: - visualizer (Visualizer | bool): Visualizer configuration - - ``True`` -> use default visualizer - - ``False`` -> no visualizer - - ``Visualizer`` -> use provided visualizer - - Returns: - Visualizer | None: Configured visualizer - - Raises: - TypeError: If visualizer is invalid type - """ - if isinstance(visualizer, Visualizer): - return visualizer - if isinstance(visualizer, bool): - return self.configure_visualizer() if visualizer else None - msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}" - raise TypeError(msg) - @classmethod def configure_visualizer(cls) -> ImageVisualizer: """Configure the default visualizer. + Override this method to provide a custom visualizer for the model. + Returns: ImageVisualizer: Default image visualizer instance diff --git a/tests/unit/models/components/base/test_anomaly_module.py b/tests/unit/models/components/base/test_anomaly_module.py index 1578fc9e17..0c522998ae 100644 --- a/tests/unit/models/components/base/test_anomaly_module.py +++ b/tests/unit/models/components/base/test_anomaly_module.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torch import nn from anomalib.models.components.base import AnomalibModule @@ -57,3 +58,60 @@ def test_from_config(self, model_name: str) -> None: model = AnomalibModule.from_config(config_path=config_path) assert model is not None assert isinstance(model, AnomalibModule) + + +class TestResolveComponents: + """Test AnomalibModule._resolve_component.""" + + class DummyComponent(nn.Module): + """Dummy component class.""" + + def __init__(self, value: int) -> None: + self.value = value + + @classmethod + def dummy_configure_component(cls) -> DummyComponent: + """Dummy configure component method, simulates configure_ methods in module.""" + return cls.DummyComponent(value=1) + + def test_component_passed(self) -> None: + """Test that the component is returned as is if it is an instance of the component type.""" + component = self.DummyComponent(value=0) + resolved = AnomalibModule._resolve_component( # noqa: SLF001 + component=component, + component_type=self.DummyComponent, + default_callable=self.dummy_configure_component, + ) + assert isinstance(resolved, self.DummyComponent) + assert resolved.value == 0 + + def test_component_true(self) -> None: + """Test that the default_callable is called if component is True.""" + component = True + resolved = AnomalibModule._resolve_component( # noqa: SLF001 + component=component, + component_type=self.DummyComponent, + default_callable=self.dummy_configure_component, + ) + assert isinstance(resolved, self.DummyComponent) + assert resolved.value == 1 + + def test_component_false(self) -> None: + """Test that None is returned if component is False.""" + component = False + resolved = AnomalibModule._resolve_component( # noqa: SLF001 + component=component, + component_type=self.DummyComponent, + default_callable=self.dummy_configure_component, + ) + assert resolved is None + + def test_raises_type_error(self) -> None: + """Test that a TypeError is raised if the component is not of the correct type.""" + component = 1 + with pytest.raises(TypeError): + AnomalibModule._resolve_component( # noqa: SLF001 + component=component, + component_type=self.DummyComponent, + default_callable=self.dummy_configure_component, + )