diff --git a/langkit/metrics/embeddings_types.py b/langkit/metrics/embeddings_types.py index f97d528..834cb5d 100644 --- a/langkit/metrics/embeddings_types.py +++ b/langkit/metrics/embeddings_types.py @@ -5,18 +5,18 @@ from sentence_transformers import SentenceTransformer -class TransformerEmbeddingAdapter: +class EmbeddingEncoder(Protocol): + def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": + ... + + +class TransformerEmbeddingAdapter(EmbeddingEncoder): def __init__(self, transformer: SentenceTransformer): self._transformer = transformer @lru_cache(maxsize=6, typed=True) - def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": - return torch.as_tensor(self._transformer.encode(sentences=list(text))) # type: ignore[reportUnknownMemberType] - - -class EmbeddingEncoder(Protocol): - def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": - ... + def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": # pyright: ignore[reportIncompatibleMethodOverride] + return torch.as_tensor(self._transformer.encode(sentences=list(text), show_progress_bar=False)) # type: ignore[reportUnknownMemberType] class CachingEmbeddingEncoder(EmbeddingEncoder): diff --git a/langkit/metrics/injections.py b/langkit/metrics/injections.py index 031dce7..fc1bd0b 100644 --- a/langkit/metrics/injections.py +++ b/langkit/metrics/injections.py @@ -10,7 +10,7 @@ from langkit.config import LANGKIT_CACHE from langkit.core.metric import Metric, SingleMetric, SingleMetricResult from langkit.metrics.util import retry -from langkit.transformer import embedding_adapter, sentence_transformer +from langkit.transformer import embedding_adapter logger = getLogger(__name__) @@ -34,7 +34,8 @@ def __cache_embeddings(harm_embeddings: pd.DataFrame, embeddings_path: str, file logger.warning(f"Injections - unable to serialize embeddings to {embeddings_path_local}. Error: {serialization_error}") -def __download_embeddings(filename: str) -> pd.DataFrame: +def __download_embeddings(version: str) -> pd.DataFrame: + filename = f"embeddings_{__transformer_name}_harm_{version}.parquet" embeddings_path_remote: str = __injections_base_url + filename embeddings_path_local: str = os.path.join(LANGKIT_INJECTIONS_CACHE, filename) try: @@ -60,18 +61,17 @@ def __process_embeddings(harm_embeddings: pd.DataFrame) -> "np.ndarray[Any, Any] @lru_cache def _get_embeddings(version: str) -> "np.ndarray[Any, Any]": - filename = f"embeddings_{__transformer_name}_harm_{version}.parquet" - harm_embeddings = __download_embeddings(filename) - embeddings_norm = __process_embeddings(harm_embeddings) - return embeddings_norm + return __process_embeddings(__download_embeddings(version)) def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric: def cache_assets(): - _get_embeddings(version) + __download_embeddings(version) + embedding_adapter(onnx) def init(): - embedding_adapter() + _get_embeddings(version) + embedding_adapter(onnx) def udf(text: pd.DataFrame) -> SingleMetricResult: if column_name not in text.columns: @@ -80,12 +80,8 @@ def udf(text: pd.DataFrame) -> SingleMetricResult: input_series: "pd.Series[str]" = cast("pd.Series[str]", text[column_name]) - if onnx: - _transformer = embedding_adapter() - target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy() - else: - _transformer = sentence_transformer() - target_embeddings: npt.NDArray[np.float32] = _transformer.encode(list(input_series), show_progress_bar=False) # pyright: ignore[reportAssignmentType, reportUnknownMemberType] + _transformer = embedding_adapter(onnx) + target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy() target_norms = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True) cosine_similarities = np.dot(_embeddings, target_norms.T) diff --git a/langkit/metrics/input_output_similarity.py b/langkit/metrics/input_output_similarity.py index a8f6ca8..db2a4ca 100644 --- a/langkit/metrics/input_output_similarity.py +++ b/langkit/metrics/input_output_similarity.py @@ -7,14 +7,17 @@ from langkit.transformer import embedding_adapter -def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response") -> Metric: +def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric: + def cache_assets(): + embedding_adapter(onnx) + def init(): - embedding_adapter() + embedding_adapter(onnx) def udf(text: pd.DataFrame) -> SingleMetricResult: in_np = UdfInput(text).to_list(input_column_name) out_np = UdfInput(text).to_list(output_column_name) - encoder = embedding_adapter() + encoder = embedding_adapter(onnx) similarity = compute_embedding_similarity(encoder, in_np, out_np) if len(similarity.shape) == 1: @@ -27,6 +30,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult: input_names=[input_column_name, output_column_name], evaluate=udf, init=init, + cache_assets=cache_assets, ) diff --git a/langkit/metrics/library.py b/langkit/metrics/library.py index 0f5bef6..d08faa6 100644 --- a/langkit/metrics/library.py +++ b/langkit/metrics/library.py @@ -249,7 +249,7 @@ def __call__(self) -> MetricCreator: ] @staticmethod - def injection(version: Optional[str] = None) -> MetricCreator: + def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator: """ Analyze the input for injection themes. The injection score is a measure of how similar the input is to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity. @@ -257,19 +257,19 @@ def injection(version: Optional[str] = None) -> MetricCreator: from langkit.metrics.injections import injections_metric, prompt_injections_metric if version: - return partial(injections_metric, column_name="prompt", version=version) + return partial(injections_metric, column_name="prompt", version=version, onnx=onnx) return prompt_injections_metric @staticmethod - def jailbreak() -> MetricCreator: + def jailbreak(onnx: bool = True) -> MetricCreator: """ Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric - return prompt_jailbreak_similarity_metric + return partial(prompt_jailbreak_similarity_metric, onnx=onnx) class sentiment: def __call__(self) -> MetricCreator: @@ -302,7 +302,7 @@ def __call__(self) -> MetricCreator: return partial(topic_metric, "prompt", self.topics, self.hypothesis_template) @staticmethod - def medicine(onnx: bool = False) -> MetricCreator: + def medicine(onnx: bool = True) -> MetricCreator: if onnx: from langkit.metrics.topic_onnx import topic_metric @@ -486,24 +486,24 @@ def __call__(self) -> MetricCreator: ] @staticmethod - def prompt() -> MetricCreator: + def prompt(onnx: bool = True) -> MetricCreator: """ Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric - return prompt_response_input_output_similarity_metric + return partial(prompt_response_input_output_similarity_metric, onnx=onnx) @staticmethod - def refusal() -> MetricCreator: + def refusal(onnx: bool = True) -> MetricCreator: """ Analyze the response for refusal themes. The refusal score is a measure of how similar the response is to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity. """ from langkit.metrics.themes.themes import response_refusal_similarity_metric - return response_refusal_similarity_metric + return partial(response_refusal_similarity_metric, onnx=onnx) class topics: def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True): @@ -522,7 +522,7 @@ def __call__(self) -> MetricCreator: return partial(topic_metric, "response", self.topics, self.hypothesis_template) @staticmethod - def medicine(onnx: bool = False) -> MetricCreator: + def medicine(onnx: bool = True) -> MetricCreator: if onnx: from langkit.metrics.topic_onnx import topic_metric diff --git a/langkit/metrics/themes/themes.py b/langkit/metrics/themes/themes.py index 8bb2380..bb2eea2 100644 --- a/langkit/metrics/themes/themes.py +++ b/langkit/metrics/themes/themes.py @@ -59,7 +59,7 @@ def _get_themes(encoder: TransformerEmbeddingAdapter) -> Dict[str, torch.Tensor] return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()} -def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"]) -> Metric: +def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric: if themes_group == "refusal" and column_name == "prompt": raise ValueError("Refusal themes are not applicable to prompt") @@ -67,10 +67,13 @@ def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusa raise ValueError("Jailbreak themes are not applicable to response") def cache_assets(): - _get_themes(embedding_adapter()) + embedding_adapter(onnx) + + def init(): + _get_themes(embedding_adapter(onnx)) def udf(text: pd.DataFrame) -> SingleMetricResult: - encoder = embedding_adapter() + encoder = embedding_adapter(onnx) theme = _get_themes(encoder)[themes_group] # (n_theme_examples, embedding_dim) text_list: List[str] = text[column_name].tolist() encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim) @@ -84,6 +87,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult: input_names=[column_name], evaluate=udf, cache_assets=cache_assets, + init=init, ) diff --git a/langkit/transformer.py b/langkit/transformer.py index 7e4e0bf..683a3b1 100644 --- a/langkit/transformer.py +++ b/langkit/transformer.py @@ -4,12 +4,11 @@ import torch from sentence_transformers import SentenceTransformer -from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder +from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder, TransformerEmbeddingAdapter from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel -@lru_cache -def sentence_transformer( +def _sentence_transformer( name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"), ) -> SentenceTransformer: """ @@ -25,5 +24,8 @@ def sentence_transformer( @lru_cache -def embedding_adapter() -> EmbeddingEncoder: - return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM)) +def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder: + if onnx: + return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM)) + else: + return TransformerEmbeddingAdapter(_sentence_transformer())