Skip to content

Commit 7893efd

Browse files
authored
feat(evals): Make context support any type. (#1517)
1 parent b8a4618 commit 7893efd

File tree

10 files changed

+148
-108
lines changed

10 files changed

+148
-108
lines changed

genkit-tools/cli/src/commands/eval-extract-data.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ export const evalExtractData = new Command('eval:extractData')
7474
testCaseId: generateTestCaseId(),
7575
input: extractors.input(trace),
7676
output: extractors.output(trace),
77-
context: JSON.parse(extractors.context(trace)) as string[],
77+
context: toArray(extractors.context(trace)),
7878
// The trace (t) does not contain the traceId, so we have to pull it out of the
7979
// spans, de- dupe, and turn it back into an array.
8080
traceIds: Array.from(
@@ -105,3 +105,7 @@ export const evalExtractData = new Command('eval:extractData')
105105
}
106106
});
107107
});
108+
109+
function toArray(input: any) {
110+
return Array.isArray(input) ? input : [input];
111+
}

genkit-tools/common/src/eval/evaluate.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ async function gatherEvalInput(params: {
376376
input,
377377
output,
378378
error,
379-
context: JSON.parse(context) as string[],
379+
context: Array.isArray(context) ? context : [context],
380380
reference: state.reference,
381381
traceIds: [traceId],
382382
};
@@ -395,12 +395,11 @@ function getSpanErrorMessage(span: SpanData): string | undefined {
395395
}
396396
}
397397

398-
function getErrorFromModelResponse(output: string): string | undefined {
399-
const obj = JSON.parse(output);
398+
function getErrorFromModelResponse(obj: any): string | undefined {
400399
const response = GenerateResponseSchema.parse(obj);
401400

402401
if (!response || !response.candidates || response.candidates.length === 0) {
403-
return `No response was extracted from the output. '${output}'`;
402+
return `No response was extracted from the output. '${JSON.stringify(obj)}'`;
404403
}
405404

406405
// We currently only support the first candidate

genkit-tools/common/src/plugin/config.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ const EvaluationExtractorSchema = z.record(
4444
z.union([
4545
z.string(), // specify the displayName (default to output)
4646
StepSelectorSchema, //, {inputOf: 'my-step-name'}
47-
z.function().args(TraceDataSchema).returns(z.string()), // custom trace extractor
47+
z.function().args(TraceDataSchema).returns(z.any()), // custom trace extractor
4848
])
4949
);
5050
export type EvaluationExtractor = z.infer<typeof EvaluationExtractorSchema>;

genkit-tools/common/src/types/eval.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ export const EvalInputSchema = z.object({
8181
input: z.any(),
8282
output: z.any(),
8383
error: z.string().optional(),
84-
context: z.array(z.string()).optional(),
84+
context: z.array(z.any()).optional(),
8585
reference: z.any().optional(),
8686
traceIds: z.array(z.string()),
8787
});

genkit-tools/common/src/utils/eval.ts

+34-27
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ import { NestedSpanData, TraceData } from '../types/trace';
3131
import { logger } from './logger';
3232
import { stackTraceSpans } from './trace';
3333

34-
export type EvalExtractorFn = (t: TraceData) => string;
35-
const JSON_EMPTY_STRING = '""';
34+
export type EvalExtractorFn = (t: TraceData) => any;
3635

3736
export const EVALUATOR_ACTION_PREFIX = '/evaluator';
3837

@@ -78,30 +77,39 @@ function getRootSpan(trace: TraceData): NestedSpanData | undefined {
7877
return stackTraceSpans(trace);
7978
}
8079

80+
function safeParse(value?: string) {
81+
if (value) {
82+
try {
83+
return JSON.parse(value);
84+
} catch (e) {
85+
return '';
86+
}
87+
}
88+
return '';
89+
}
90+
8191
const DEFAULT_INPUT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
8292
const rootSpan = getRootSpan(trace);
83-
return (rootSpan?.attributes['genkit:input'] as string) || JSON_EMPTY_STRING;
93+
return safeParse(rootSpan?.attributes['genkit:input'] as string);
8494
};
8595
const DEFAULT_OUTPUT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
8696
const rootSpan = getRootSpan(trace);
87-
return (rootSpan?.attributes['genkit:output'] as string) || JSON_EMPTY_STRING;
97+
return safeParse(rootSpan?.attributes['genkit:output'] as string);
8898
};
8999
const DEFAULT_CONTEXT_EXTRACTOR: EvalExtractorFn = (trace: TraceData) => {
90-
return JSON.stringify(
91-
Object.values(trace.spans)
92-
.filter((s) => s.attributes['genkit:metadata:subtype'] === 'retriever')
93-
.flatMap((s) => {
94-
const output: RetrieverResponse = JSON.parse(
95-
s.attributes['genkit:output'] as string
96-
);
97-
if (!output) {
98-
return [];
99-
}
100-
return output.documents.flatMap((d: DocumentData) =>
101-
d.content.map((c) => c.text).filter((text): text is string => !!text)
102-
);
103-
})
104-
);
100+
return Object.values(trace.spans)
101+
.filter((s) => s.attributes['genkit:metadata:subtype'] === 'retriever')
102+
.flatMap((s) => {
103+
const output: RetrieverResponse = safeParse(
104+
s.attributes['genkit:output'] as string
105+
);
106+
if (!output) {
107+
return [];
108+
}
109+
return output.documents.flatMap((d: DocumentData) =>
110+
d.content.map((c) => c.text).filter((text): text is string => !!text)
111+
);
112+
});
105113
};
106114

107115
const DEFAULT_FLOW_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
@@ -113,29 +121,29 @@ const DEFAULT_FLOW_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
113121
const DEFAULT_MODEL_EXTRACTORS: Record<EvalField, EvalExtractorFn> = {
114122
input: DEFAULT_INPUT_EXTRACTOR,
115123
output: DEFAULT_OUTPUT_EXTRACTOR,
116-
context: () => JSON.stringify([]),
124+
context: () => [],
117125
};
118126

119127
function getStepAttribute(
120128
trace: TraceData,
121129
stepName: string,
122130
attributeName?: string
123-
): string {
131+
) {
124132
// Default to output
125133
const attr = attributeName ?? 'genkit:output';
126134
const values = Object.values(trace.spans)
127135
.filter((step) => step.displayName === stepName)
128136
.flatMap((step) => {
129-
return JSON.parse(step.attributes[attr] as string);
137+
return safeParse(step.attributes[attr] as string);
130138
});
131139
if (values.length === 0) {
132-
return JSON_EMPTY_STRING;
140+
return '';
133141
}
134142
if (values.length === 1) {
135-
return JSON.stringify(values[0]);
143+
return values[0];
136144
}
137145
// Return array if multiple steps have the same name
138-
return JSON.stringify(values);
146+
return values;
139147
}
140148

141149
function getExtractorFromStepName(stepName: string): EvalExtractorFn {
@@ -159,7 +167,7 @@ function getExtractorFromStepSelector(
159167
selectedAttribute = 'genkit:output';
160168
}
161169
if (!stepName) {
162-
return JSON_EMPTY_STRING;
170+
return '';
163171
} else {
164172
return getStepAttribute(trace, stepName, selectedAttribute);
165173
}
@@ -196,7 +204,6 @@ export async function getEvalExtractors(
196204
return Promise.resolve(DEFAULT_MODEL_EXTRACTORS);
197205
}
198206
const config = await findToolsConfig();
199-
logger.info(`Found tools config... ${JSON.stringify(config)}`);
200207
const extractors = config?.evaluators
201208
?.filter((e) => e.actionRef === actionRef)
202209
.map((e) => e.extractors);

genkit-tools/common/tests/utils/eval_test.ts

+33-43
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ describe('eval utils', () => {
3838
expect(Object.keys(extractors).sort()).toEqual(
3939
['input', 'output', 'context'].sort()
4040
);
41-
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
42-
expect(extractors.output(trace)).toEqual(JSON.stringify('My output'));
43-
expect(extractors.context(trace)).toEqual(JSON.stringify([]));
41+
expect(extractors.input(trace)).toEqual('My input');
42+
expect(extractors.output(trace)).toEqual('My output');
43+
expect(extractors.context(trace)).toEqual([]);
4444
});
4545
});
4646

@@ -63,9 +63,9 @@ describe('eval utils', () => {
6363
expect(Object.keys(extractors).sort()).toEqual(
6464
['input', 'output', 'context'].sort()
6565
);
66-
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
67-
expect(extractors.output(trace)).toEqual(JSON.stringify('My output'));
68-
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
66+
expect(extractors.input(trace)).toEqual('My input');
67+
expect(extractors.output(trace)).toEqual('My output');
68+
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
6969
});
7070

7171
it('returns custom extractors by stepName', async () => {
@@ -100,11 +100,9 @@ describe('eval utils', () => {
100100

101101
const extractors = await getEvalExtractors('/flow/multiSteps');
102102

103-
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
104-
expect(extractors.output(trace)).toEqual(
105-
JSON.stringify({ out: 'my-object-output' })
106-
);
107-
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
103+
expect(extractors.input(trace)).toEqual('My input');
104+
expect(extractors.output(trace)).toEqual({ out: 'my-object-output' });
105+
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
108106
});
109107

110108
it('returns custom extractors by stepSelector', async () => {
@@ -146,11 +144,9 @@ describe('eval utils', () => {
146144

147145
const extractors = await getEvalExtractors('/flow/multiSteps');
148146

149-
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
150-
expect(extractors.output(trace)).toEqual(JSON.stringify('step2-input'));
151-
expect(extractors.context(trace)).toEqual(
152-
JSON.stringify(['Hello', 'World'])
153-
);
147+
expect(extractors.input(trace)).toEqual('My input');
148+
expect(extractors.output(trace)).toEqual('step2-input');
149+
expect(extractors.context(trace)).toEqual(['Hello', 'World']);
154150
});
155151

156152
it('returns custom extractors by trace function', async () => {
@@ -160,23 +156,21 @@ describe('eval utils', () => {
160156
actionRef: '/flow/multiSteps',
161157
extractors: {
162158
input: (trace: TraceData) => {
163-
return JSON.stringify(
164-
Object.values(trace.spans)
165-
.filter(
166-
(s) =>
167-
s.attributes['genkit:type'] === 'action' &&
168-
s.attributes['genkit:metadata:subtype'] !== 'retriever'
169-
)
170-
.map((s) => {
171-
const inputValue = JSON.parse(
172-
s.attributes['genkit:input'] as string
173-
).start.input;
174-
if (!inputValue) {
175-
return '';
176-
}
177-
return inputValue + ' TEST TEST TEST';
178-
})
179-
);
159+
return Object.values(trace.spans)
160+
.filter(
161+
(s) =>
162+
s.attributes['genkit:type'] === 'action' &&
163+
s.attributes['genkit:metadata:subtype'] !== 'retriever'
164+
)
165+
.map((s) => {
166+
const inputValue = JSON.parse(
167+
s.attributes['genkit:input'] as string
168+
).start.input;
169+
if (!inputValue) {
170+
return '';
171+
}
172+
return inputValue + ' TEST TEST TEST';
173+
});
180174
},
181175
output: { inputOf: 'step2' },
182176
context: { outputOf: 'step3-array' },
@@ -211,13 +205,9 @@ describe('eval utils', () => {
211205

212206
const extractors = await getEvalExtractors('/flow/multiSteps');
213207

214-
expect(extractors.input(trace)).toEqual(
215-
JSON.stringify(['My input TEST TEST TEST'])
216-
);
217-
expect(extractors.output(trace)).toEqual(JSON.stringify('step2-input'));
218-
expect(extractors.context(trace)).toEqual(
219-
JSON.stringify(['Hello', 'World'])
220-
);
208+
expect(extractors.input(trace)).toEqual(['My input TEST TEST TEST']);
209+
expect(extractors.output(trace)).toEqual('step2-input');
210+
expect(extractors.context(trace)).toEqual(['Hello', 'World']);
221211
});
222212

223213
it('returns runs default extractors when trace fails', async () => {
@@ -239,8 +229,8 @@ describe('eval utils', () => {
239229
expect(Object.keys(extractors).sort()).toEqual(
240230
['input', 'output', 'context'].sort()
241231
);
242-
expect(extractors.input(trace)).toEqual(JSON.stringify('My input'));
243-
expect(extractors.output(trace)).toEqual(JSON.stringify(''));
244-
expect(extractors.context(trace)).toEqual(JSON.stringify(CONTEXT_TEXTS));
232+
expect(extractors.input(trace)).toEqual('My input');
233+
expect(extractors.output(trace)).toEqual('');
234+
expect(extractors.context(trace)).toEqual(CONTEXT_TEXTS);
245235
});
246236
});

js/plugins/evaluators/src/metrics/answer_relevancy.ts

+20-6
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,26 @@ export async function answerRelevancyScore<
4040
embedderOptions?: z.infer<CustomEmbedderOptions>
4141
): Promise<Score> {
4242
try {
43-
if (!dataPoint.context?.length) {
44-
throw new Error('Context was not provided');
43+
if (!dataPoint.input) {
44+
throw new Error('Input was not provided');
4545
}
4646
if (!dataPoint.output) {
4747
throw new Error('Output was not provided');
4848
}
49+
if (!dataPoint.context?.length) {
50+
throw new Error('Context was not provided');
51+
}
52+
53+
const input =
54+
typeof dataPoint.input === 'string'
55+
? dataPoint.input
56+
: JSON.stringify(dataPoint.input);
57+
const output =
58+
typeof dataPoint.output === 'string'
59+
? dataPoint.output
60+
: JSON.stringify(dataPoint.output);
61+
const context = dataPoint.context.map((i) => JSON.stringify(i));
62+
4963
const prompt = await loadPromptFile(
5064
ai.registry,
5165
path.resolve(getDirName(), '../../prompts/answer_relevancy.prompt')
@@ -54,9 +68,9 @@ export async function answerRelevancyScore<
5468
model: judgeLlm,
5569
config: judgeConfig,
5670
prompt: prompt.renderText({
57-
question: dataPoint.input as string,
58-
answer: dataPoint.output as string,
59-
context: dataPoint.context.join(' '),
71+
question: input,
72+
answer: output,
73+
context: context.join(' '),
6074
}),
6175
output: {
6276
schema: AnswerRelevancyResponseSchema,
@@ -68,7 +82,7 @@ export async function answerRelevancyScore<
6882

6983
const questionEmbed = await ai.embed({
7084
embedder,
71-
content: dataPoint.input as string,
85+
content: input,
7286
options: embedderOptions,
7387
});
7488
const genQuestionEmbed = await ai.embed({

js/plugins/evaluators/src/metrics/faithfulness.ts

+19-6
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,26 @@ export async function faithfulnessScore<
4343
judgeConfig?: CustomModelOptions
4444
): Promise<Score> {
4545
try {
46-
const { input, output, context } = dataPoint;
47-
if (!context?.length) {
48-
throw new Error('Context was not provided');
46+
if (!dataPoint.input) {
47+
throw new Error('Input was not provided');
4948
}
50-
if (!output) {
49+
if (!dataPoint.output) {
5150
throw new Error('Output was not provided');
5251
}
52+
if (!dataPoint.context?.length) {
53+
throw new Error('Context was not provided');
54+
}
55+
56+
const input =
57+
typeof dataPoint.input === 'string'
58+
? dataPoint.input
59+
: JSON.stringify(dataPoint.input);
60+
const output =
61+
typeof dataPoint.output === 'string'
62+
? dataPoint.output
63+
: JSON.stringify(dataPoint.output);
64+
const context = dataPoint.context.map((i) => JSON.stringify(i));
65+
5366
const longFormPrompt = await loadPromptFile(
5467
ai.registry,
5568
path.resolve(getDirName(), '../../prompts/faithfulness_long_form.prompt')
@@ -58,8 +71,8 @@ export async function faithfulnessScore<
5871
model: judgeLlm,
5972
config: judgeConfig,
6073
prompt: longFormPrompt.renderText({
61-
question: input as string,
62-
answer: output as string,
74+
question: input,
75+
answer: output,
6376
}),
6477
output: {
6578
schema: LongFormResponseSchema,

0 commit comments

Comments
 (0)