Skip to content

Commit

Permalink
Add ability to validate on multiple columns
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony Naddeo committed Mar 23, 2024
1 parent 3ce949e commit 2483bde
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 54 deletions.
123 changes: 92 additions & 31 deletions langkit/validators/comparison.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Set, Union
from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -139,39 +140,45 @@ def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequen
return []


@dataclass
class ConstraintValidatorOptions:
target_metric: str
upper_threshold: Optional[Union[float, int]] = None
upper_threshold_inclusive: Optional[Union[float, int]] = None
lower_threshold: Optional[Union[float, int]] = None
lower_threshold_inclusive: Optional[Union[float, int]] = None
one_of: Optional[Sequence[Union[str, float, int]]] = None
none_of: Optional[Sequence[Union[str, float, int]]] = None
must_be_non_none: Optional[bool] = None
must_be_none: Optional[bool] = None


class ConstraintValidator(Validator):
def __init__(
self,
target_metric: str,
upper_threshold: Optional[Union[float, int]] = None,
upper_threshold_inclusive: Optional[Union[float, int]] = None,
lower_threshold: Optional[Union[float, int]] = None,
lower_threshold_inclusive: Optional[Union[float, int]] = None,
one_of: Optional[Sequence[Union[str, float, int]]] = None,
none_of: Optional[Sequence[Union[str, float, int]]] = None,
must_be_non_none: Optional[bool] = None,
must_be_none: Optional[bool] = None,
):
def __init__(self, options: ConstraintValidatorOptions):
validation_functions: List[Callable[[Any, str], Sequence[ValidationFailure]]] = []

if upper_threshold is not None:
validation_functions.append(partial(_enforce_upper_threshold, target_metric, upper_threshold))
if lower_threshold is not None:
validation_functions.append(partial(_enforce_lower_threshold, target_metric, lower_threshold))
if upper_threshold_inclusive is not None:
validation_functions.append(partial(_enforce_upper_threshold_inclusive, target_metric, upper_threshold_inclusive))
if lower_threshold_inclusive is not None:
validation_functions.append(partial(_enforce_lower_threshold_inclusive, target_metric, lower_threshold_inclusive))
if one_of is not None:
validation_functions.append(partial(_enforce_one_of, target_metric, set(one_of)))
if none_of is not None:
validation_functions.append(partial(_enforce_none_of, target_metric, set(none_of)))
if must_be_non_none is not None:
validation_functions.append(partial(_enforce_must_be_non_none, target_metric))
if must_be_none is not None:
validation_functions.append(partial(_enforce_must_be_none, target_metric))

self._target_metric = target_metric
if options.upper_threshold is not None:
validation_functions.append(partial(_enforce_upper_threshold, options.target_metric, options.upper_threshold))
if options.lower_threshold is not None:
validation_functions.append(partial(_enforce_lower_threshold, options.target_metric, options.lower_threshold))
if options.upper_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_upper_threshold_inclusive, options.target_metric, options.upper_threshold_inclusive)
)
if options.lower_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_lower_threshold_inclusive, options.target_metric, options.lower_threshold_inclusive)
)
if options.one_of is not None:
validation_functions.append(partial(_enforce_one_of, options.target_metric, set(options.one_of)))
if options.none_of is not None:
validation_functions.append(partial(_enforce_none_of, options.target_metric, set(options.none_of)))
if options.must_be_non_none is not None:
validation_functions.append(partial(_enforce_must_be_non_none, options.target_metric))
if options.must_be_none is not None:
validation_functions.append(partial(_enforce_must_be_none, options.target_metric))

self._target_metric = options.target_metric
self._validation_functions = validation_functions

if len(validation_functions) == 0:
Expand All @@ -197,3 +204,57 @@ def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]:
return None

return ValidationResult(failures)


class MultiColumnConstraintValidator(Validator):
def __init__(
self,
constraints: List[ConstraintValidatorOptions],
operator: Literal["AND", "OR"] = "AND",
report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC",
):
"""
:param constraints: List of constraint options to validate
:param operator: Operator to combine the constraints. Either "AND" or "OR". AND requires that all of the
constraints trigger, while OR requires that at least one triggers.
:param report_mode: How to report the validation result. If "FIRST_FAILED_METRIC", then this validator will
return a single validation result when there are failures, and that validation result will contain the
first failed metric. If "ALL_FAILED_METRICS", then this validator will return each validation failure.
"""
self._operator = operator
self._constraints = [ConstraintValidator(constraint) for constraint in constraints]
self._report_mode = report_mode

def get_target_metric_names(self) -> List[str]:
target_metrics: List[str] = []
for constraint in self._constraints:
target_metrics.extend(constraint.get_target_metric_names())
return target_metrics

def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]:
"""
Validate all of the contraint validators and combine them using the specified operator.
If the output of the operator is True, then return the validation result according to the report mode.
"""
all_failures: List[ValidationFailure] = []
for constraint in self._constraints:
result = constraint.validate_result(df)
if result:
all_failures.extend(result.report)

if len(all_failures) == 0:
return None

if self._report_mode == "FIRST_FAILED_METRIC":
# Create a new message that explains the failure happened because of the operator+ the names of the other failed metrics
failure = all_failures[0]
failure_metric_names = [failure.metric for failure in all_failures]
trigger_details = (
f". Triggered because of failures in {', '.join(failure_metric_names)} ({self._operator})." if len(all_failures) > 1 else ""
)
failure_details = f"{failure.details}{trigger_details}"

return ValidationResult([replace(failure, details=failure_details)])

return ValidationResult(all_failures)
32 changes: 21 additions & 11 deletions langkit/validators/library.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Sequence, Union
from typing import List, Literal, Optional, Sequence, Union

from langkit.core.validation import Validator
from langkit.validators.comparison import ConstraintValidator
from langkit.validators.comparison import ConstraintValidator, ConstraintValidatorOptions, MultiColumnConstraintValidator


class lib:
Expand Down Expand Up @@ -69,13 +69,23 @@ def constraint(
must_be_none: Optional[bool] = None,
) -> Validator:
return ConstraintValidator(
target_metric=target_metric,
upper_threshold=upper_threshold,
upper_threshold_inclusive=upper_threshold_inclusive,
lower_threshold=lower_threshold,
lower_threshold_inclusive=lower_threshold_inclusive,
one_of=one_of,
none_of=none_of,
must_be_non_none=must_be_non_none,
must_be_none=must_be_none,
ConstraintValidatorOptions(
target_metric=target_metric,
upper_threshold=upper_threshold,
upper_threshold_inclusive=upper_threshold_inclusive,
lower_threshold=lower_threshold,
lower_threshold_inclusive=lower_threshold_inclusive,
one_of=one_of,
none_of=none_of,
must_be_non_none=must_be_non_none,
must_be_none=must_be_none,
)
)

@staticmethod
def multi_column_constraint(
constraints: List[ConstraintValidatorOptions],
operator: Literal["AND", "OR"] = "AND",
report_mode: Literal["ALL_FAILED_METRICS", "FIRST_FAILED_METRIC"] = "FIRST_FAILED_METRIC",
) -> Validator:
return MultiColumnConstraintValidator(constraints=constraints, operator=operator, report_mode=report_mode)
4 changes: 2 additions & 2 deletions tests/langkit/metrics/test_text_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from langkit.metrics.text_statistics_types import TextStat
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema
from langkit.validators.comparison import ConstraintValidator
from langkit.validators.comparison import ConstraintValidator, ConstraintValidatorOptions

expected_metrics = [
"cardinality/est",
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_prompt_char_count_module():
def test_prompt_char_count_0_module():
wf = Workflow(
metrics=[prompt_char_count_metric, response_char_count_metric],
validators=[ConstraintValidator("prompt.stats.char_count", lower_threshold=2)],
validators=[ConstraintValidator(ConstraintValidatorOptions("prompt.stats.char_count", lower_threshold=2))],
)

df = pd.DataFrame(
Expand Down
Loading

0 comments on commit 2483bde

Please sign in to comment.