From 4024c5039481d246ea409ab62c1deb483371ad8d Mon Sep 17 00:00:00 2001 From: Anthony Naddeo Date: Sun, 24 Mar 2024 14:34:55 -0700 Subject: [PATCH] Fix sneaky bug with lambdas Aparently something about lambdas makes it too hard for pyright to catch argument type errors. This switches all of the lambdas to partials in the metric library and corrects the type error for the new topic module. --- langkit/metrics/library.py | 25 ++++++++----- langkit/metrics/topic.py | 6 ++-- tests/langkit/metrics/test_topic.py | 55 +++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index f528c8f..cac7e40 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -1,3 +1,4 @@ +from functools import partial from typing import List, Optional from langkit.core.metric import MetricCreator @@ -104,7 +105,7 @@ def pii(entities: Optional[List[str]] = None, input_name: str = "prompt") -> Met from langkit.metrics.pii import pii_presidio_metric, prompt_presidio_pii_metric if entities: - return lambda: pii_presidio_metric(entities=entities, input_name=input_name) + return partial(pii_presidio_metric, entities=entities, input_name=input_name) return prompt_presidio_pii_metric @@ -185,7 +186,7 @@ def token_count(tiktoken_encoding: Optional[str] = None) -> MetricCreator: from langkit.metrics.token import prompt_token_metric, token_metric if tiktoken_encoding: - return lambda: token_metric(column_name="prompt", encoding=tiktoken_encoding) + return partial(token_metric, column_name="prompt", encoding=tiktoken_encoding) return prompt_token_metric @@ -246,7 +247,7 @@ def injection(version: Optional[str] = None) -> MetricCreator: from langkit.metrics.injections import injections_metric, prompt_injections_metric if version: - return lambda: injections_metric(column_name="prompt", version=version) + return partial(injections_metric, column_name="prompt", version=version) return prompt_injections_metric @@ -275,10 +276,14 @@ def sentiment_score() -> MetricCreator: return prompt_sentiment_polarity class topics: - def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator: + def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None): + self.topics = topics + self.hypothesis_template = hypothesis_template + + def __call__(self) -> MetricCreator: from langkit.metrics.topic import topic_metric - return lambda: topic_metric("prompt", topics, hypothesis_template) + return partial(topic_metric, "prompt", self.topics, self.hypothesis_template) @staticmethod def medicine() -> MetricCreator: @@ -469,13 +474,17 @@ def refusal() -> MetricCreator: return response_refusal_similarity_metric class topics: - def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator: + def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None): + self.topics = topics + self.hypothesis_template = hypothesis_template + + def __call__(self) -> MetricCreator: from langkit.metrics.topic import topic_metric - return lambda: topic_metric("response", topics, hypothesis_template) + return partial(topic_metric, "response", self.topics, self.hypothesis_template) @staticmethod def medicine() -> MetricCreator: from langkit.metrics.topic import topic_metric - return lambda: topic_metric("response", ["medicine"]) + return partial(topic_metric, "response", ["medicine"]) diff --git a/langkit/metrics/topic.py b/langkit/metrics/topic.py index 0bb7c03..8928e90 100644 --- a/langkit/metrics/topic.py +++ b/langkit/metrics/topic.py @@ -38,9 +38,11 @@ def __get_scores_per_label( return scores_per_label -def topic_metric(input_name: str, topics: List[str], hypothesis_template: str = _hypothesis_template) -> MultiMetric: +def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric: + hypothesis_template = hypothesis_template or _hypothesis_template + def udf(text: pd.DataFrame) -> MultiMetricResult: - metrics: Dict[str, List[Optional[float]]] = {metric_name: [] for metric_name in topics} + metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics} def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]: value: Any = row[input_name] # type: ignore diff --git a/tests/langkit/metrics/test_topic.py b/tests/langkit/metrics/test_topic.py index b9742bc..d5340cd 100644 --- a/tests/langkit/metrics/test_topic.py +++ b/tests/langkit/metrics/test_topic.py @@ -6,6 +6,7 @@ import whylogs as why from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder from langkit.core.workflow import Workflow +from langkit.metrics.library import lib from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module from langkit.metrics.whylogs_compat import create_whylogs_udf_schema @@ -151,6 +152,60 @@ def test_topic_row(): assert actual.index.tolist() == expected_columns +def test_topic_library(): + df = pd.DataFrame( + { + "prompt": [ + "What's the best kind of bait?", + "What's the best kind of punch?", + "What's the best kind of trail?", + "What's the best kind of swimming stroke?", + ], + "response": [ + "The best kind of bait is worms.", + "The best kind of punch is a jab.", + "The best kind of trail is a loop.", + "The best kind of stroke is freestyle.", + ], + } + ) + + topics = ["fishing", "boxing", "hiking", "swimming"] + wf = Workflow(metrics=[lib.prompt.topics(topics), lib.response.topics(topics)]) + # schema = WorkflowMetricConfigBuilder().add(lib.prompt.topics(topics)).add(lib.response.topics(topics)).build() + # schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build() + + # actual = _log(df, schema) + result = wf.run(df) + actual = result.metrics + + expected_columns = [ + "prompt.topics.fishing", + "prompt.topics.boxing", + "prompt.topics.hiking", + "prompt.topics.swimming", + "response.topics.fishing", + "response.topics.boxing", + "response.topics.hiking", + "response.topics.swimming", + "id", + ] + assert actual.columns.tolist() == expected_columns + + pd.set_option("display.max_columns", None) + print(actual.transpose()) + + assert actual["prompt.topics.fishing"][0] > 0.50 + assert actual["prompt.topics.boxing"][1] > 0.50 + assert actual["prompt.topics.hiking"][2] > 0.50 + assert actual["prompt.topics.swimming"][3] > 0.50 + + assert actual["response.topics.fishing"][0] > 0.50 + assert actual["response.topics.boxing"][1] > 0.50 + assert actual["response.topics.hiking"][2] > 0.50 + assert actual["response.topics.swimming"][3] > 0.50 + + def test_custom_topic(): df = pd.DataFrame( {