Skip to content

Commit

Permalink
[GSK-2571] modified behaviour of tests w.r.t. cropped DL (#26)
Browse files Browse the repository at this point in the history
* modified behaviour of tests w.r.t. cropped DL

* pdm format
  • Loading branch information
rabah-khalek authored Jan 16, 2024
1 parent e34504f commit 9922af2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
12 changes: 11 additions & 1 deletion giskard_vision/landmark_detection/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._margins = margins

@property
def name(self):
def name(self) -> str:
"""
Gets the name of the cropped data loader.
Expand All @@ -56,6 +56,16 @@ def name(self):
"""
return f"{self._wrapped_dataloader.name} cropped on {self._part.name}"

@property
def facial_part(self) -> FacialPart:
"""
Gets the facial_part used for the copping.
Returns:
FacialPart: The name of the cropped data loader.
"""
return self._part

def get_image(self, idx: int) -> np.ndarray:
"""
Gets a cropped image based on facial landmarks.
Expand Down
15 changes: 11 additions & 4 deletions giskard_vision/landmark_detection/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,23 @@ def run(
self,
model: FaceLandmarksModelBase,
dataloader: DataIteratorBase,
facial_part: FacialPart = FacialParts.ENTIRE.value,
facial_part: FacialPart = None,
) -> TestResult:
"""Run the test on the specified model and dataloader.
Passes if metric <= threhsold.
Args:
model (FaceLandmarksModelBase): Model to be evaluated.
dataloader (DataIteratorBase): Dataloader providing input data.
facial_part (FacialPart, optional): Facial part to consider during the evaluation. Defaults to entire face.
facial_part (FacialPart, optional): Facial part to consider during the evaluation. Defaults to entire face if dataloader doesn't have facial_part as property.
Returns:
TestResult: Result of the test.
"""
facial_part = (
getattr(dataloader, "facial_part", FacialParts.ENTIRE.value) if facial_part is None else facial_part
)
ground_truth = dataloader.all_marks
prediction_result = model.predict(dataloader, facial_part=facial_part)
metric_value = self.metric.get(prediction_result, ground_truth)
Expand Down Expand Up @@ -259,7 +262,7 @@ def run(
model: FaceLandmarksModelBase,
dataloader: DataIteratorBase,
dataloader_ref: DataIteratorBase,
facial_part: FacialPart = FacialParts.ENTIRE.value,
facial_part: Optional[FacialPart] = None, # FacialParts.ENTIRE.value,
) -> TestResult:
"""Run the differential test on the specified model and dataloaders.
Defined as metric_diff = (metric_ref-metric)/metric_ref.
Expand All @@ -269,12 +272,16 @@ def run(
model (FaceLandmarksModelBase): Model to be evaluated.
dataloader (DataIteratorBase): Main dataloader.
dataloader_ref (DataIteratorBase): Reference dataloader for comparison.
facial_part (FacialPart, optional): Facial part to consider during the evaluation. Defaults to entire face.
facial_part (FacialPart, optional): Facial part to consider during the evaluation. Defaults to entire face if dataloader doesn't have facial_part as property.
Returns:
TestResult: Result of the differential test.
"""
facial_part = (
getattr(dataloader, "facial_part", FacialParts.ENTIRE.value) if facial_part is None else facial_part
)

prediction_result = model.predict(dataloader, facial_part=facial_part)
prediction_result_ref = model.predict(dataloader_ref, facial_part=facial_part)

Expand Down
20 changes: 20 additions & 0 deletions tests/landmark_detection/tests_and_metrics/test_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from giskard_vision.landmark_detection.dataloaders.wrappers import CroppedDataLoader
from giskard_vision.landmark_detection.marks.facial_parts import FacialParts
from giskard_vision.landmark_detection.tests.base import Test, TestDiff
from giskard_vision.landmark_detection.tests.performance import NMEMean


def test_tests_on_cropped_dl(opencv_model, dataset_300w):
fp = FacialParts.LEFT_HALF.value
dl = CroppedDataLoader(dataset_300w, part=fp)

for test in [Test, TestDiff]:
kwargs = {"model": opencv_model, "dataloader": dl, "facial_part": fp}
if test == TestDiff:
kwargs["dataloader_ref"] = dataset_300w

test1 = test(metric=NMEMean, threshold=1).run(**kwargs)
kwargs.pop("facial_part")
test2 = test(metric=NMEMean, threshold=1).run(**kwargs)

assert test1.metric_value == test2.metric_value

0 comments on commit 9922af2

Please sign in to comment.