Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation library #238

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.69
current_version = 0.0.70
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
Expand Down
Empty file removed langkit/callbacks/__init__.py
Empty file.
47 changes: 3 additions & 44 deletions langkit/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

import numpy as np
import pandas as pd


Expand All @@ -15,6 +14,8 @@ class ValidationFailure:
value: Union[int, float, str, None]
upper_threshold: Optional[float] = None
lower_threshold: Optional[float] = None
allowed_values: Optional[List[Union[str, float, int]]] = None
disallowed_values: Optional[List[Union[str, float, int]]] = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -42,45 +43,3 @@ def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]:
by default, that will include a prompt and a resopnse column if both were supplied to the evaluation.
"""
return None


def create_validator(target_metric: str, upper_threshold: Optional[float] = None, lower_threshold: Optional[float] = None) -> Validator:
class _Validator(Validator):
def get_target_metric_names(self) -> List[str]:
return [target_metric]

def validate_result(self, df: pd.DataFrame):
failures: List[ValidationFailure] = []
for _index, row in df.iterrows(): # type: ignore
id = str(row["id"]) # type: ignore TODO make sure this is ok
value: Any = row[target_metric]
if isinstance(value, pd.Series) and value.size == 1:
value = value.item()
elif isinstance(value, np.ndarray) and value.size == 1:
value = value.item()

if upper_threshold is not None and target_metric in row and row[target_metric] > upper_threshold:
failures.append(
ValidationFailure(
id,
target_metric,
f"Value {row[target_metric]} is above threshold {upper_threshold}",
value=value, # type: ignore
upper_threshold=upper_threshold,
)
)

if lower_threshold is not None and target_metric in row and row[target_metric] < lower_threshold:
failures.append(
ValidationFailure(
id,
target_metric,
f"Value {row[target_metric]} is below threshold {lower_threshold}",
value=value, # type: ignore
lower_threshold=lower_threshold,
)
)

return ValidationResult(failures)

return _Validator()
197 changes: 197 additions & 0 deletions langkit/validators/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Set, Union

import numpy as np
import pandas as pd

from langkit.core.validation import ValidationFailure, ValidationResult, Validator


def _enforce_upper_threshold(target_metric: str, upper_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

if value > upper_threshold:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is above threshold {upper_threshold}",
value=value,
upper_threshold=upper_threshold,
)
]

return []


def _enforce_lower_threshold(target_metric: str, lower_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

if value < lower_threshold:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is below threshold {lower_threshold}",
value=value,
lower_threshold=lower_threshold,
)
]

return []


def _enforce_upper_threshold_inclusive(
target_metric: str, upper_threshold_inclusive: Union[int, float], value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

if value >= upper_threshold_inclusive:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is above or equal to threshold {upper_threshold_inclusive}",
value=value,
upper_threshold=upper_threshold_inclusive,
)
]

return []


def _enforce_lower_threshold_inclusive(
target_metric: str, lower_threshold_inclusive: Union[int, float], value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

if value <= lower_threshold_inclusive:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is below or equal to threshold {lower_threshold_inclusive}",
value=value,
lower_threshold=lower_threshold_inclusive,
)
]

return []


def _enforce_one_of(target_metric: str, one_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
if value not in one_of:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is not in allowed values {one_of}",
value=value,
allowed_values=list(one_of),
)
]
return []


def _enforce_none_of(target_metric: str, none_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
if value in none_of:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is in disallowed values {none_of}",
value=value,
disallowed_values=list(none_of),
)
]
return []


def _enforce_must_be_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is not None:
return [
ValidationFailure(
id=id,
metric=target_metric,
details=f"Value {value} is not None",
value=value,
)
]
return []


def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is None:
return [
ValidationFailure(
id=id,
metric=target_metric,
details="Value is None",
value=value,
)
]
return []


class ConstraintValidator(Validator):
def __init__(
self,
target_metric: str,
upper_threshold: Optional[Union[float, int]] = None,
upper_threshold_inclusive: Optional[Union[float, int]] = None,
lower_threshold: Optional[Union[float, int]] = None,
lower_threshold_inclusive: Optional[Union[float, int]] = None,
one_of: Optional[Sequence[Union[str, float, int]]] = None,
none_of: Optional[Sequence[Union[str, float, int]]] = None,
must_be_non_none: Optional[bool] = None,
must_be_none: Optional[bool] = None,
):
validation_functions: List[Callable[[Any, str], Sequence[ValidationFailure]]] = []

if upper_threshold is not None:
validation_functions.append(partial(_enforce_upper_threshold, target_metric, upper_threshold))
if lower_threshold is not None:
validation_functions.append(partial(_enforce_lower_threshold, target_metric, lower_threshold))
if upper_threshold_inclusive is not None:
validation_functions.append(partial(_enforce_upper_threshold_inclusive, target_metric, upper_threshold_inclusive))
if lower_threshold_inclusive is not None:
validation_functions.append(partial(_enforce_lower_threshold_inclusive, target_metric, lower_threshold_inclusive))
if one_of is not None:
validation_functions.append(partial(_enforce_one_of, target_metric, set(one_of)))
if none_of is not None:
validation_functions.append(partial(_enforce_none_of, target_metric, set(none_of)))
if must_be_non_none is not None:
validation_functions.append(partial(_enforce_must_be_non_none, target_metric))
if must_be_none is not None:
validation_functions.append(partial(_enforce_must_be_none, target_metric))

self._target_metric = target_metric
self._validation_functions = validation_functions

if len(validation_functions) == 0:
raise ValueError("At least one constraint must be provided")

def get_target_metric_names(self) -> List[str]:
return [self._target_metric]

def validate_result(self, df: pd.DataFrame) -> Optional[ValidationResult]:
failures: List[ValidationFailure] = []
for _index, row in df.iterrows(): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
id = str(row["id"]) # pyright: ignore[reportUnknownArgumentType]
value: Any = row[self._target_metric]
if isinstance(value, pd.Series) and value.size == 1:
value = value.item()
elif isinstance(value, np.ndarray) and value.size == 1:
value = value.item()

for validation_function in self._validation_functions:
failures.extend(validation_function(value, id))

if len(failures) == 0:
return None

return ValidationResult(failures)
30 changes: 30 additions & 0 deletions langkit/validators/library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional, Sequence, Union

from langkit.core.validation import Validator
from langkit.validators.comparison import ConstraintValidator


class lib:
@staticmethod
def constraint(
target_metric: str,
upper_threshold: Optional[float] = None,
upper_threshold_inclusive: Optional[float] = None,
lower_threshold: Optional[float] = None,
lower_threshold_inclusive: Optional[float] = None,
one_of: Optional[Sequence[Union[str, float, int]]] = None,
none_of: Optional[Sequence[Union[str, float, int]]] = None,
must_be_non_none: Optional[bool] = None,
must_be_none: Optional[bool] = None,
) -> Validator:
return ConstraintValidator(
target_metric=target_metric,
upper_threshold=upper_threshold,
upper_threshold_inclusive=upper_threshold_inclusive,
lower_threshold=lower_threshold,
lower_threshold_inclusive=lower_threshold_inclusive,
one_of=one_of,
none_of=none_of,
must_be_non_none=must_be_non_none,
must_be_none=must_be_none,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langkit"
version = "0.0.69"
version = "0.0.70"
description = "A language toolkit for monitoring LLM interactions"
authors = ["WhyLabs.ai <[email protected]>"]
homepage = "https://docs.whylabs.ai/docs/large-language-model-monitoring"
Expand Down
5 changes: 3 additions & 2 deletions tests/langkit/metrics/test_text_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import whylogs as why
from langkit.core.metric import EvaluationConfig, EvaluationConfigBuilder, Metric, MultiMetric, MultiMetricResult, UdfInput
from langkit.core.validation import ValidationFailure, ValidationResult, create_validator
from langkit.core.validation import ValidationFailure, ValidationResult
from langkit.core.workflow import EvaluationWorkflow
from langkit.metrics.text_statistics import (
prompt_char_count_module,
Expand All @@ -20,6 +20,7 @@
)
from langkit.metrics.text_statistics_types import TextStat
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema
from langkit.validators.comparison import ConstraintValidator

expected_metrics = [
"cardinality/est",
Expand Down Expand Up @@ -257,7 +258,7 @@ def test_prompt_char_count_module():
def test_prompt_char_count_0_module():
wf = EvaluationWorkflow(
metrics=[prompt_char_count_module, response_char_count_module],
validators=[create_validator("prompt.char_count", lower_threshold=2)],
validators=[ConstraintValidator("prompt.char_count", lower_threshold=2)],
)

df = pd.DataFrame(
Expand Down
Loading
Loading