Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add toxicity model and version options #303

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,26 @@ def __call__(self) -> MetricCreator:
return self.toxicity_score()

@staticmethod
def toxicity_score(onnx: bool = True) -> MetricCreator:
def toxicity_score(
onnx: bool = True, onnx_tag: Optional[str] = None, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None
) -> MetricCreator:
"""
Analyze the input for toxicity. The output of this metric ranges from 0 to 1, where 0 indicates
non-toxic and 1 indicates toxic.
:param onnx: Whether to use the ONNX model for toxicity analysis. This is mutually exclusive with model options.
:param hf_model: The Hugging Face model to use for toxicity analysis. Defaults to martin-ha/toxic-comment-model
:param hf_model_revision: The revision of the Hugging Face model to use. This default can change between releases so you
can specify the revision to lock it to a specific version.
"""
if onnx:
from langkit.metrics.toxicity_onnx import prompt_toxicity_metric

return prompt_toxicity_metric
return partial(prompt_toxicity_metric, tag=onnx_tag)
else:
from langkit.metrics.toxicity import prompt_toxicity_metric

return prompt_toxicity_metric
return partial(prompt_toxicity_metric, hf_model=hf_model, hf_model_revision=hf_model_revision)

class stats:
def __call__(self) -> MetricCreator:
Expand Down
35 changes: 17 additions & 18 deletions langkit/metrics/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportUnknownLambdaType=none
import os
from functools import lru_cache, partial
from typing import List, cast
from typing import List, Optional, cast

import pandas as pd
import torch
Expand All @@ -22,40 +22,39 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List
return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore


_model_path = "martin-ha/toxic-comment-model"
_revision = "9842c08b35a4687e7b211187d676986c8c96256d"


def _cache_assets():
AutoModelForSequenceClassification.from_pretrained(_model_path, revision=_revision)
AutoTokenizer.from_pretrained(_model_path, revision=_revision)
def _cache_assets(model_path: str, revision: str):
AutoModelForSequenceClassification.from_pretrained(model_path, revision=revision)
AutoTokenizer.from_pretrained(model_path, revision=revision)


@lru_cache
def _get_tokenizer() -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_model_path, local_files_only=True, revision=_revision)
def _get_tokenizer(model_path: str, revision: str) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(model_path, local_files_only=True, revision=revision)


@lru_cache
def _get_pipeline() -> TextClassificationPipeline:
def _get_pipeline(model_path: str, revision: str) -> TextClassificationPipeline:
use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False))
model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained(
_model_path, local_files_only=True, revision=_revision
model_path, local_files_only=True, revision=revision
)
tokenizer = _get_tokenizer()
tokenizer = _get_tokenizer(model_path, revision)
return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1)


def toxicity_metric(column_name: str) -> Metric:
def toxicity_metric(column_name: str, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None) -> Metric:
model_path = "martin-ha/toxic-comment-model" if hf_model is None else hf_model
revision = "9842c08b35a4687e7b211187d676986c8c96256d" if hf_model_revision is None else hf_model_revision

def cache_assets():
_cache_assets()
_cache_assets(model_path, revision)

def init():
_get_pipeline()
_get_pipeline(model_path, revision)

def udf(text: pd.DataFrame) -> SingleMetricResult:
_tokenizer = _get_tokenizer()
_pipeline = _get_pipeline()
_tokenizer = _get_tokenizer(model_path, revision)
_pipeline = _get_pipeline(model_path, revision)

col = list(UdfInput(text).iter_column_rows(column_name))
max_length = cast(int, _tokenizer.model_max_length)
Expand Down
28 changes: 14 additions & 14 deletions langkit/metrics/toxicity_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportUnknownLambdaType=none
import os
from functools import lru_cache, partial
from typing import List, cast
from typing import List, Optional, cast

import numpy as np
import onnxruntime
Expand Down Expand Up @@ -36,35 +36,35 @@ def __toxicity(tokenizer: PreTrainedTokenizerBase, session: onnxruntime.Inferenc
return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore


def _download_assets():
name, tag = TransformerModel.ToxicCommentModel.value
return get_asset(name, tag)
def _download_assets(tag: Optional[str]):
name, default_tag = TransformerModel.ToxicCommentModel.value
return get_asset(name, tag or default_tag)


@lru_cache
def _get_tokenizer() -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_download_assets())
def _get_tokenizer(tag: Optional[str]) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(_download_assets(tag))


@lru_cache
def _get_session() -> onnxruntime.InferenceSession:
downloaded_path = _download_assets()
def _get_session(tag: Optional[str]) -> onnxruntime.InferenceSession:
downloaded_path = _download_assets(tag)
onnx_model_path = os.path.join(downloaded_path, "model.onnx")
print(f"Loading ONNX model from {onnx_model_path}")
return onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])


def toxicity_metric(column_name: str) -> Metric:
def toxicity_metric(column_name: str, tag: Optional[str] = None) -> Metric:
def cache_assets():
_download_assets()
_download_assets(tag)

def init():
_get_session()
_get_tokenizer()
_get_session(tag)
_get_tokenizer(tag)

def udf(text: pd.DataFrame) -> SingleMetricResult:
_tokenizer = _get_tokenizer()
_session = _get_session()
_tokenizer = _get_tokenizer(tag)
_session = _get_session(tag)

col = list(UdfInput(text).iter_column_rows(column_name))
max_length = cast(int, _tokenizer.model_max_length)
Expand Down
11 changes: 11 additions & 0 deletions tests/langkit/metrics/test_toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import whylogs as why
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib as metrics_lib
from langkit.metrics.toxicity import prompt_response_toxicity_module, prompt_toxicity_metric, response_toxicity_metric
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

Expand Down Expand Up @@ -81,6 +83,15 @@ def test_prompt_toxicity_row_non_toxic():
assert actual["distribution/max"]["prompt.toxicity.toxicity_score"] < 0.1


def test_prompt_toxicity_version():
wf = Workflow(metrics=[metrics_lib.prompt.toxicity.toxicity_score(hf_model_revision="f1c3aa41130e8baeee31c3ea5d14598a0d3385e5")])
result = wf.run(row)

expected_columns = ["prompt.toxicity.toxicity_score", "id"]

assert list(result.metrics.columns) == expected_columns


def test_prompt_toxicity_df_non_toxic():
schema = WorkflowMetricConfigBuilder().add(prompt_toxicity_metric).build()

Expand Down
Loading