Skip to content

Commit

Permalink
Merge pull request #301 from whylabs/flag
Browse files Browse the repository at this point in the history
Flag
  • Loading branch information
naddeoa authored Apr 24, 2024
2 parents dda5396 + f7d079a commit d135ba8
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.28.dev13
current_version = 0.0.28.dev14
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
Expand Down
6 changes: 5 additions & 1 deletion langkit/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

import pandas as pd

ValidationFailureLevel = Literal["flag", "block"]


@dataclass(frozen=True)
class ValidationFailure:
Expand All @@ -19,6 +21,8 @@ class ValidationFailure:
must_be_none: Optional[bool] = None
must_be_non_none: Optional[bool] = None

failure_level: ValidationFailureLevel = "block"


@dataclass(frozen=True)
class ValidationResult:
Expand Down
53 changes: 36 additions & 17 deletions langkit/validators/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import numpy as np
import pandas as pd

from langkit.core.validation import ValidationFailure, ValidationResult, Validator
from langkit.core.validation import ValidationFailure, ValidationFailureLevel, ValidationResult, Validator


def _enforce_upper_threshold(target_metric: str, upper_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_upper_threshold(
target_metric: str, upper_threshold: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

Expand All @@ -20,13 +22,16 @@ def _enforce_upper_threshold(target_metric: str, upper_threshold: Union[int, flo
details=f"Value {value} is above threshold {upper_threshold}",
value=value,
upper_threshold=upper_threshold,
failure_level=level,
)
]

return []


def _enforce_lower_threshold(target_metric: str, lower_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_lower_threshold(
target_metric: str, lower_threshold: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

Expand All @@ -38,14 +43,15 @@ def _enforce_lower_threshold(target_metric: str, lower_threshold: Union[int, flo
details=f"Value {value} is below threshold {lower_threshold}",
value=value,
lower_threshold=lower_threshold,
failure_level=level,
)
]

return []


def _enforce_upper_threshold_inclusive(
target_metric: str, upper_threshold_inclusive: Union[int, float], value: Any, id: str
target_metric: str, upper_threshold_inclusive: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []
Expand All @@ -58,14 +64,15 @@ def _enforce_upper_threshold_inclusive(
details=f"Value {value} is above or equal to threshold {upper_threshold_inclusive}",
value=value,
upper_threshold=upper_threshold_inclusive,
failure_level=level,
)
]

return []


def _enforce_lower_threshold_inclusive(
target_metric: str, lower_threshold_inclusive: Union[int, float], value: Any, id: str
target_metric: str, lower_threshold_inclusive: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []
Expand All @@ -78,13 +85,16 @@ def _enforce_lower_threshold_inclusive(
details=f"Value {value} is below or equal to threshold {lower_threshold_inclusive}",
value=value,
lower_threshold=lower_threshold_inclusive,
failure_level=level,
)
]

return []


def _enforce_one_of(target_metric: str, one_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_one_of(
target_metric: str, one_of: Set[Union[str, float, int]], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if value not in one_of:
return [
ValidationFailure(
Expand All @@ -93,12 +103,15 @@ def _enforce_one_of(target_metric: str, one_of: Set[Union[str, float, int]], val
details=f"Value {value} is not in allowed values {one_of}",
value=value,
allowed_values=list(one_of),
failure_level=level,
)
]
return []


def _enforce_none_of(target_metric: str, none_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_none_of(
target_metric: str, none_of: Set[Union[str, float, int]], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if value in none_of:
return [
ValidationFailure(
Expand All @@ -107,12 +120,13 @@ def _enforce_none_of(target_metric: str, none_of: Set[Union[str, float, int]], v
details=f"Value {value} is in disallowed values {none_of}",
value=value,
disallowed_values=list(none_of),
failure_level=level,
)
]
return []


def _enforce_must_be_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_must_be_none(target_metric: str, level: ValidationFailureLevel, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is not None:
return [
ValidationFailure(
Expand All @@ -121,12 +135,13 @@ def _enforce_must_be_none(target_metric: str, value: Any, id: str) -> Sequence[V
details=f"Value {value} is not None",
value=value,
must_be_none=True,
failure_level=level,
)
]
return []


def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_must_be_non_none(target_metric: str, level: ValidationFailureLevel, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is None:
return [
ValidationFailure(
Expand All @@ -135,6 +150,7 @@ def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequen
details="Value is None",
value=value,
must_be_non_none=True,
failure_level=level,
)
]
return []
Expand All @@ -151,32 +167,35 @@ class ConstraintValidatorOptions:
none_of: Optional[Tuple[Union[str, float, int], ...]] = None
must_be_non_none: Optional[bool] = None
must_be_none: Optional[bool] = None
failure_level: Optional[ValidationFailureLevel] = None


class ConstraintValidator(Validator):
def __init__(self, options: ConstraintValidatorOptions):
validation_functions: List[Callable[[Any, str], Sequence[ValidationFailure]]] = []

level = options.failure_level or "block"

if options.upper_threshold is not None:
validation_functions.append(partial(_enforce_upper_threshold, options.target_metric, options.upper_threshold))
validation_functions.append(partial(_enforce_upper_threshold, options.target_metric, options.upper_threshold, level))
if options.lower_threshold is not None:
validation_functions.append(partial(_enforce_lower_threshold, options.target_metric, options.lower_threshold))
validation_functions.append(partial(_enforce_lower_threshold, options.target_metric, options.lower_threshold, level))
if options.upper_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_upper_threshold_inclusive, options.target_metric, options.upper_threshold_inclusive)
partial(_enforce_upper_threshold_inclusive, options.target_metric, options.upper_threshold_inclusive, level)
)
if options.lower_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_lower_threshold_inclusive, options.target_metric, options.lower_threshold_inclusive)
partial(_enforce_lower_threshold_inclusive, options.target_metric, options.lower_threshold_inclusive, level)
)
if options.one_of is not None:
validation_functions.append(partial(_enforce_one_of, options.target_metric, set(options.one_of)))
validation_functions.append(partial(_enforce_one_of, options.target_metric, set(options.one_of), level))
if options.none_of is not None:
validation_functions.append(partial(_enforce_none_of, options.target_metric, set(options.none_of)))
validation_functions.append(partial(_enforce_none_of, options.target_metric, set(options.none_of), level))
if options.must_be_non_none is not None:
validation_functions.append(partial(_enforce_must_be_non_none, options.target_metric))
validation_functions.append(partial(_enforce_must_be_non_none, options.target_metric, level))
if options.must_be_none is not None:
validation_functions.append(partial(_enforce_must_be_none, options.target_metric))
validation_functions.append(partial(_enforce_must_be_none, options.target_metric, level))

self._target_metric = options.target_metric
self._validation_functions = validation_functions
Expand Down
4 changes: 3 additions & 1 deletion langkit/validators/library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional, Sequence, Union

from langkit.core.validation import Validator
from langkit.core.validation import ValidationFailureLevel, Validator
from langkit.validators.comparison import (
ConstraintValidator,
ConstraintValidatorOptions,
Expand Down Expand Up @@ -72,6 +72,7 @@ def constraint(
none_of: Optional[Sequence[Union[str, float, int]]] = None,
must_be_non_none: Optional[bool] = None,
must_be_none: Optional[bool] = None,
failure_level: Optional[ValidationFailureLevel] = None,
) -> Validator:
return ConstraintValidator(
ConstraintValidatorOptions(
Expand All @@ -84,6 +85,7 @@ def constraint(
none_of=tuple(none_of) if none_of else None,
must_be_non_none=must_be_non_none,
must_be_none=must_be_none,
failure_level=failure_level,
)
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langkit"
version = "0.0.28.dev13"
version = "0.0.28.dev14"
description = "A language toolkit for monitoring LLM interactions"
authors = ["WhyLabs.ai <[email protected]>"]
homepage = "https://docs.whylabs.ai/docs/large-language-model-monitoring"
Expand Down
53 changes: 53 additions & 0 deletions tests/langkit/metrics/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from langkit.core.validation import ValidationFailure
from langkit.core.workflow import MetricFilterOptions, RunOptions, Workflow
from langkit.metrics.library import lib
from langkit.validators.library import lib as validator_lib
Expand Down Expand Up @@ -208,6 +209,58 @@ def test_just_prompt_validation():
]


def test_just_prompt_validation_flag():
rule = validator_lib.constraint(target_metric="response.stats.token_count", upper_threshold=1, failure_level="flag")
wf = Workflow(metrics=[lib.presets.recommended()], validators=[rule])

result = wf.run({"prompt": "hi", "response": "hello there"})
metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == [
"prompt.pii.phone_number",
"prompt.pii.email_address",
"prompt.pii.credit_card",
"prompt.pii.us_ssn",
"prompt.pii.us_bank_number",
"prompt.pii.redacted",
"prompt.stats.token_count",
"prompt.stats.char_count",
"prompt.similarity.injection",
"prompt.similarity.jailbreak",
"response.pii.phone_number",
"response.pii.email_address",
"response.pii.credit_card",
"response.pii.us_ssn",
"response.pii.us_bank_number",
"response.pii.redacted",
"response.stats.token_count",
"response.stats.char_count",
"response.stats.flesch_reading_ease",
"response.sentiment.sentiment_score",
"response.toxicity.toxicity_score",
"response.similarity.refusal",
"id",
]

assert result.validation_results.report == [
ValidationFailure(
id="0",
metric="response.stats.token_count",
details="Value 2 is above threshold 1",
value=2,
upper_threshold=1,
lower_threshold=None,
allowed_values=None,
disallowed_values=None,
must_be_none=None,
must_be_non_none=None,
failure_level="flag",
),
]


def test_just_response():
wf = Workflow(metrics=[lib.presets.recommended()])
result = wf.run({"response": "I'm doing great!"})
Expand Down

0 comments on commit d135ba8

Please sign in to comment.