Skip to content

Commit

Permalink
Fix sneaky bug with lambdas
Browse files Browse the repository at this point in the history
Aparently something about lambdas makes it too hard for pyright to catch
argument type errors. This switches all of the lambdas to partials in
the metric library and corrects the type error for the new topic module.
  • Loading branch information
Anthony Naddeo committed Mar 24, 2024
1 parent c5bec44 commit 4024c50
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
25 changes: 17 additions & 8 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional

from langkit.core.metric import MetricCreator
Expand Down Expand Up @@ -104,7 +105,7 @@ def pii(entities: Optional[List[str]] = None, input_name: str = "prompt") -> Met
from langkit.metrics.pii import pii_presidio_metric, prompt_presidio_pii_metric

if entities:
return lambda: pii_presidio_metric(entities=entities, input_name=input_name)
return partial(pii_presidio_metric, entities=entities, input_name=input_name)

return prompt_presidio_pii_metric

Expand Down Expand Up @@ -185,7 +186,7 @@ def token_count(tiktoken_encoding: Optional[str] = None) -> MetricCreator:
from langkit.metrics.token import prompt_token_metric, token_metric

if tiktoken_encoding:
return lambda: token_metric(column_name="prompt", encoding=tiktoken_encoding)
return partial(token_metric, column_name="prompt", encoding=tiktoken_encoding)

return prompt_token_metric

Expand Down Expand Up @@ -246,7 +247,7 @@ def injection(version: Optional[str] = None) -> MetricCreator:
from langkit.metrics.injections import injections_metric, prompt_injections_metric

if version:
return lambda: injections_metric(column_name="prompt", version=version)
return partial(injections_metric, column_name="prompt", version=version)

return prompt_injections_metric

Expand Down Expand Up @@ -275,10 +276,14 @@ def sentiment_score() -> MetricCreator:
return prompt_sentiment_polarity

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("prompt", topics, hypothesis_template)
return partial(topic_metric, "prompt", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
Expand Down Expand Up @@ -469,13 +474,17 @@ def refusal() -> MetricCreator:
return response_refusal_similarity_metric

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", topics, hypothesis_template)
return partial(topic_metric, "response", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", ["medicine"])
return partial(topic_metric, "response", ["medicine"])
6 changes: 4 additions & 2 deletions langkit/metrics/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __get_scores_per_label(
return scores_per_label


def topic_metric(input_name: str, topics: List[str], hypothesis_template: str = _hypothesis_template) -> MultiMetric:
def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric:
hypothesis_template = hypothesis_template or _hypothesis_template

def udf(text: pd.DataFrame) -> MultiMetricResult:
metrics: Dict[str, List[Optional[float]]] = {metric_name: [] for metric_name in topics}
metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics}

def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
value: Any = row[input_name] # type: ignore
Expand Down
55 changes: 55 additions & 0 deletions tests/langkit/metrics/test_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

Expand Down Expand Up @@ -151,6 +152,60 @@ def test_topic_row():
assert actual.index.tolist() == expected_columns


def test_topic_library():
df = pd.DataFrame(
{
"prompt": [
"What's the best kind of bait?",
"What's the best kind of punch?",
"What's the best kind of trail?",
"What's the best kind of swimming stroke?",
],
"response": [
"The best kind of bait is worms.",
"The best kind of punch is a jab.",
"The best kind of trail is a loop.",
"The best kind of stroke is freestyle.",
],
}
)

topics = ["fishing", "boxing", "hiking", "swimming"]
wf = Workflow(metrics=[lib.prompt.topics(topics), lib.response.topics(topics)])
# schema = WorkflowMetricConfigBuilder().add(lib.prompt.topics(topics)).add(lib.response.topics(topics)).build()
# schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build()

# actual = _log(df, schema)
result = wf.run(df)
actual = result.metrics

expected_columns = [
"prompt.topics.fishing",
"prompt.topics.boxing",
"prompt.topics.hiking",
"prompt.topics.swimming",
"response.topics.fishing",
"response.topics.boxing",
"response.topics.hiking",
"response.topics.swimming",
"id",
]
assert actual.columns.tolist() == expected_columns

pd.set_option("display.max_columns", None)
print(actual.transpose())

assert actual["prompt.topics.fishing"][0] > 0.50
assert actual["prompt.topics.boxing"][1] > 0.50
assert actual["prompt.topics.hiking"][2] > 0.50
assert actual["prompt.topics.swimming"][3] > 0.50

assert actual["response.topics.fishing"][0] > 0.50
assert actual["response.topics.boxing"][1] > 0.50
assert actual["response.topics.hiking"][2] > 0.50
assert actual["response.topics.swimming"][3] > 0.50


def test_custom_topic():
df = pd.DataFrame(
{
Expand Down

0 comments on commit 4024c50

Please sign in to comment.