Skip to content

Commit

Permalink
Merge pull request #271 from whylabs/lambda-bad
Browse files Browse the repository at this point in the history
Fix sneaky bug with lambdas
  • Loading branch information
naddeoa authored Mar 24, 2024
2 parents c5bec44 + 4024c50 commit ec0f067
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
25 changes: 17 additions & 8 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional

from langkit.core.metric import MetricCreator
Expand Down Expand Up @@ -104,7 +105,7 @@ def pii(entities: Optional[List[str]] = None, input_name: str = "prompt") -> Met
from langkit.metrics.pii import pii_presidio_metric, prompt_presidio_pii_metric

if entities:
return lambda: pii_presidio_metric(entities=entities, input_name=input_name)
return partial(pii_presidio_metric, entities=entities, input_name=input_name)

return prompt_presidio_pii_metric

Expand Down Expand Up @@ -185,7 +186,7 @@ def token_count(tiktoken_encoding: Optional[str] = None) -> MetricCreator:
from langkit.metrics.token import prompt_token_metric, token_metric

if tiktoken_encoding:
return lambda: token_metric(column_name="prompt", encoding=tiktoken_encoding)
return partial(token_metric, column_name="prompt", encoding=tiktoken_encoding)

return prompt_token_metric

Expand Down Expand Up @@ -246,7 +247,7 @@ def injection(version: Optional[str] = None) -> MetricCreator:
from langkit.metrics.injections import injections_metric, prompt_injections_metric

if version:
return lambda: injections_metric(column_name="prompt", version=version)
return partial(injections_metric, column_name="prompt", version=version)

return prompt_injections_metric

Expand Down Expand Up @@ -275,10 +276,14 @@ def sentiment_score() -> MetricCreator:
return prompt_sentiment_polarity

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("prompt", topics, hypothesis_template)
return partial(topic_metric, "prompt", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
Expand Down Expand Up @@ -469,13 +474,17 @@ def refusal() -> MetricCreator:
return response_refusal_similarity_metric

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", topics, hypothesis_template)
return partial(topic_metric, "response", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", ["medicine"])
return partial(topic_metric, "response", ["medicine"])
6 changes: 4 additions & 2 deletions langkit/metrics/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __get_scores_per_label(
return scores_per_label


def topic_metric(input_name: str, topics: List[str], hypothesis_template: str = _hypothesis_template) -> MultiMetric:
def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric:
hypothesis_template = hypothesis_template or _hypothesis_template

def udf(text: pd.DataFrame) -> MultiMetricResult:
metrics: Dict[str, List[Optional[float]]] = {metric_name: [] for metric_name in topics}
metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics}

def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
value: Any = row[input_name] # type: ignore
Expand Down
55 changes: 55 additions & 0 deletions tests/langkit/metrics/test_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

Expand Down Expand Up @@ -151,6 +152,60 @@ def test_topic_row():
assert actual.index.tolist() == expected_columns


def test_topic_library():
df = pd.DataFrame(
{
"prompt": [
"What's the best kind of bait?",
"What's the best kind of punch?",
"What's the best kind of trail?",
"What's the best kind of swimming stroke?",
],
"response": [
"The best kind of bait is worms.",
"The best kind of punch is a jab.",
"The best kind of trail is a loop.",
"The best kind of stroke is freestyle.",
],
}
)

topics = ["fishing", "boxing", "hiking", "swimming"]
wf = Workflow(metrics=[lib.prompt.topics(topics), lib.response.topics(topics)])
# schema = WorkflowMetricConfigBuilder().add(lib.prompt.topics(topics)).add(lib.response.topics(topics)).build()
# schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build()

# actual = _log(df, schema)
result = wf.run(df)
actual = result.metrics

expected_columns = [
"prompt.topics.fishing",
"prompt.topics.boxing",
"prompt.topics.hiking",
"prompt.topics.swimming",
"response.topics.fishing",
"response.topics.boxing",
"response.topics.hiking",
"response.topics.swimming",
"id",
]
assert actual.columns.tolist() == expected_columns

pd.set_option("display.max_columns", None)
print(actual.transpose())

assert actual["prompt.topics.fishing"][0] > 0.50
assert actual["prompt.topics.boxing"][1] > 0.50
assert actual["prompt.topics.hiking"][2] > 0.50
assert actual["prompt.topics.swimming"][3] > 0.50

assert actual["response.topics.fishing"][0] > 0.50
assert actual["response.topics.boxing"][1] > 0.50
assert actual["response.topics.hiking"][2] > 0.50
assert actual["response.topics.swimming"][3] > 0.50


def test_custom_topic():
df = pd.DataFrame(
{
Expand Down

0 comments on commit ec0f067

Please sign in to comment.