Skip to content

Commit

Permalink
Update library defaults to use onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony Naddeo committed Apr 7, 2024
1 parent f94ccb3 commit ec1637e
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 43 deletions.
16 changes: 8 additions & 8 deletions langkit/metrics/embeddings_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
from sentence_transformers import SentenceTransformer


class TransformerEmbeddingAdapter:
class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...


class TransformerEmbeddingAdapter(EmbeddingEncoder):
def __init__(self, transformer: SentenceTransformer):
self._transformer = transformer

@lru_cache(maxsize=6, typed=True)
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
return torch.as_tensor(self._transformer.encode(sentences=list(text))) # type: ignore[reportUnknownMemberType]


class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": # pyright: ignore[reportIncompatibleMethodOverride]
return torch.as_tensor(self._transformer.encode(sentences=list(text), show_progress_bar=False)) # type: ignore[reportUnknownMemberType]


class CachingEmbeddingEncoder(EmbeddingEncoder):
Expand Down
24 changes: 10 additions & 14 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langkit.config import LANGKIT_CACHE
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.util import retry
from langkit.transformer import embedding_adapter, sentence_transformer
from langkit.transformer import embedding_adapter

logger = getLogger(__name__)

Expand All @@ -34,7 +34,8 @@ def __cache_embeddings(harm_embeddings: pd.DataFrame, embeddings_path: str, file
logger.warning(f"Injections - unable to serialize embeddings to {embeddings_path_local}. Error: {serialization_error}")


def __download_embeddings(filename: str) -> pd.DataFrame:
def __download_embeddings(version: str) -> pd.DataFrame:
filename = f"embeddings_{__transformer_name}_harm_{version}.parquet"
embeddings_path_remote: str = __injections_base_url + filename
embeddings_path_local: str = os.path.join(LANGKIT_INJECTIONS_CACHE, filename)
try:
Expand All @@ -60,18 +61,17 @@ def __process_embeddings(harm_embeddings: pd.DataFrame) -> "np.ndarray[Any, Any]

@lru_cache
def _get_embeddings(version: str) -> "np.ndarray[Any, Any]":
filename = f"embeddings_{__transformer_name}_harm_{version}.parquet"
harm_embeddings = __download_embeddings(filename)
embeddings_norm = __process_embeddings(harm_embeddings)
return embeddings_norm
return __process_embeddings(__download_embeddings(version))


def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric:
def cache_assets():
_get_embeddings(version)
__download_embeddings(version)
embedding_adapter(onnx)

def init():
embedding_adapter()
_get_embeddings(version)
embedding_adapter(onnx)

def udf(text: pd.DataFrame) -> SingleMetricResult:
if column_name not in text.columns:
Expand All @@ -80,12 +80,8 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:

input_series: "pd.Series[str]" = cast("pd.Series[str]", text[column_name])

if onnx:
_transformer = embedding_adapter()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()
else:
_transformer = sentence_transformer()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(list(input_series), show_progress_bar=False) # pyright: ignore[reportAssignmentType, reportUnknownMemberType]
_transformer = embedding_adapter(onnx)
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()

target_norms = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True)
cosine_similarities = np.dot(_embeddings, target_norms.T)
Expand Down
10 changes: 7 additions & 3 deletions langkit/metrics/input_output_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from langkit.transformer import embedding_adapter


def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response") -> Metric:
def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric:
def cache_assets():
embedding_adapter(onnx)

def init():
embedding_adapter()
embedding_adapter(onnx)

def udf(text: pd.DataFrame) -> SingleMetricResult:
in_np = UdfInput(text).to_list(input_column_name)
out_np = UdfInput(text).to_list(output_column_name)
encoder = embedding_adapter()
encoder = embedding_adapter(onnx)
similarity = compute_embedding_similarity(encoder, in_np, out_np)

if len(similarity.shape) == 1:
Expand All @@ -27,6 +30,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
input_names=[input_column_name, output_column_name],
evaluate=udf,
init=init,
cache_assets=cache_assets,
)


Expand Down
20 changes: 10 additions & 10 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,27 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def injection(version: Optional[str] = None) -> MetricCreator:
def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator:
"""
Analyze the input for injection themes. The injection score is a measure of how similar the input is
to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.injections import injections_metric, prompt_injections_metric

if version:
return partial(injections_metric, column_name="prompt", version=version)
return partial(injections_metric, column_name="prompt", version=version, onnx=onnx)

return prompt_injections_metric

@staticmethod
def jailbreak() -> MetricCreator:
def jailbreak(onnx: bool = True) -> MetricCreator:
"""
Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is
to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric

return prompt_jailbreak_similarity_metric
return partial(prompt_jailbreak_similarity_metric, onnx=onnx)

class sentiment:
def __call__(self) -> MetricCreator:
Expand Down Expand Up @@ -302,7 +302,7 @@ def __call__(self) -> MetricCreator:
return partial(topic_metric, "prompt", self.topics, self.hypothesis_template)

@staticmethod
def medicine(onnx: bool = False) -> MetricCreator:
def medicine(onnx: bool = True) -> MetricCreator:
if onnx:
from langkit.metrics.topic_onnx import topic_metric

Expand Down Expand Up @@ -486,24 +486,24 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def prompt() -> MetricCreator:
def prompt(onnx: bool = True) -> MetricCreator:
"""
Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1,
where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric

return prompt_response_input_output_similarity_metric
return partial(prompt_response_input_output_similarity_metric, onnx=onnx)

@staticmethod
def refusal() -> MetricCreator:
def refusal(onnx: bool = True) -> MetricCreator:
"""
Analyze the response for refusal themes. The refusal score is a measure of how similar the response is
to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import response_refusal_similarity_metric

return response_refusal_similarity_metric
return partial(response_refusal_similarity_metric, onnx=onnx)

class topics:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True):
Expand All @@ -522,7 +522,7 @@ def __call__(self) -> MetricCreator:
return partial(topic_metric, "response", self.topics, self.hypothesis_template)

@staticmethod
def medicine(onnx: bool = False) -> MetricCreator:
def medicine(onnx: bool = True) -> MetricCreator:
if onnx:
from langkit.metrics.topic_onnx import topic_metric

Expand Down
10 changes: 7 additions & 3 deletions langkit/metrics/themes/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,21 @@ def _get_themes(encoder: TransformerEmbeddingAdapter) -> Dict[str, torch.Tensor]
return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()}


def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"]) -> Metric:
def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric:
if themes_group == "refusal" and column_name == "prompt":
raise ValueError("Refusal themes are not applicable to prompt")

if themes_group == "jailbreak" and column_name == "response":
raise ValueError("Jailbreak themes are not applicable to response")

def cache_assets():
_get_themes(embedding_adapter())
embedding_adapter(onnx)

def init():
_get_themes(embedding_adapter(onnx))

def udf(text: pd.DataFrame) -> SingleMetricResult:
encoder = embedding_adapter()
encoder = embedding_adapter(onnx)
theme = _get_themes(encoder)[themes_group] # (n_theme_examples, embedding_dim)
text_list: List[str] = text[column_name].tolist()
encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim)
Expand All @@ -84,6 +87,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
input_names=[column_name],
evaluate=udf,
cache_assets=cache_assets,
init=init,
)


Expand Down
12 changes: 7 additions & 5 deletions langkit/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import torch
from sentence_transformers import SentenceTransformer

from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder
from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder, TransformerEmbeddingAdapter
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel


@lru_cache
def sentence_transformer(
def _sentence_transformer(
name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"),
) -> SentenceTransformer:
"""
Expand All @@ -25,5 +24,8 @@ def sentence_transformer(


@lru_cache
def embedding_adapter() -> EmbeddingEncoder:
return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM))
def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder:
if onnx:
return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM))
else:
return TransformerEmbeddingAdapter(_sentence_transformer())

0 comments on commit ec1637e

Please sign in to comment.