diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index f41fed4..e7a53ae 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -116,19 +116,26 @@ def __call__(self) -> MetricCreator: return self.toxicity_score() @staticmethod - def toxicity_score(onnx: bool = True) -> MetricCreator: + def toxicity_score( + onnx: bool = True, onnx_tag: Optional[str] = None, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None + ) -> MetricCreator: """ Analyze the input for toxicity. The output of this metric ranges from 0 to 1, where 0 indicates non-toxic and 1 indicates toxic. + + :param onnx: Whether to use the ONNX model for toxicity analysis. This is mutually exclusive with model options. + :param hf_model: The Hugging Face model to use for toxicity analysis. Defaults to martin-ha/toxic-comment-model + :param hf_model_revision: The revision of the Hugging Face model to use. This default can change between releases so you + can specify the revision to lock it to a specific version. """ if onnx: from langkit.metrics.toxicity_onnx import prompt_toxicity_metric - return prompt_toxicity_metric + return partial(prompt_toxicity_metric, tag=onnx_tag) else: from langkit.metrics.toxicity import prompt_toxicity_metric - return prompt_toxicity_metric + return partial(prompt_toxicity_metric, hf_model=hf_model, hf_model_revision=hf_model_revision) class stats: def __call__(self) -> MetricCreator: diff --git a/langkit/metrics/toxicity.py b/langkit/metrics/toxicity.py index 36b9d9f..73613e2 100644 --- a/langkit/metrics/toxicity.py +++ b/langkit/metrics/toxicity.py @@ -3,7 +3,7 @@ # pyright: reportUnknownLambdaType=none import os from functools import lru_cache, partial -from typing import List, cast +from typing import List, Optional, cast import pandas as pd import torch @@ -22,40 +22,39 @@ def __toxicity(pipeline: TextClassificationPipeline, max_length: int, text: List return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore -_model_path = "martin-ha/toxic-comment-model" -_revision = "9842c08b35a4687e7b211187d676986c8c96256d" - - -def _cache_assets(): - AutoModelForSequenceClassification.from_pretrained(_model_path, revision=_revision) - AutoTokenizer.from_pretrained(_model_path, revision=_revision) +def _cache_assets(model_path: str, revision: str): + AutoModelForSequenceClassification.from_pretrained(model_path, revision=revision) + AutoTokenizer.from_pretrained(model_path, revision=revision) @lru_cache -def _get_tokenizer() -> PreTrainedTokenizerBase: - return AutoTokenizer.from_pretrained(_model_path, local_files_only=True, revision=_revision) +def _get_tokenizer(model_path: str, revision: str) -> PreTrainedTokenizerBase: + return AutoTokenizer.from_pretrained(model_path, local_files_only=True, revision=revision) @lru_cache -def _get_pipeline() -> TextClassificationPipeline: +def _get_pipeline(model_path: str, revision: str) -> TextClassificationPipeline: use_cuda = torch.cuda.is_available() and not bool(os.environ.get("LANGKIT_NO_CUDA", False)) model: PreTrainedTokenizerBase = AutoModelForSequenceClassification.from_pretrained( - _model_path, local_files_only=True, revision=_revision + model_path, local_files_only=True, revision=revision ) - tokenizer = _get_tokenizer() + tokenizer = _get_tokenizer(model_path, revision) return TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if use_cuda else -1) -def toxicity_metric(column_name: str) -> Metric: +def toxicity_metric(column_name: str, hf_model: Optional[str] = None, hf_model_revision: Optional[str] = None) -> Metric: + model_path = "martin-ha/toxic-comment-model" if hf_model is None else hf_model + revision = "9842c08b35a4687e7b211187d676986c8c96256d" if hf_model_revision is None else hf_model_revision + def cache_assets(): - _cache_assets() + _cache_assets(model_path, revision) def init(): - _get_pipeline() + _get_pipeline(model_path, revision) def udf(text: pd.DataFrame) -> SingleMetricResult: - _tokenizer = _get_tokenizer() - _pipeline = _get_pipeline() + _tokenizer = _get_tokenizer(model_path, revision) + _pipeline = _get_pipeline(model_path, revision) col = list(UdfInput(text).iter_column_rows(column_name)) max_length = cast(int, _tokenizer.model_max_length) diff --git a/langkit/metrics/toxicity_onnx.py b/langkit/metrics/toxicity_onnx.py index 6dd87ac..6ca67f1 100644 --- a/langkit/metrics/toxicity_onnx.py +++ b/langkit/metrics/toxicity_onnx.py @@ -3,7 +3,7 @@ # pyright: reportUnknownLambdaType=none import os from functools import lru_cache, partial -from typing import List, cast +from typing import List, Optional, cast import numpy as np import onnxruntime @@ -36,35 +36,35 @@ def __toxicity(tokenizer: PreTrainedTokenizerBase, session: onnxruntime.Inferenc return [result["score"] if result["label"] == "toxic" else 1.0 - result["score"] for result in results] # type: ignore -def _download_assets(): - name, tag = TransformerModel.ToxicCommentModel.value - return get_asset(name, tag) +def _download_assets(tag: Optional[str]): + name, default_tag = TransformerModel.ToxicCommentModel.value + return get_asset(name, tag or default_tag) @lru_cache -def _get_tokenizer() -> PreTrainedTokenizerBase: - return AutoTokenizer.from_pretrained(_download_assets()) +def _get_tokenizer(tag: Optional[str]) -> PreTrainedTokenizerBase: + return AutoTokenizer.from_pretrained(_download_assets(tag)) @lru_cache -def _get_session() -> onnxruntime.InferenceSession: - downloaded_path = _download_assets() +def _get_session(tag: Optional[str]) -> onnxruntime.InferenceSession: + downloaded_path = _download_assets(tag) onnx_model_path = os.path.join(downloaded_path, "model.onnx") print(f"Loading ONNX model from {onnx_model_path}") return onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"]) -def toxicity_metric(column_name: str) -> Metric: +def toxicity_metric(column_name: str, tag: Optional[str] = None) -> Metric: def cache_assets(): - _download_assets() + _download_assets(tag) def init(): - _get_session() - _get_tokenizer() + _get_session(tag) + _get_tokenizer(tag) def udf(text: pd.DataFrame) -> SingleMetricResult: - _tokenizer = _get_tokenizer() - _session = _get_session() + _tokenizer = _get_tokenizer(tag) + _session = _get_session(tag) col = list(UdfInput(text).iter_column_rows(column_name)) max_length = cast(int, _tokenizer.model_max_length) diff --git a/tests/langkit/metrics/test_toxicity.py b/tests/langkit/metrics/test_toxicity.py index c0fd076..e347627 100644 --- a/tests/langkit/metrics/test_toxicity.py +++ b/tests/langkit/metrics/test_toxicity.py @@ -5,6 +5,8 @@ import whylogs as why from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder +from langkit.core.workflow import Workflow +from langkit.metrics.library import lib as metrics_lib from langkit.metrics.toxicity import prompt_response_toxicity_module, prompt_toxicity_metric, response_toxicity_metric from langkit.metrics.whylogs_compat import create_whylogs_udf_schema @@ -81,6 +83,15 @@ def test_prompt_toxicity_row_non_toxic(): assert actual["distribution/max"]["prompt.toxicity.toxicity_score"] < 0.1 +def test_prompt_toxicity_version(): + wf = Workflow(metrics=[metrics_lib.prompt.toxicity.toxicity_score(hf_model_revision="f1c3aa41130e8baeee31c3ea5d14598a0d3385e5")]) + result = wf.run(row) + + expected_columns = ["prompt.toxicity.toxicity_score", "id"] + + assert list(result.metrics.columns) == expected_columns + + def test_prompt_toxicity_df_non_toxic(): schema = WorkflowMetricConfigBuilder().add(prompt_toxicity_metric).build()