Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sneaky bug with lambdas #271

Merged
merged 1 commit into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional

from langkit.core.metric import MetricCreator
Expand Down Expand Up @@ -104,7 +105,7 @@ def pii(entities: Optional[List[str]] = None, input_name: str = "prompt") -> Met
from langkit.metrics.pii import pii_presidio_metric, prompt_presidio_pii_metric

if entities:
return lambda: pii_presidio_metric(entities=entities, input_name=input_name)
return partial(pii_presidio_metric, entities=entities, input_name=input_name)

return prompt_presidio_pii_metric

Expand Down Expand Up @@ -185,7 +186,7 @@ def token_count(tiktoken_encoding: Optional[str] = None) -> MetricCreator:
from langkit.metrics.token import prompt_token_metric, token_metric

if tiktoken_encoding:
return lambda: token_metric(column_name="prompt", encoding=tiktoken_encoding)
return partial(token_metric, column_name="prompt", encoding=tiktoken_encoding)

return prompt_token_metric

Expand Down Expand Up @@ -246,7 +247,7 @@ def injection(version: Optional[str] = None) -> MetricCreator:
from langkit.metrics.injections import injections_metric, prompt_injections_metric

if version:
return lambda: injections_metric(column_name="prompt", version=version)
return partial(injections_metric, column_name="prompt", version=version)

return prompt_injections_metric

Expand Down Expand Up @@ -275,10 +276,14 @@ def sentiment_score() -> MetricCreator:
return prompt_sentiment_polarity

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("prompt", topics, hypothesis_template)
return partial(topic_metric, "prompt", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
Expand Down Expand Up @@ -469,13 +474,17 @@ def refusal() -> MetricCreator:
return response_refusal_similarity_metric

class topics:
def __call__(self, topics: List[str], hypothesis_template: Optional[str] = None) -> MetricCreator:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None):
self.topics = topics
self.hypothesis_template = hypothesis_template

def __call__(self) -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", topics, hypothesis_template)
return partial(topic_metric, "response", self.topics, self.hypothesis_template)

@staticmethod
def medicine() -> MetricCreator:
from langkit.metrics.topic import topic_metric

return lambda: topic_metric("response", ["medicine"])
return partial(topic_metric, "response", ["medicine"])
6 changes: 4 additions & 2 deletions langkit/metrics/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __get_scores_per_label(
return scores_per_label


def topic_metric(input_name: str, topics: List[str], hypothesis_template: str = _hypothesis_template) -> MultiMetric:
def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric:
hypothesis_template = hypothesis_template or _hypothesis_template

def udf(text: pd.DataFrame) -> MultiMetricResult:
metrics: Dict[str, List[Optional[float]]] = {metric_name: [] for metric_name in topics}
metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics}

def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
value: Any = row[input_name] # type: ignore
Expand Down
55 changes: 55 additions & 0 deletions tests/langkit/metrics/test_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

Expand Down Expand Up @@ -151,6 +152,60 @@ def test_topic_row():
assert actual.index.tolist() == expected_columns


def test_topic_library():
df = pd.DataFrame(
{
"prompt": [
"What's the best kind of bait?",
"What's the best kind of punch?",
"What's the best kind of trail?",
"What's the best kind of swimming stroke?",
],
"response": [
"The best kind of bait is worms.",
"The best kind of punch is a jab.",
"The best kind of trail is a loop.",
"The best kind of stroke is freestyle.",
],
}
)

topics = ["fishing", "boxing", "hiking", "swimming"]
wf = Workflow(metrics=[lib.prompt.topics(topics), lib.response.topics(topics)])
# schema = WorkflowMetricConfigBuilder().add(lib.prompt.topics(topics)).add(lib.response.topics(topics)).build()
# schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build()

# actual = _log(df, schema)
result = wf.run(df)
actual = result.metrics

expected_columns = [
"prompt.topics.fishing",
"prompt.topics.boxing",
"prompt.topics.hiking",
"prompt.topics.swimming",
"response.topics.fishing",
"response.topics.boxing",
"response.topics.hiking",
"response.topics.swimming",
"id",
]
assert actual.columns.tolist() == expected_columns

pd.set_option("display.max_columns", None)
print(actual.transpose())

assert actual["prompt.topics.fishing"][0] > 0.50
assert actual["prompt.topics.boxing"][1] > 0.50
assert actual["prompt.topics.hiking"][2] > 0.50
assert actual["prompt.topics.swimming"][3] > 0.50

assert actual["response.topics.fishing"][0] > 0.50
assert actual["response.topics.boxing"][1] > 0.50
assert actual["response.topics.hiking"][2] > 0.50
assert actual["response.topics.swimming"][3] > 0.50


def test_custom_topic():
df = pd.DataFrame(
{
Expand Down
Loading