Skip to content

Commit

Permalink
Add labels match to OpenAIChatRuntime, RAG with only errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Nov 30, 2023
1 parent 05e5ba4 commit e73e317
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
24 changes: 22 additions & 2 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import difflib
from rich import print

from typing import Optional, Dict, Any, List
Expand Down Expand Up @@ -32,9 +33,9 @@ def check_if_new_openai_version():
from tenacity import retry, stop_after_attempt, wait_random


@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(6))
@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3))
def chat_completion_call(model, messages):
return openai.ChatCompletion.create(model=model, messages=messages)
return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120)


class OpenAIChatRuntime(Runtime):
Expand Down Expand Up @@ -158,8 +159,27 @@ def record_to_record(
]

completion_text = self.execute(messages)

field_schema = field_schema or {}
if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array":
# expected output is one item from the array
expected_items = field_schema[output_field_name]['items']['enum']
completion_text = self._match_items(completion_text, expected_items)

return {output_field_name: completion_text}

def _match_items(self, query: str, items: List[str]) -> str:
# hard constraint: the item must be in the query
filtered_items = [item for item in items if item in query]
if not filtered_items:
# make the best guess - find the most similar item to the query
filtered_items = items

# soft constraint: find the most similar item to the query
scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items))
matched_item = filtered_items[scores.index(max(scores))]
return matched_item


class OpenAIVisionRuntime(OpenAIChatRuntime):
"""
Expand Down
14 changes: 9 additions & 5 deletions adala/skills/collection/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class RAGSkill(TransformSkill):
output_template: str = "{rag}"
num_results: int = 1
memory: Memory = None
only_errors: bool = True

@model_validator(mode="after")
def init_memory(self):
Expand Down Expand Up @@ -119,13 +120,16 @@ def improve(
runtime: Runtime to use for generation (not used).
"""

error_indices = feedback.match[
(feedback.match.fillna(True) == False).any(axis=1)
].index
inputs = predictions.loc[error_indices]
if self.only_errors:
indices = feedback.match[
(feedback.match.fillna(True) == False).any(axis=1)
].index
else:
indices = feedback.match.index
inputs = predictions.loc[indices]
input_strings = inputs.apply(
lambda r: self.input_template.format(**r), axis=1
).tolist()
fb = feedback.feedback.loc[error_indices].rename(columns=lambda c: f"{c}__fb")
fb = feedback.feedback.loc[indices].rename(columns=lambda c: f"{c}__fb")
inputs = inputs.join(fb)
self.memory.remember_many(input_strings, inputs.to_dict(orient="records"))

0 comments on commit e73e317

Please sign in to comment.