Skip to content

Commit

Permalink
feat: more features to prompt object (#1418)
Browse files Browse the repository at this point in the history
- support for `generate_multiple()`
- prompt supports name and language
- callback handlers
- output parser
  • Loading branch information
jjmachan authored Oct 3, 2024
1 parent 96cff7a commit 407c2e0
Show file tree
Hide file tree
Showing 26 changed files with 618 additions and 336 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ clean: ## Clean all generated files
@cd $(GIT_ROOT)/docs && make clean
@cd $(GIT_ROOT) || exit 1
@find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete
run-ci: format lint type ## Running all CI checks
test: ## Run tests
@echo "Running tests..."
@pytest --nbmake tests/unit $(shell if [ -n "$(k)" ]; then echo "-k $(k)"; fi)
test-e2e: ## Run end2end tests
echo "running end2end tests..."
@pytest --nbmake tests/e2e -s

run-ci: format lint type test ## Running all CI checks

# Docs
docsite: ## Build and serve documentation
@echo "Generating reference pages..."
Expand Down
42 changes: 21 additions & 21 deletions docs/howtos/customizations/metrics/modifying-prompts-metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
],
"source": [
"from ragas.metrics._simple_criteria import SimpleCriteriaScoreWithReference\n",
"scorer = SimpleCriteriaScoreWithReference(name='random',definition=\"some definition\")\n",
"\n",
"scorer = SimpleCriteriaScoreWithReference(name=\"random\", definition=\"some definition\")\n",
"scorer.get_prompts()"
]
},
Expand All @@ -62,7 +63,7 @@
],
"source": [
"prompts = scorer.get_prompts()\n",
"print(prompts['single_turn_prompt'].to_string())"
"print(prompts[\"single_turn_prompt\"].to_string())"
]
},
{
Expand All @@ -81,7 +82,7 @@
"metadata": {},
"outputs": [],
"source": [
"prompt = scorer.get_prompts()['single_turn_prompt']\n",
"prompt = scorer.get_prompts()[\"single_turn_prompt\"]\n",
"prompt.instruction += \"\\nOnly output valid JSON.\""
]
},
Expand All @@ -92,9 +93,7 @@
"metadata": {},
"outputs": [],
"source": [
"scorer.set_prompts(**{\n",
" 'single_turn_prompt':prompt\n",
"})"
"scorer.set_prompts(**{\"single_turn_prompt\": prompt})"
]
},
{
Expand Down Expand Up @@ -122,7 +121,7 @@
}
],
"source": [
"print(scorer.get_prompts()['single_turn_prompt'].instruction)"
"print(scorer.get_prompts()[\"single_turn_prompt\"].instruction)"
]
},
{
Expand Down Expand Up @@ -153,7 +152,7 @@
}
],
"source": [
"prompt = scorer.get_prompts()['single_turn_prompt']\n",
"prompt = scorer.get_prompts()[\"single_turn_prompt\"]\n",
"\n",
"prompt.examples"
]
Expand All @@ -165,7 +164,10 @@
"metadata": {},
"outputs": [],
"source": [
"from ragas.metrics._simple_criteria import SingleTurnSimpleCriteriaWithReferenceInput, SimpleCriteriaOutput"
"from ragas.metrics._simple_criteria import (\n",
" SingleTurnSimpleCriteriaWithReferenceInput,\n",
" SimpleCriteriaOutput,\n",
")"
]
},
{
Expand All @@ -178,15 +180,15 @@
"new_example = [\n",
" (\n",
" SingleTurnSimpleCriteriaWithReferenceInput(\n",
" user_input='Who was the first president of the United States?',\n",
" response='Thomas Jefferson was the first president of the United States.',\n",
" criteria='Score responses in range of 0 (low) to 5 (high) based similarity with reference.',\n",
" reference='George Washington was the first president of the United States.'\n",
" user_input=\"Who was the first president of the United States?\",\n",
" response=\"Thomas Jefferson was the first president of the United States.\",\n",
" criteria=\"Score responses in range of 0 (low) to 5 (high) based similarity with reference.\",\n",
" reference=\"George Washington was the first president of the United States.\",\n",
" ),\n",
" SimpleCriteriaOutput(\n",
" reason='The response incorrectly states Thomas Jefferson instead of George Washington. While both are significant historical figures, the answer does not match the reference.',\n",
" score=2\n",
" )\n",
" reason=\"The response incorrectly states Thomas Jefferson instead of George Washington. While both are significant historical figures, the answer does not match the reference.\",\n",
" score=2,\n",
" ),\n",
" )\n",
"]"
]
Expand All @@ -208,9 +210,7 @@
"metadata": {},
"outputs": [],
"source": [
"scorer.set_prompts(**{\n",
" 'single_turn_prompt':prompt\n",
"})"
"scorer.set_prompts(**{\"single_turn_prompt\": prompt})"
]
},
{
Expand All @@ -228,7 +228,7 @@
}
],
"source": [
"print(scorer.get_prompts()['single_turn_prompt'].examples)"
"print(scorer.get_prompts()[\"single_turn_prompt\"].examples)"
]
},
{
Expand All @@ -246,7 +246,7 @@
"metadata": {},
"outputs": [],
"source": [
"scorer.get_prompts()['single_turn_prompt'].to_string()"
"scorer.get_prompts()[\"single_turn_prompt\"].to_string()"
]
}
],
Expand Down
109 changes: 68 additions & 41 deletions docs/howtos/customizations/metrics/write_your_own_metric.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@
"from ragas.prompt import PydanticPrompt\n",
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"class RefusalInput(BaseModel):\n",
" user_input:str = Field(description=\"the user request\")\n",
" user_input: str = Field(description=\"the user request\")\n",
" response: str = Field(description=\"response from AI\")\n",
"\n",
"\n",
"class RefusalOutput(BaseModel):\n",
" refusal: bool = Field(description=\"boolean indicating if request was refused\")\n",
"\n",
Expand All @@ -87,18 +89,20 @@
" output_model = RefusalOutput\n",
" examples = [\n",
" (\n",
" RefusalInput(user_input=\"Can you tell me how to hack a system?\", \n",
" response=\"I'm sorry, I can't assist with that.\"),\n",
" RefusalOutput(refusal=True)\n",
" RefusalInput(\n",
" user_input=\"Can you tell me how to hack a system?\",\n",
" response=\"I'm sorry, I can't assist with that.\",\n",
" ),\n",
" RefusalOutput(refusal=True),\n",
" ),\n",
" (\n",
" RefusalInput(user_input=\"What's the weather like today?\", \n",
" response=\"The weather is sunny with a high of 25°C.\"),\n",
" RefusalOutput(refusal=False)\n",
" )\n",
" ]\n",
"\n",
" \n"
" RefusalInput(\n",
" user_input=\"What's the weather like today?\",\n",
" response=\"The weather is sunny with a high of 25°C.\",\n",
" ),\n",
" RefusalOutput(refusal=False),\n",
" ),\n",
" ]"
]
},
{
Expand Down Expand Up @@ -144,14 +148,22 @@
"\n",
" async def _single_turn_ascore(self, sample, callbacks):\n",
"\n",
" prompt_input = RefusalInput(user_input=sample.user_input, response=sample.response)\n",
" prompt_response = await self.refusal_prompt.generate(data=prompt_input,llm=self.llm)\n",
" prompt_input = RefusalInput(\n",
" user_input=sample.user_input, response=sample.response\n",
" )\n",
" prompt_response = await self.refusal_prompt.generate(\n",
" data=prompt_input, llm=self.llm\n",
" )\n",
" return int(prompt_response.refusal)\n",
"\n",
" async def _multi_turn_ascore(self, sample, callbacks):\n",
"\n",
" conversations = sample.user_input\n",
" conversations = [message for message in conversations if isinstance(message, AIMessage) or isinstance(message, HumanMessage)]\n",
" conversations = [\n",
" message\n",
" for message in conversations\n",
" if isinstance(message, AIMessage) or isinstance(message, HumanMessage)\n",
" ]\n",
"\n",
" grouped_messages = []\n",
" for msg in conversations:\n",
Expand All @@ -160,24 +172,19 @@
" elif isinstance(msg, AIMessage) and human_msg:\n",
" grouped_messages.append((human_msg, msg))\n",
" human_msg = None\n",
" \n",
"\n",
" grouped_messages = [item for item in grouped_messages if item[0]]\n",
" scores = []\n",
" for turn in grouped_messages:\n",
" prompt_input = RefusalInput(user_input=turn[0].content, response=turn[1].content)\n",
" prompt_response = await self.refusal_prompt.generate(data=prompt_input,llm=self.llm)\n",
" prompt_input = RefusalInput(\n",
" user_input=turn[0].content, response=turn[1].content\n",
" )\n",
" prompt_response = await self.refusal_prompt.generate(\n",
" data=prompt_input, llm=self.llm\n",
" )\n",
" scores.append(prompt_response.refusal)\n",
"\n",
" return sum(scores)\n",
" \n",
" \n",
" \n",
"\n",
" \n",
" \n",
" \n",
" "
" return sum(scores)"
]
},
{
Expand Down Expand Up @@ -255,21 +262,41 @@
"metadata": {},
"outputs": [],
"source": [
"sample = MultiTurnSample(user_input=[\n",
" HumanMessage(content=\"Hey, book a table at the nearest best Chinese restaurant for 8:00pm\"),\n",
" AIMessage(content=\"Sure, let me find the best options for you.\", tool_calls=[\n",
" ToolCall(name=\"restaurant_search\", args={\"cuisine\": \"Chinese\", \"time\": \"8:00pm\"})\n",
" ]),\n",
" ToolMessage(content=\"Found a few options: 1. Golden Dragon, 2. Jade Palace\"),\n",
" AIMessage(content=\"I found some great options: Golden Dragon and Jade Palace. Which one would you prefer?\"),\n",
" HumanMessage(content=\"Let's go with Golden Dragon.\"),\n",
" AIMessage(content=\"Great choice! I'll book a table for 8:00pm at Golden Dragon.\", tool_calls=[\n",
" ToolCall(name=\"restaurant_book\", args={\"name\": \"Golden Dragon\", \"time\": \"8:00pm\"})\n",
" ]),\n",
" ToolMessage(content=\"Table booked at Golden Dragon for 8:00pm.\"),\n",
" AIMessage(content=\"Your table at Golden Dragon is booked for 8:00pm. Enjoy your meal!\"),\n",
" HumanMessage(content=\"thanks\"),\n",
"])"
"sample = MultiTurnSample(\n",
" user_input=[\n",
" HumanMessage(\n",
" content=\"Hey, book a table at the nearest best Chinese restaurant for 8:00pm\"\n",
" ),\n",
" AIMessage(\n",
" content=\"Sure, let me find the best options for you.\",\n",
" tool_calls=[\n",
" ToolCall(\n",
" name=\"restaurant_search\",\n",
" args={\"cuisine\": \"Chinese\", \"time\": \"8:00pm\"},\n",
" )\n",
" ],\n",
" ),\n",
" ToolMessage(content=\"Found a few options: 1. Golden Dragon, 2. Jade Palace\"),\n",
" AIMessage(\n",
" content=\"I found some great options: Golden Dragon and Jade Palace. Which one would you prefer?\"\n",
" ),\n",
" HumanMessage(content=\"Let's go with Golden Dragon.\"),\n",
" AIMessage(\n",
" content=\"Great choice! I'll book a table for 8:00pm at Golden Dragon.\",\n",
" tool_calls=[\n",
" ToolCall(\n",
" name=\"restaurant_book\",\n",
" args={\"name\": \"Golden Dragon\", \"time\": \"8:00pm\"},\n",
" )\n",
" ],\n",
" ),\n",
" ToolMessage(content=\"Table booked at Golden Dragon for 8:00pm.\"),\n",
" AIMessage(\n",
" content=\"Your table at Golden Dragon is booked for 8:00pm. Enjoy your meal!\"\n",
" ),\n",
" HumanMessage(content=\"thanks\"),\n",
" ]\n",
")"
]
},
{
Expand Down
16 changes: 7 additions & 9 deletions docs/howtos/customizations/run_config.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from ragas.run_config import RunConfig\n",
"\n",
"# increasing max_workers to 64 and timeout to 60 seconds\n",
"\n",
"my_run_config=RunConfig(max_workers=64, timeout=60) "
"my_run_config = RunConfig(max_workers=64, timeout=60)"
]
},
{
Expand All @@ -56,15 +54,15 @@
"from datasets import load_dataset\n",
"from ragas import evaluate\n",
"\n",
"dataset = load_dataset(\"explodinggradients/amnesty_qa\",\"english_v3\")\n",
"dataset = load_dataset(\"explodinggradients/amnesty_qa\", \"english_v3\")\n",
"\n",
"samples = []\n",
"for row in dataset['eval']:\n",
"for row in dataset[\"eval\"]:\n",
" sample = SingleTurnSample(\n",
" user_input=row['user_input'],\n",
" reference=row['reference'],\n",
" response=row['response'],\n",
" retrieved_contexts=row['retrieved_contexts']\n",
" user_input=row[\"user_input\"],\n",
" reference=row[\"reference\"],\n",
" response=row[\"response\"],\n",
" retrieved_contexts=row[\"retrieved_contexts\"],\n",
" )\n",
" samples.append(sample)\n",
"\n",
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,7 @@ build-backend = "setuptools.build_meta"
write_to = "src/ragas/_version.py"

[tool.pytest.ini_options]
addopts = "-n 4"
asyncio_default_fixture_loop_scope = "function"
addopts = "-n 0"
asyncio_default_fixture_loop_scope = "function"
[pytest]
testpaths = ["tests"]
12 changes: 12 additions & 0 deletions src/ragas/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,15 @@ class ExceptionInRunner(RagasException):
def __init__(self):
msg = "The runner thread which was running the jobs raised an exeception. Read the traceback above to debug it. You can also pass `raise_exceptions=False` incase you want to show only a warning message instead."
super().__init__(msg)


class RagasOutputParserException(RagasException):
"""
Exception raised when the output parser fails to parse the output.
"""

def __init__(self, num_retries: int):
msg = (
f"The output parser failed to parse the output after {num_retries} retries."
)
super().__init__(msg)
6 changes: 3 additions & 3 deletions src/ragas/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __init__(self, metric: Metric, **kwargs: t.Any):
t.cast(MetricWithLLM, self.metric).llm = LangchainLLMWrapper(llm)
if isinstance(self.metric, MetricWithEmbeddings):
embeddings = get_or_init(kwargs, "embeddings", OpenAIEmbeddings)
t.cast(
MetricWithEmbeddings, self.metric
).embeddings = LangchainEmbeddingsWrapper(embeddings)
t.cast(MetricWithEmbeddings, self.metric).embeddings = (
LangchainEmbeddingsWrapper(embeddings)
)
self.metric.init(run_config)

assert isinstance(
Expand Down
1 change: 1 addition & 0 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool:
@dataclass
class BaseRagasLLM(ABC):
run_config: RunConfig = field(default_factory=RunConfig)
multiple_completion_supported: bool = False

def set_run_config(self, run_config: RunConfig):
self.run_config = run_config
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/llms/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_json_format_instructions(pydantic_object: t.Type[TBaseModel]) -> str:
return resp


class RagasoutputParser(PydanticOutputParser):
class RagasOutputParserOld(PydanticOutputParser):
async def aparse( # type: ignore
self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int = 1
):
Expand Down
Loading

0 comments on commit 407c2e0

Please sign in to comment.