Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flag #301

Merged
merged 2 commits into from
Apr 24, 2024
Merged

Flag #301

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.28.dev13
current_version = 0.0.28.dev14
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
Expand Down
6 changes: 5 additions & 1 deletion langkit/core/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

import pandas as pd

ValidationFailureLevel = Literal["flag", "block"]


@dataclass(frozen=True)
class ValidationFailure:
Expand All @@ -19,6 +21,8 @@ class ValidationFailure:
must_be_none: Optional[bool] = None
must_be_non_none: Optional[bool] = None

failure_level: ValidationFailureLevel = "block"


@dataclass(frozen=True)
class ValidationResult:
Expand Down
53 changes: 36 additions & 17 deletions langkit/validators/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import numpy as np
import pandas as pd

from langkit.core.validation import ValidationFailure, ValidationResult, Validator
from langkit.core.validation import ValidationFailure, ValidationFailureLevel, ValidationResult, Validator


def _enforce_upper_threshold(target_metric: str, upper_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_upper_threshold(
target_metric: str, upper_threshold: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

Expand All @@ -20,13 +22,16 @@ def _enforce_upper_threshold(target_metric: str, upper_threshold: Union[int, flo
details=f"Value {value} is above threshold {upper_threshold}",
value=value,
upper_threshold=upper_threshold,
failure_level=level,
)
]

return []


def _enforce_lower_threshold(target_metric: str, lower_threshold: Union[int, float], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_lower_threshold(
target_metric: str, lower_threshold: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []

Expand All @@ -38,14 +43,15 @@ def _enforce_lower_threshold(target_metric: str, lower_threshold: Union[int, flo
details=f"Value {value} is below threshold {lower_threshold}",
value=value,
lower_threshold=lower_threshold,
failure_level=level,
)
]

return []


def _enforce_upper_threshold_inclusive(
target_metric: str, upper_threshold_inclusive: Union[int, float], value: Any, id: str
target_metric: str, upper_threshold_inclusive: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []
Expand All @@ -58,14 +64,15 @@ def _enforce_upper_threshold_inclusive(
details=f"Value {value} is above or equal to threshold {upper_threshold_inclusive}",
value=value,
upper_threshold=upper_threshold_inclusive,
failure_level=level,
)
]

return []


def _enforce_lower_threshold_inclusive(
target_metric: str, lower_threshold_inclusive: Union[int, float], value: Any, id: str
target_metric: str, lower_threshold_inclusive: Union[int, float], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if not isinstance(value, (float, int)):
return []
Expand All @@ -78,13 +85,16 @@ def _enforce_lower_threshold_inclusive(
details=f"Value {value} is below or equal to threshold {lower_threshold_inclusive}",
value=value,
lower_threshold=lower_threshold_inclusive,
failure_level=level,
)
]

return []


def _enforce_one_of(target_metric: str, one_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_one_of(
target_metric: str, one_of: Set[Union[str, float, int]], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if value not in one_of:
return [
ValidationFailure(
Expand All @@ -93,12 +103,15 @@ def _enforce_one_of(target_metric: str, one_of: Set[Union[str, float, int]], val
details=f"Value {value} is not in allowed values {one_of}",
value=value,
allowed_values=list(one_of),
failure_level=level,
)
]
return []


def _enforce_none_of(target_metric: str, none_of: Set[Union[str, float, int]], value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_none_of(
target_metric: str, none_of: Set[Union[str, float, int]], level: ValidationFailureLevel, value: Any, id: str
) -> Sequence[ValidationFailure]:
if value in none_of:
return [
ValidationFailure(
Expand All @@ -107,12 +120,13 @@ def _enforce_none_of(target_metric: str, none_of: Set[Union[str, float, int]], v
details=f"Value {value} is in disallowed values {none_of}",
value=value,
disallowed_values=list(none_of),
failure_level=level,
)
]
return []


def _enforce_must_be_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_must_be_none(target_metric: str, level: ValidationFailureLevel, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is not None:
return [
ValidationFailure(
Expand All @@ -121,12 +135,13 @@ def _enforce_must_be_none(target_metric: str, value: Any, id: str) -> Sequence[V
details=f"Value {value} is not None",
value=value,
must_be_none=True,
failure_level=level,
)
]
return []


def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequence[ValidationFailure]:
def _enforce_must_be_non_none(target_metric: str, level: ValidationFailureLevel, value: Any, id: str) -> Sequence[ValidationFailure]:
if value is None:
return [
ValidationFailure(
Expand All @@ -135,6 +150,7 @@ def _enforce_must_be_non_none(target_metric: str, value: Any, id: str) -> Sequen
details="Value is None",
value=value,
must_be_non_none=True,
failure_level=level,
)
]
return []
Expand All @@ -151,32 +167,35 @@ class ConstraintValidatorOptions:
none_of: Optional[Tuple[Union[str, float, int], ...]] = None
must_be_non_none: Optional[bool] = None
must_be_none: Optional[bool] = None
failure_level: Optional[ValidationFailureLevel] = None


class ConstraintValidator(Validator):
def __init__(self, options: ConstraintValidatorOptions):
validation_functions: List[Callable[[Any, str], Sequence[ValidationFailure]]] = []

level = options.failure_level or "block"

if options.upper_threshold is not None:
validation_functions.append(partial(_enforce_upper_threshold, options.target_metric, options.upper_threshold))
validation_functions.append(partial(_enforce_upper_threshold, options.target_metric, options.upper_threshold, level))
if options.lower_threshold is not None:
validation_functions.append(partial(_enforce_lower_threshold, options.target_metric, options.lower_threshold))
validation_functions.append(partial(_enforce_lower_threshold, options.target_metric, options.lower_threshold, level))
if options.upper_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_upper_threshold_inclusive, options.target_metric, options.upper_threshold_inclusive)
partial(_enforce_upper_threshold_inclusive, options.target_metric, options.upper_threshold_inclusive, level)
)
if options.lower_threshold_inclusive is not None:
validation_functions.append(
partial(_enforce_lower_threshold_inclusive, options.target_metric, options.lower_threshold_inclusive)
partial(_enforce_lower_threshold_inclusive, options.target_metric, options.lower_threshold_inclusive, level)
)
if options.one_of is not None:
validation_functions.append(partial(_enforce_one_of, options.target_metric, set(options.one_of)))
validation_functions.append(partial(_enforce_one_of, options.target_metric, set(options.one_of), level))
if options.none_of is not None:
validation_functions.append(partial(_enforce_none_of, options.target_metric, set(options.none_of)))
validation_functions.append(partial(_enforce_none_of, options.target_metric, set(options.none_of), level))
if options.must_be_non_none is not None:
validation_functions.append(partial(_enforce_must_be_non_none, options.target_metric))
validation_functions.append(partial(_enforce_must_be_non_none, options.target_metric, level))
if options.must_be_none is not None:
validation_functions.append(partial(_enforce_must_be_none, options.target_metric))
validation_functions.append(partial(_enforce_must_be_none, options.target_metric, level))

self._target_metric = options.target_metric
self._validation_functions = validation_functions
Expand Down
4 changes: 3 additions & 1 deletion langkit/validators/library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional, Sequence, Union

from langkit.core.validation import Validator
from langkit.core.validation import ValidationFailureLevel, Validator
from langkit.validators.comparison import (
ConstraintValidator,
ConstraintValidatorOptions,
Expand Down Expand Up @@ -72,6 +72,7 @@ def constraint(
none_of: Optional[Sequence[Union[str, float, int]]] = None,
must_be_non_none: Optional[bool] = None,
must_be_none: Optional[bool] = None,
failure_level: Optional[ValidationFailureLevel] = None,
) -> Validator:
return ConstraintValidator(
ConstraintValidatorOptions(
Expand All @@ -84,6 +85,7 @@ def constraint(
none_of=tuple(none_of) if none_of else None,
must_be_non_none=must_be_non_none,
must_be_none=must_be_none,
failure_level=failure_level,
)
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langkit"
version = "0.0.28.dev13"
version = "0.0.28.dev14"
description = "A language toolkit for monitoring LLM interactions"
authors = ["WhyLabs.ai <[email protected]>"]
homepage = "https://docs.whylabs.ai/docs/large-language-model-monitoring"
Expand Down
53 changes: 53 additions & 0 deletions tests/langkit/metrics/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from langkit.core.validation import ValidationFailure
from langkit.core.workflow import MetricFilterOptions, RunOptions, Workflow
from langkit.metrics.library import lib
from langkit.validators.library import lib as validator_lib
Expand Down Expand Up @@ -208,6 +209,58 @@ def test_just_prompt_validation():
]


def test_just_prompt_validation_flag():
rule = validator_lib.constraint(target_metric="response.stats.token_count", upper_threshold=1, failure_level="flag")
wf = Workflow(metrics=[lib.presets.recommended()], validators=[rule])

result = wf.run({"prompt": "hi", "response": "hello there"})
metrics = result.metrics

metric_names: List[str] = metrics.columns.tolist() # pyright: ignore[reportUnknownMemberType]

assert metric_names == [
"prompt.pii.phone_number",
"prompt.pii.email_address",
"prompt.pii.credit_card",
"prompt.pii.us_ssn",
"prompt.pii.us_bank_number",
"prompt.pii.redacted",
"prompt.stats.token_count",
"prompt.stats.char_count",
"prompt.similarity.injection",
"prompt.similarity.jailbreak",
"response.pii.phone_number",
"response.pii.email_address",
"response.pii.credit_card",
"response.pii.us_ssn",
"response.pii.us_bank_number",
"response.pii.redacted",
"response.stats.token_count",
"response.stats.char_count",
"response.stats.flesch_reading_ease",
"response.sentiment.sentiment_score",
"response.toxicity.toxicity_score",
"response.similarity.refusal",
"id",
]

assert result.validation_results.report == [
ValidationFailure(
id="0",
metric="response.stats.token_count",
details="Value 2 is above threshold 1",
value=2,
upper_threshold=1,
lower_threshold=None,
allowed_values=None,
disallowed_values=None,
must_be_none=None,
must_be_non_none=None,
failure_level="flag",
),
]


def test_just_response():
wf = Workflow(metrics=[lib.presets.recommended()])
result = wf.run({"response": "I'm doing great!"})
Expand Down
Loading