Skip to content

Commit 8401384

Browse files
authored
feat/refactor: Allow pipelines without generators to be used with the RAG eval harness (#31)
1 parent 9973f3b commit 8401384

File tree

4 files changed

+217
-98
lines changed

4 files changed

+217
-98
lines changed

haystack_experimental/evaluation/harness/rag/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from .harness import RAGEvaluationHarness
5+
from .harness import DefaultRAGArchitecture, RAGEvaluationHarness
66
from .parameters import (
77
RAGEvaluationInput,
88
RAGEvaluationMetric,
@@ -13,6 +13,7 @@
1313
)
1414

1515
_all_ = [
16+
"DefaultRAGArchitecture",
1617
"RAGEvaluationHarness",
1718
"RAGExpectedComponent",
1819
"RAGExpectedComponentMetadata",

haystack_experimental/evaluation/harness/rag/harness.py

+154-83
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from copy import deepcopy
6-
from typing import Any, Dict, List, Optional, Set
6+
from enum import Enum
7+
from typing import Any, Dict, List, Optional, Set, Union
78

89
from haystack import Pipeline
910
from haystack.evaluation.eval_run_result import EvaluationRunResult
@@ -25,6 +26,83 @@
2526
)
2627

2728

29+
class DefaultRAGArchitecture(Enum):
30+
"""
31+
Represents default RAG pipeline architectures that can be used with the evaluation harness.
32+
"""
33+
34+
#: A RAG pipeline with:
35+
#: - A query embedder component named 'query_embedder' with a 'text' input.
36+
#: - A document retriever component named 'retriever' with a 'documents' output.
37+
EMBEDDING_RETRIEVAL = "embedding_retrieval"
38+
39+
#: A RAG pipeline with:
40+
#: - A document retriever component named 'retriever' with a 'query' input and a 'documents' output.
41+
KEYWORD_RETRIEVAL = "keyword_retrieval"
42+
43+
#: A RAG pipeline with:
44+
#: - A query embedder component named 'query_embedder' with a 'text' input.
45+
#: - A document retriever component named 'retriever' with a 'documents' output.
46+
#: - A response generator component named 'generator' with a 'replies' output.
47+
GENERATION_WITH_EMBEDDING_RETRIEVAL = "generation_with_embedding_retrieval"
48+
49+
#: A RAG pipeline with:
50+
#: - A document retriever component named 'retriever' with a 'query' input and a 'documents' output.
51+
#: - A response generator component named 'generator' with a 'replies' output.
52+
GENERATION_WITH_KEYWORD_RETRIEVAL = "generation_with_keyword_retrieval"
53+
54+
@property
55+
def expected_components(
56+
self,
57+
) -> Dict[RAGExpectedComponent, RAGExpectedComponentMetadata]:
58+
"""
59+
Returns the expected components for the architecture.
60+
61+
:returns:
62+
The expected components.
63+
"""
64+
if self in (
65+
DefaultRAGArchitecture.EMBEDDING_RETRIEVAL,
66+
DefaultRAGArchitecture.GENERATION_WITH_EMBEDDING_RETRIEVAL,
67+
):
68+
expected = {
69+
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
70+
name="query_embedder", input_mapping={"query": "text"}
71+
),
72+
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
73+
name="retriever",
74+
output_mapping={"retrieved_documents": "documents"},
75+
),
76+
}
77+
elif self in (
78+
DefaultRAGArchitecture.KEYWORD_RETRIEVAL,
79+
DefaultRAGArchitecture.GENERATION_WITH_KEYWORD_RETRIEVAL,
80+
):
81+
expected = {
82+
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
83+
name="retriever", input_mapping={"query": "query"}
84+
),
85+
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
86+
name="retriever",
87+
output_mapping={"retrieved_documents": "documents"},
88+
),
89+
}
90+
else:
91+
raise NotImplementedError(f"Unexpected default RAG architecture: {self}")
92+
93+
if self in (
94+
DefaultRAGArchitecture.GENERATION_WITH_EMBEDDING_RETRIEVAL,
95+
DefaultRAGArchitecture.GENERATION_WITH_KEYWORD_RETRIEVAL,
96+
):
97+
expected[RAGExpectedComponent.RESPONSE_GENERATOR] = (
98+
RAGExpectedComponentMetadata(
99+
name="generator", output_mapping={"replies": "replies"}
100+
)
101+
)
102+
103+
return expected
104+
105+
28106
class RAGEvaluationHarness(
29107
EvaluationHarness[RAGEvaluationInput, RAGEvaluationOverrides, RAGEvaluationOutput]
30108
):
@@ -35,7 +113,10 @@ class RAGEvaluationHarness(
35113
def __init__(
36114
self,
37115
rag_pipeline: Pipeline,
38-
rag_components: Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
116+
rag_components: Union[
117+
DefaultRAGArchitecture,
118+
Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
119+
],
39120
metrics: Set[RAGEvaluationMetric],
40121
):
41122
"""
@@ -44,76 +125,23 @@ def __init__(
44125
:param rag_pipeline:
45126
The RAG pipeline to evaluate.
46127
:param rag_components:
47-
A mapping of expected components to their metadata.
128+
Either a default RAG architecture or a mapping
129+
of expected components to their metadata.
48130
:param metrics:
49131
The metrics to use during evaluation.
50132
"""
51133
super().__init__()
52134

53-
self._validate_rag_components(rag_pipeline, rag_components)
135+
if isinstance(rag_components, DefaultRAGArchitecture):
136+
rag_components = rag_components.expected_components
137+
138+
self._validate_rag_components(rag_pipeline, rag_components, metrics)
54139

55140
self.rag_pipeline = rag_pipeline
56-
self.rag_components = rag_components
57-
self.metrics = metrics
141+
self.rag_components = deepcopy(rag_components)
142+
self.metrics = deepcopy(metrics)
58143
self.evaluation_pipeline = default_rag_evaluation_pipeline(metrics)
59144

60-
@classmethod
61-
def default_with_embedding_retriever(
62-
cls, rag_pipeline: Pipeline, metrics: Set[RAGEvaluationMetric]
63-
) -> "RAGEvaluationHarness":
64-
"""
65-
Create a default evaluation harness for evaluating RAG pipelines with a query embedder.
66-
67-
:param rag_pipeline:
68-
The RAG pipeline to evaluate. The following assumptions are made:
69-
- The query embedder component is named 'query_embedder' and has a 'text' input.
70-
- The document retriever component is named 'retriever' and has a 'documents' output.
71-
- The response generator component is named 'generator' and has a 'replies' output.
72-
:param metrics:
73-
The metrics to use during evaluation.
74-
"""
75-
rag_components = {
76-
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
77-
name="query_embedder", input_mapping={"query": "text"}
78-
),
79-
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
80-
name="retriever", output_mapping={"retrieved_documents": "documents"}
81-
),
82-
RAGExpectedComponent.RESPONSE_GENERATOR: RAGExpectedComponentMetadata(
83-
name="generator", output_mapping={"replies": "replies"}
84-
),
85-
}
86-
87-
return cls(rag_pipeline, rag_components, deepcopy(metrics))
88-
89-
@classmethod
90-
def default_with_keyword_retriever(
91-
cls, rag_pipeline: Pipeline, metrics: Set[RAGEvaluationMetric]
92-
) -> "RAGEvaluationHarness":
93-
"""
94-
Create a default evaluation harness for evaluating RAG pipelines with a keyword retriever.
95-
96-
:param rag_pipeline:
97-
The RAG pipeline to evaluate. The following assumptions are made:
98-
- The document retriever component is named 'retriever' and has a 'query' input and a 'documents' output.
99-
- The response generator component is named 'generator' and has a 'replies' output.
100-
:param metrics:
101-
The metrics to use during evaluation.
102-
"""
103-
rag_components = {
104-
RAGExpectedComponent.QUERY_PROCESSOR: RAGExpectedComponentMetadata(
105-
name="retriever", input_mapping={"query": "query"}
106-
),
107-
RAGExpectedComponent.DOCUMENT_RETRIEVER: RAGExpectedComponentMetadata(
108-
name="retriever", output_mapping={"retrieved_documents": "documents"}
109-
),
110-
RAGExpectedComponent.RESPONSE_GENERATOR: RAGExpectedComponentMetadata(
111-
name="generator", output_mapping={"replies": "replies"}
112-
),
113-
}
114-
115-
return cls(rag_pipeline, rag_components, deepcopy(metrics))
116-
117145
def run( # noqa: D102
118146
self,
119147
inputs: RAGEvaluationInput,
@@ -141,10 +169,12 @@ def run( # noqa: D102
141169
"retrieved_documents",
142170
)
143171
],
144-
"responses": self._lookup_component_output(
145-
RAGExpectedComponent.RESPONSE_GENERATOR, rag_outputs, "replies"
146-
),
147172
}
173+
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
174+
result_inputs["responses"] = self._lookup_component_output(
175+
RAGExpectedComponent.RESPONSE_GENERATOR, rag_outputs, "replies"
176+
)
177+
148178
if inputs.ground_truth_answers is not None:
149179
result_inputs["ground_truth_answers"] = inputs.ground_truth_answers
150180
if inputs.ground_truth_documents is not None:
@@ -199,34 +229,40 @@ def _generate_eval_run_pipelines(
199229
rag_pipeline = self._override_pipeline(self.rag_pipeline, rag_overrides)
200230
eval_pipeline = self._override_pipeline(self.evaluation_pipeline, eval_overrides) # type: ignore
201231

232+
included_first_outputs = {
233+
self.rag_components[RAGExpectedComponent.DOCUMENT_RETRIEVER].name
234+
}
235+
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
236+
included_first_outputs.add(
237+
self.rag_components[RAGExpectedComponent.RESPONSE_GENERATOR].name
238+
)
239+
202240
return PipelinePair(
203241
first=rag_pipeline,
204242
second=eval_pipeline,
205243
outputs_to_inputs=self._map_rag_eval_pipeline_io(),
206244
map_first_outputs=lambda x: self._aggregate_rag_outputs( # pylint: disable=unnecessary-lambda
207245
x
208246
),
209-
included_first_outputs={
210-
self.rag_components[RAGExpectedComponent.DOCUMENT_RETRIEVER].name,
211-
self.rag_components[RAGExpectedComponent.RESPONSE_GENERATOR].name,
212-
},
247+
included_first_outputs=included_first_outputs,
213248
)
214249

215250
def _aggregate_rag_outputs(
216251
self, outputs: List[Dict[str, Dict[str, Any]]]
217252
) -> Dict[str, Dict[str, Any]]:
218253
aggregate = aggregate_batched_pipeline_outputs(outputs)
219254

220-
# We only care about the first response from the generator.
221-
generator_name = self.rag_components[
222-
RAGExpectedComponent.RESPONSE_GENERATOR
223-
].name
224-
replies_output_name = self.rag_components[
225-
RAGExpectedComponent.RESPONSE_GENERATOR
226-
].output_mapping["replies"]
227-
aggregate[generator_name][replies_output_name] = [
228-
r[0] for r in aggregate[generator_name][replies_output_name]
229-
]
255+
if RAGExpectedComponent.RESPONSE_GENERATOR in self.rag_components:
256+
# We only care about the first response from the generator.
257+
generator_name = self.rag_components[
258+
RAGExpectedComponent.RESPONSE_GENERATOR
259+
].name
260+
replies_output_name = self.rag_components[
261+
RAGExpectedComponent.RESPONSE_GENERATOR
262+
].output_mapping["replies"]
263+
aggregate[generator_name][replies_output_name] = [
264+
r[0] for r in aggregate[generator_name][replies_output_name]
265+
]
230266

231267
return aggregate
232268

@@ -383,11 +419,46 @@ def _prepare_eval_pipeline_additional_inputs(
383419
def _validate_rag_components(
384420
pipeline: Pipeline,
385421
components: Dict[RAGExpectedComponent, RAGExpectedComponentMetadata],
422+
metrics: Set[RAGEvaluationMetric],
386423
):
387-
for e in RAGExpectedComponent:
388-
if e not in components:
424+
metric_specific_required_components = {
425+
RAGEvaluationMetric.DOCUMENT_MAP: [
426+
RAGExpectedComponent.QUERY_PROCESSOR,
427+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
428+
],
429+
RAGEvaluationMetric.DOCUMENT_MRR: [
430+
RAGExpectedComponent.QUERY_PROCESSOR,
431+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
432+
],
433+
RAGEvaluationMetric.DOCUMENT_RECALL_SINGLE_HIT: [
434+
RAGExpectedComponent.QUERY_PROCESSOR,
435+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
436+
],
437+
RAGEvaluationMetric.DOCUMENT_RECALL_MULTI_HIT: [
438+
RAGExpectedComponent.QUERY_PROCESSOR,
439+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
440+
],
441+
RAGEvaluationMetric.SEMANTIC_ANSWER_SIMILARITY: [
442+
RAGExpectedComponent.QUERY_PROCESSOR,
443+
RAGExpectedComponent.RESPONSE_GENERATOR,
444+
],
445+
RAGEvaluationMetric.FAITHFULNESS: [
446+
RAGExpectedComponent.QUERY_PROCESSOR,
447+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
448+
RAGExpectedComponent.RESPONSE_GENERATOR,
449+
],
450+
RAGEvaluationMetric.CONTEXT_RELEVANCE: [
451+
RAGExpectedComponent.QUERY_PROCESSOR,
452+
RAGExpectedComponent.DOCUMENT_RETRIEVER,
453+
],
454+
}
455+
456+
for m in metrics:
457+
required_components = metric_specific_required_components[m]
458+
if not all(c in components for c in required_components):
389459
raise ValueError(
390-
f"RAG evaluation harness requires metadata for the '{e.value}' component."
460+
f"In order to use the metric '{m}', the RAG evaluation harness requires metadata "
461+
f"for the following components: {required_components}"
391462
)
392463

393464
pipeline_outputs = pipeline.outputs(

haystack_experimental/evaluation/harness/rag/parameters.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class RAGExpectedComponent(Enum):
1414
"""
15-
Represents the basic components in a RAG pipeline that needs to be present for evaluation.
15+
Represents the basic components in a RAG pipeline that are, by default, required to be present for evaluation.
1616
1717
Each of these can be separate components in the pipeline or a single component that performs
1818
multiple tasks.
@@ -27,6 +27,7 @@ class RAGExpectedComponent(Enum):
2727
DOCUMENT_RETRIEVER = "document_retriever"
2828

2929
#: The component in a RAG pipeline that generates responses based on the query and the retrieved documents.
30+
#: Can be optional if the harness is only evaluating retrieval.
3031
#: Expected outputs: `replies` - Name of out containing the LLM responses. Only the first response is used.
3132
RESPONSE_GENERATOR = "response_generator"
3233

@@ -57,24 +58,31 @@ class RAGEvaluationMetric(Enum):
5758
"""
5859

5960
#: Document Mean Average Precision.
61+
#: Required RAG components: Query Processor, Document Retriever.
6062
DOCUMENT_MAP = "metric_doc_map"
6163

6264
#: Document Mean Reciprocal Rank.
65+
#: Required RAG components: Query Processor, Document Retriever.
6366
DOCUMENT_MRR = "metric_doc_mrr"
6467

6568
#: Document Recall with a single hit.
69+
#: Required RAG components: Query Processor, Document Retriever.
6670
DOCUMENT_RECALL_SINGLE_HIT = "metric_doc_recall_single"
6771

6872
#: Document Recall with multiple hits.
73+
#: Required RAG components: Query Processor, Document Retriever.
6974
DOCUMENT_RECALL_MULTI_HIT = "metric_doc_recall_multi"
7075

7176
#: Semantic Answer Similarity.
77+
#: Required RAG components: Query Processor, Response Generator.
7278
SEMANTIC_ANSWER_SIMILARITY = "metric_sas"
7379

7480
#: Faithfulness.
81+
#: Required RAG components: Query Processor, Document Retriever, Response Generator.
7582
FAITHFULNESS = "metric_faithfulness"
7683

7784
#: Context Relevance.
85+
#: Required RAG components: Query Processor, Document Retriever.
7886
CONTEXT_RELEVANCE = "metric_context_relevance"
7987

8088

0 commit comments

Comments
 (0)