Skip to content

Commit

Permalink
Merge pull request #295 from whylabs/st-config
Browse files Browse the repository at this point in the history
Add support for customizing the sentence transformer
  • Loading branch information
naddeoa authored Apr 15, 2024
2 parents d4c0ee3 + 2236a88 commit 2141d43
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 57 deletions.
4 changes: 2 additions & 2 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def _get_embeddings(version: str) -> "np.ndarray[Any, Any]":
return __process_embeddings(__download_embeddings(version))


def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric:
def injections_metric(column_name: str, version: str = "v2") -> Metric:
def cache_assets():
__download_embeddings(version)

def init():
_get_embeddings(version)

embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name)
embedding_dep = EmbeddingContextDependency(embedding_choice="default", input_column=column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
if column_name not in text.columns:
Expand Down
10 changes: 6 additions & 4 deletions langkit/metrics/input_context_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded
from langkit.transformer import EmbeddingContextDependency, RAGContextDependency
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, RAGContextDependency


def input_context_similarity(input_column_name: str = "prompt", context_column_name: str = "context", onnx: bool = True) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name)
context_embedding_dep = RAGContextDependency(onnx=onnx, context_column_name=context_column_name)
def input_context_similarity(
input_column_name: str = "prompt", context_column_name: str = "context", embedding: EmbeddingChoiceArg = "default"
) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name)
context_embedding_dep = RAGContextDependency(embedding_choice=embedding, context_column_name=context_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
prompt_embedding = prompt_embedding_dep.get_request_data(context)
Expand Down
10 changes: 6 additions & 4 deletions langkit/metrics/input_output_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.embeddings_utils import compute_embedding_similarity_encoded
from langkit.transformer import EmbeddingContextDependency
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency


def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=input_column_name)
response_embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=output_column_name)
def input_output_similarity_metric(
input_column_name: str = "prompt", output_column_name: str = "response", embedding: EmbeddingChoiceArg = "default"
) -> Metric:
prompt_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=input_column_name)
response_embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=output_column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
prompt_embedding = prompt_embedding_dep.get_request_data(context)
Expand Down
27 changes: 14 additions & 13 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

from langkit.core.metric import MetricCreator
from langkit.transformer import EmbeddingChoiceArg


class lib:
Expand Down Expand Up @@ -251,33 +252,33 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator:
def injection(version: Optional[str] = None) -> MetricCreator:
"""
Analyze the input for injection themes. The injection score is a measure of how similar the input is
to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.injections import prompt_injections_metric

if version:
return partial(prompt_injections_metric, onnx=onnx, version=version)
return partial(prompt_injections_metric, version=version)

return partial(prompt_injections_metric, onnx=onnx)
return partial(prompt_injections_metric)

@staticmethod
def jailbreak(onnx: bool = True) -> MetricCreator:
def jailbreak(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is
to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric

return partial(prompt_jailbreak_similarity_metric, onnx=onnx)
return partial(prompt_jailbreak_similarity_metric, embedding=embedding)

@staticmethod
def context(onnx: bool = True) -> MetricCreator:
def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
from langkit.metrics.input_context_similarity import input_context_similarity

return partial(input_context_similarity, onnx=onnx)
return partial(input_context_similarity, embedding=embedding)

class sentiment:
def __call__(self) -> MetricCreator:
Expand Down Expand Up @@ -494,30 +495,30 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def prompt(onnx: bool = True) -> MetricCreator:
def prompt(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1,
where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric

return partial(prompt_response_input_output_similarity_metric, onnx=onnx)
return partial(prompt_response_input_output_similarity_metric, embedding=embedding)

@staticmethod
def refusal(onnx: bool = True) -> MetricCreator:
def refusal(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
"""
Analyze the response for refusal themes. The refusal score is a measure of how similar the response is
to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import response_refusal_similarity_metric

return partial(response_refusal_similarity_metric, onnx=onnx)
return partial(response_refusal_similarity_metric, embedding=embedding)

@staticmethod
def context(onnx: bool = True) -> MetricCreator:
def context(embedding: EmbeddingChoiceArg = "default") -> MetricCreator:
from langkit.metrics.input_context_similarity import input_context_similarity

return partial(input_context_similarity, onnx=onnx, input_column_name="response")
return partial(input_context_similarity, embedding=embedding, input_column_name="response")

class topics:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True):
Expand Down
8 changes: 3 additions & 5 deletions langkit/metrics/themes/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from langkit.core.context import Context
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.transformer import EmbeddingContextDependency, embedding_adapter
from langkit.transformer import EmbeddingChoiceArg, EmbeddingContextDependency, embedding_adapter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_themes() -> Dict[str, torch.Tensor]:
return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()}


def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric:
def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], embedding: EmbeddingChoiceArg = "default") -> Metric:
if themes_group == "refusal" and column_name == "prompt":
raise ValueError("Refusal themes are not applicable to prompt")

Expand All @@ -70,12 +70,10 @@ def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusa
def init():
_get_themes()

embedding_dep = EmbeddingContextDependency(onnx=onnx, input_column=column_name)
embedding_dep = EmbeddingContextDependency(embedding_choice=embedding, input_column=column_name)

def udf(text: pd.DataFrame, context: Context) -> SingleMetricResult:
theme = _get_themes()[themes_group] # (n_theme_examples, embedding_dim)
# text_list: List[str] = text[column_name].tolist()
# encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim)
encoded_text = embedding_dep.get_request_data(context)
similarities = F.cosine_similarity(encoded_text.unsqueeze(1), theme.unsqueeze(0), dim=2) # (n_input_rows, n_theme_examples)
max_similarities = similarities.max(dim=1)[0] # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] (n_input_rows,)
Expand Down
99 changes: 70 additions & 29 deletions langkit/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal, Tuple
from typing import List, Literal, Union

import pandas as pd
import torch
Expand All @@ -12,43 +13,74 @@
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel


def _sentence_transformer(
name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"),
) -> SentenceTransformer:
"""
Returns a SentenceTransformer model instance.
class EmbeddingChoice(ABC):
@abstractmethod
def get_encoder(self) -> EmbeddingEncoder:
raise NotImplementedError()

The intent of this function is to cache the SentenceTransformer instance to avoid
multple instances being created all over langkit, and have a single place that
can be used to change the transformer name for the metrics that default to the same one.
"""
transformer_name, revision = name_revision
device = "cuda" if torch.cuda.is_available() else "cpu"
return SentenceTransformer(transformer_name, revision=revision, device=device)

class SentenceTransformerChoice(EmbeddingChoice):
def __init__(self, name: str, revision: str):
self.name = name
self.revision = revision

@lru_cache
def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder:
if onnx:
def get_encoder(self) -> EmbeddingEncoder:
device = "cuda" if torch.cuda.is_available() else "cpu"
return TransformerEmbeddingAdapter(SentenceTransformer(self.name, revision=self.revision, device=device))


class DefaultChoice(SentenceTransformerChoice):
def __init__(self):
super().__init__("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6")


class OnnxChoice(EmbeddingChoice):
def get_encoder(self) -> EmbeddingEncoder:
return OnnxSentenceTransformer(TransformerModel.AllMiniLM)


@dataclass(frozen=True)
class SentenceTransformerTarget:
name: str
revision: str


EmbeddingChoiceArg = Union[Literal["default"], Literal["onnx"], SentenceTransformerTarget]


@lru_cache
def embedding_adapter(choice: EmbeddingChoiceArg = "default") -> EmbeddingEncoder:
if choice == "default":
return DefaultChoice().get_encoder()
elif choice == "onnx":
return OnnxChoice().get_encoder()
else:
return TransformerEmbeddingAdapter(_sentence_transformer())
return SentenceTransformerChoice(choice.name, choice.revision).get_encoder()


@dataclass(frozen=True)
class EmbeddingContextDependency(ContextDependency[torch.Tensor]):
onnx: bool
embedding_choice: EmbeddingChoiceArg
input_column: str

def name(self) -> str:
return f"{self.input_column}.embedding?onnx={self.onnx}"
if self.embedding_choice == "default":
choice_str = "default"
elif self.embedding_choice == "onnx":
choice_str = "onnx"
else:
choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}"

return f"{self.input_column}.embedding?type={choice_str}"

def _get_encoder(self) -> EmbeddingEncoder:
return embedding_adapter(choice=self.embedding_choice)

def cache_assets(self) -> None:
# TODO do only the downloading
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def init(self) -> None:
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def populate_request(self, context: Context, data: pd.DataFrame):
if self.input_column not in data.columns:
Expand All @@ -57,7 +89,7 @@ def populate_request(self, context: Context, data: pd.DataFrame):
if self.name() in context.request_data:
return

encoder = embedding_adapter(onnx=self.onnx)
encoder = self._get_encoder()
embedding = encoder.encode(tuple(data[self.input_column])) # pyright: ignore[reportUnknownArgumentType]
context.request_data[self.name()] = embedding

Expand All @@ -67,7 +99,7 @@ def get_request_data(self, context: Context) -> torch.Tensor:

@dataclass(frozen=True)
class RAGContextDependency(ContextDependency[torch.Tensor]):
onnx: bool
embedding_choice: EmbeddingChoiceArg
strategy: Literal["combine"] = "combine"
"""
The strategy for converting the context into embeddings.
Expand All @@ -77,14 +109,23 @@ class RAGContextDependency(ContextDependency[torch.Tensor]):
context_column_name: str = "context"

def name(self) -> str:
return f"{self.context_column_name}.context?onnx={self.onnx}"
if self.embedding_choice == "default":
choice_str = "default"
elif self.embedding_choice == "onnx":
choice_str = "onnx"
else:
choice_str = f"{self.embedding_choice.name}-{self.embedding_choice.revision}"

return f"{self.context_column_name}.context?type={choice_str}&strategy={self.strategy}"

def _get_encoder(self) -> EmbeddingEncoder:
return embedding_adapter(choice=self.embedding_choice)

def cache_assets(self) -> None:
# TODO do only the downloading
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def init(self) -> None:
embedding_adapter(onnx=self.onnx)
self._get_encoder()

def populate_request(self, context: Context, data: pd.DataFrame):
if self.context_column_name not in data.columns:
Expand All @@ -104,7 +145,7 @@ def populate_request(self, context: Context, data: pd.DataFrame):
else:
raise ValueError(f"Unknown context embedding strategy {self.strategy}")

encoder = embedding_adapter(onnx=self.onnx)
encoder = self._get_encoder()
embedding = encoder.encode(tuple(combined))
context.request_data[self.name()] = embedding

Expand Down

0 comments on commit 2141d43

Please sign in to comment.