Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply transforms in PreProcessor #2467

Open
wants to merge 23 commits into
base: release/v2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/anomalib/data/dataclasses/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import numpy as np
import torch
from torch import tensor
from torch.utils.data import default_collate
from torchvision.transforms.v2.functional import resize
from torchvision.tv_tensors import Image, Mask, Video

ImageT = TypeVar("ImageT", Image, Video, np.ndarray)
Expand Down Expand Up @@ -656,5 +658,20 @@ def batch_size(self) -> int:
def collate(cls: type["BatchIterateMixin"], items: list[ItemT]) -> "BatchIterateMixin":
"""Convert a list of DatasetItem objects to a Batch object."""
keys = [key for key, value in asdict(items[0]).items() if value is not None]

# Check if all images have the same shape. If not, resize before collating
im_shapes = torch.vstack([tensor(item.image.shape) for item in items if item.image is not None])[..., 1:]
if torch.unique(im_shapes, dim=0).size(0) != 1: # check if batch has heterogeneous shapes
target_shape = im_shapes[
torch.unravel_index(im_shapes.argmax(), im_shapes.shape)[0],
:,
] # shape of image with largest H or W
for item in items:
for key in keys:
value = getattr(item, key)
if isinstance(value, Image | Mask):
setattr(item, key, resize(value, target_shape))
djdameln marked this conversation as resolved.
Show resolved Hide resolved

# collate the batch
out_dict = {key: default_collate([getattr(item, key) for item in items]) for key in keys}
return cls(**out_dict)
110 changes: 43 additions & 67 deletions src/anomalib/pre_processing/pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,17 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING

import torch
from lightning import Callback, LightningModule, Trainer
from lightning.pytorch.trainer.states import TrainerFn
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import Transform

from anomalib.data import Batch

from .utils.transform import (
get_dataloaders_transforms,
get_exportable_transform,
set_dataloaders_transforms,
set_datamodule_stage_transform,
)

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

from anomalib.data import AnomalibDataModule


class PreProcessor(nn.Module, Callback):
"""Anomalib pre-processor.
Expand Down Expand Up @@ -109,63 +99,49 @@ def __init__(
self.predict_transform = self.test_transform
self.export_transform = get_exportable_transform(self.test_transform)

def setup_datamodule_transforms(self, datamodule: "AnomalibDataModule") -> None:
"""Set up datamodule transforms."""
# If PreProcessor has transforms, propagate them to datamodule
if any([self.train_transform, self.val_transform, self.test_transform]):
transforms = {
"fit": self.train_transform,
"val": self.val_transform,
"test": self.test_transform,
"predict": self.predict_transform,
}

for stage, transform in transforms.items():
if transform is not None:
set_datamodule_stage_transform(datamodule, transform, stage)

def setup_dataloader_transforms(self, dataloaders: "EVAL_DATALOADERS | TRAIN_DATALOADERS") -> None:
"""Set up dataloader transforms."""
if isinstance(dataloaders, DataLoader):
dataloaders = [dataloaders]

# If PreProcessor has transforms, propagate them to dataloaders
if any([self.train_transform, self.val_transform, self.test_transform]):
transforms = {
"train": self.train_transform,
"val": self.val_transform,
"test": self.test_transform,
}
set_dataloaders_transforms(dataloaders, transforms)
return

# Try to get transforms from dataloaders
if dataloaders:
dataloaders_transforms = get_dataloaders_transforms(dataloaders)
if dataloaders_transforms:
self.train_transform = dataloaders_transforms.get("train")
self.val_transform = dataloaders_transforms.get("val")
self.test_transform = dataloaders_transforms.get("test")
self.predict_transform = self.test_transform
self.export_transform = get_exportable_transform(self.test_transform)

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Configure transforms at the start of each stage.

Args:
trainer: The Lightning trainer.
pl_module: The Lightning module.
stage: The stage (e.g., 'fit', 'validate', 'test', 'predict').
"""
stage = TrainerFn(stage).value # Ensure stage is str
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Batch, batch_idx: int) -> None:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
"""Apply transforms to the batch of tensors during training."""
del trainer, pl_module, batch_idx # Unused
if self.train_transform:
batch.image, batch.gt_mask = self.train_transform(batch.image, batch.gt_mask)

def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Batch,
batch_idx: int,
) -> None:
"""Apply transforms to the batch of tensors during validation."""
del trainer, pl_module, batch_idx # Unused
if self.val_transform:
batch.image, batch.gt_mask = self.val_transform(batch.image, batch.gt_mask)

if hasattr(trainer, "datamodule"):
self.setup_datamodule_transforms(datamodule=trainer.datamodule)
elif hasattr(trainer, f"{stage}_dataloaders"):
dataloaders = getattr(trainer, f"{stage}_dataloaders")
self.setup_dataloader_transforms(dataloaders=dataloaders)
def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Batch,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Apply transforms to the batch of tensors during testing."""
del trainer, pl_module, batch_idx, dataloader_idx # Unused
if self.test_transform:
batch.image, batch.gt_mask = self.test_transform(batch.image, batch.gt_mask)

super().setup(trainer, pl_module, stage)
def on_predict_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Batch,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Apply transforms to the batch of tensors during prediction."""
del trainer, pl_module, batch_idx, dataloader_idx # Unused
if self.predict_transform:
batch.image, batch.gt_mask = self.predict_transform(batch.image, batch.gt_mask)

def forward(self, batch: torch.Tensor) -> torch.Tensor:
"""Apply transforms to the batch of tensors for inference.
Expand Down
104 changes: 0 additions & 104 deletions src/anomalib/pre_processing/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,115 +3,11 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence

from torch.utils.data import DataLoader
from torchvision.transforms.v2 import CenterCrop, Compose, Resize, Transform

from anomalib.data import AnomalibDataModule
from anomalib.data.transforms import ExportableCenterCrop


def get_dataloaders_transforms(dataloaders: Sequence[DataLoader]) -> dict[str, Transform]:
"""Get transforms from dataloaders.

Args:
dataloaders: The dataloaders to get transforms from.

Returns:
Dictionary mapping stages to their transforms.
"""
transforms: dict[str, Transform] = {}
stage_lookup = {
"fit": "train",
"validate": "val",
"test": "test",
"predict": "test",
}

for dataloader in dataloaders:
if not hasattr(dataloader, "dataset") or not hasattr(dataloader.dataset, "transform"):
continue

for stage in stage_lookup:
if hasattr(dataloader, f"{stage}_dataloader"):
transforms[stage_lookup[stage]] = dataloader.dataset.transform

return transforms


def set_dataloaders_transforms(dataloaders: Sequence[DataLoader], transforms: dict[str, Transform | None]) -> None:
"""Set transforms to dataloaders.

Args:
dataloaders: The dataloaders to propagate transforms to.
transforms: Dictionary mapping stages to their transforms.
"""
stage_mapping = {
"fit": "train",
"validate": "val",
"test": "test",
"predict": "test", # predict uses test transform
}

for loader in dataloaders:
if not hasattr(loader, "dataset"):
continue

for stage in stage_mapping:
if hasattr(loader, f"{stage}_dataloader"):
transform = transforms.get(stage_mapping[stage])
if transform is not None:
set_dataloader_transform([loader], transform)


def set_dataloader_transform(dataloader: DataLoader | Sequence[DataLoader], transform: Transform) -> None:
"""Set a transform for a dataloader or list of dataloaders.

Args:
dataloader: The dataloader(s) to set the transform for.
transform: The transform to set.
"""
if isinstance(dataloader, DataLoader):
if hasattr(dataloader.dataset, "transform"):
dataloader.dataset.transform = transform
elif isinstance(dataloader, Sequence):
for dl in dataloader:
set_dataloader_transform(dl, transform)
else:
msg = f"Unsupported dataloader type: {type(dataloader)}"
raise TypeError(msg)


def set_datamodule_stage_transform(datamodule: AnomalibDataModule, transform: Transform, stage: str) -> None:
"""Set a transform for a specific stage in a AnomalibDataModule.

Args:
datamodule: The AnomalibDataModule to set the transform for.
transform: The transform to set.
stage: The stage to set the transform for.

Note:
The stage parameter maps to dataset attributes as follows:
- 'fit' -> 'train_data'
- 'validate' -> 'val_data'
- 'test' -> 'test_data'
- 'predict' -> 'test_data'
"""
stage_datasets = {
"fit": "train_data",
"validate": "val_data",
"test": "test_data",
"predict": "test_data",
}

dataset_attr = stage_datasets.get(stage)
if dataset_attr and hasattr(datamodule, dataset_attr):
dataset = getattr(datamodule, dataset_attr)
if hasattr(dataset, "transform"):
dataset.transform = transform


def get_exportable_transform(transform: Transform | None) -> Transform | None:
"""Get exportable transform.

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/pre_processing/test_pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_different_stage_transforms() -> None:
assert isinstance(processed_batch, torch.Tensor)
assert processed_batch.shape == (1, 3, 288, 288)

@pytest.skip
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
def test_setup_transforms_from_dataloaders(self) -> None:
"""Test setup method when transforms are obtained from dataloaders."""
# Mock dataloader with dataset having a transform
Expand All @@ -104,6 +105,7 @@ def test_setup_transforms_from_dataloaders(self) -> None:
assert pre_processor.val_transform == self.common_transform
assert pre_processor.test_transform == self.common_transform

@pytest.skip
def test_setup_transforms_priority(self) -> None:
"""Test setup method prioritizes PreProcessor transforms over datamodule/dataloaders."""
# Mock datamodule
Expand Down
Loading