Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatters working E2E, including chunk and response parsers. #708 continued #1171

Draft
wants to merge 4 commits into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions js/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@
"require": "./lib/reranker.js",
"import": "./lib/reranker.mjs",
"default": "./lib/reranker.js"
},
"./formats": {
"types": "./lib/formats/index.d.ts",
"require": "./lib/formats/index.js",
"import": "./lib/formats/index.mjs",
"default": "./lib/formats/index.js"
}
},
"typesVersions": {
Expand Down
16 changes: 8 additions & 8 deletions js/ai/src/formats/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@ export const arrayFormatter: Formatter<unknown[], unknown[]> = {
contentType: 'application/json',
constrained: true,
},
handler: (request) => {
if (request.output?.schema && request.output?.schema.type !== 'array') {
handler: (schema) => {
if (schema && schema.type !== 'array') {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Must supply an 'array' schema type when using the 'items' parser format.`,
});
}

let instructions: string | undefined;
if (request.output?.schema) {
if (schema) {
instructions = `Output should be a JSON array conforming to the following schema:

\`\`\`
${JSON.stringify(request.output!.schema!)}
\`\`\`
\`\`\`
${JSON.stringify(schema)}
\`\`\`
`;
}

Expand All @@ -54,8 +54,8 @@ export const arrayFormatter: Formatter<unknown[], unknown[]> = {
return items;
},

parseResponse: (response) => {
const { items } = extractItems(response.text, 0);
parseMessage: (message) => {
const { items } = extractItems(message.text, 0);
return items;
},

Expand Down
15 changes: 7 additions & 8 deletions js/ai/src/formats/enum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,25 @@ import type { Formatter } from './types';
export const enumFormatter: Formatter<string, string> = {
name: 'enum',
config: {
contentType: 'text/plain',
contentType: 'text/x.enum',
constrained: true,
},
handler: (request) => {
const schemaType = request.output?.schema?.type;
if (schemaType && schemaType !== 'string' && schemaType !== 'enum') {
handler: (schema) => {
if (schema && schema.type !== 'string' && schema.type !== 'enum') {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Must supply a 'string' or 'enum' schema type when using the enum parser format.`,
});
}

let instructions: string | undefined;
if (request.output?.schema?.enum) {
instructions = `Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n${request.output?.schema?.enum.map((v) => v.toString()).join('\n')}`;
if (schema?.enum) {
instructions = `Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n${schema.enum.map((v) => v.toString()).join('\n')}`;
}

return {
parseResponse: (response) => {
return response.text.trim();
parseMessage: (message) => {
return message.text.trim();
},
instructions,
};
Expand Down
71 changes: 67 additions & 4 deletions js/ai/src/formats/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
* limitations under the License.
*/

import { JSONSchema } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import { MessageData, TextPart } from '../model.js';
import { arrayFormatter } from './array';
import { enumFormatter } from './enum';
import { jsonFormatter } from './json';
Expand All @@ -36,18 +38,79 @@ export function defineFormat(
export type FormatArgument =
| keyof typeof DEFAULT_FORMATS
| Formatter
| Omit<string, keyof typeof DEFAULT_FORMATS>;
| Omit<string, keyof typeof DEFAULT_FORMATS>
| undefined
| null;

export async function resolveFormat(
registry: Registry,
arg: FormatArgument
): Promise<Formatter | undefined> {
): Promise<Formatter<any, any> | undefined> {
if (!arg) return undefined;
if (typeof arg === 'string') {
return registry.lookupValue<Formatter>('format', arg);
}
return arg as Formatter;
}

export function resolveInstructions(
format?: Formatter,
schema?: JSONSchema,
instructionsOption?: boolean | string
): string | undefined {
if (typeof instructionsOption === 'string') return instructionsOption; // user provided instructions
if (instructionsOption === false) return undefined; // user says no instructions
if (!format) return undefined;
return format.handler(schema).instructions;
}

export function injectInstructions(
messages: MessageData[],
instructions: string | boolean | undefined
): MessageData[] {
if (!instructions) return messages;

// bail out if a non-pending output part is already present
if (
messages.find((m) =>
m.content.find(
(p) => p.metadata?.purpose === 'output' && !p.metadata?.pending
)
)
) {
return messages;
}

const newPart: TextPart = {
text: instructions as string,
metadata: { purpose: 'output' },
};

// find the system message or the last user message
let targetIndex = messages.findIndex((m) => m.role === 'system');
if (targetIndex < 0)
targetIndex = messages.map((m) => m.role).lastIndexOf('user');
if (targetIndex < 0) return messages;

const m = {
...messages[targetIndex],
content: [...messages[targetIndex].content],
};

const partIndex = m.content.findIndex(
(p) => p.metadata?.purpose === 'output' && p.metadata?.pending
);
if (partIndex > 0) {
m.content.splice(partIndex, 1, newPart);
} else {
m.content.push(newPart);
}

const outMessages = [...messages];
outMessages.splice(targetIndex, 1, m);
return outMessages;
}

export const DEFAULT_FORMATS: Formatter<any, any>[] = [
jsonFormatter,
arrayFormatter,
Expand All @@ -57,9 +120,9 @@ export const DEFAULT_FORMATS: Formatter<any, any>[] = [
];

/**
* initializeFormats registers the default built-in formats on a registry.
* configureFormats registers the default built-in formats on a registry.
*/
export function initializeFormats(registry: Registry) {
export function configureFormats(registry: Registry) {
for (const format of DEFAULT_FORMATS) {
defineFormat(
registry,
Expand Down
10 changes: 5 additions & 5 deletions js/ai/src/formats/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ export const jsonFormatter: Formatter<unknown, unknown> = {
contentType: 'application/json',
constrained: true,
},
handler: (request) => {
handler: (schema) => {
let instructions: string | undefined;

if (request.output?.schema) {
if (schema) {
instructions = `Output should be in JSON format and conform to the following schema:

\`\`\`
${JSON.stringify(request.output!.schema!)}
${JSON.stringify(schema)}
\`\`\`
`;
}
Expand All @@ -40,8 +40,8 @@ ${JSON.stringify(request.output!.schema!)}
return extractJson(chunk.accumulatedText);
},

parseResponse: (response) => {
return extractJson(response.text);
parseMessage: (message) => {
return extractJson(message.text);
},

instructions,
Expand Down
17 changes: 8 additions & 9 deletions js/ai/src/formats/jsonl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ export const jsonlFormatter: Formatter<unknown[], unknown[]> = {
config: {
contentType: 'application/jsonl',
},
handler: (request) => {
handler: (schema) => {
if (
request.output?.schema &&
(request.output?.schema.type !== 'array' ||
request.output?.schema.items?.type !== 'object')
schema &&
(schema.type !== 'array' || schema.items?.type !== 'object')
) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
Expand All @@ -44,11 +43,11 @@ export const jsonlFormatter: Formatter<unknown[], unknown[]> = {
}

let instructions: string | undefined;
if (request.output?.schema?.items) {
instructions = `Output should be JSONL format, a sequence of JSON objects (one per line). Each line should conform to the following schema:
if (schema?.items) {
instructions = `Output should be JSONL format, a sequence of JSON objects (one per line) separated by a newline \`\\n\` character. Each line should be a JSON object conforming to the following schema:

\`\`\`
${JSON.stringify(request.output.schema.items)}
${JSON.stringify(schema.items)}
\`\`\`
`;
}
Expand Down Expand Up @@ -86,8 +85,8 @@ ${JSON.stringify(request.output.schema.items)}
return results;
},

parseResponse: (response) => {
const items = objectLines(response.text)
parseMessage: (message) => {
const items = objectLines(message.text)
.map((l) => extractJson(l))
.filter((l) => !!l);

Expand Down
4 changes: 2 additions & 2 deletions js/ai/src/formats/text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ export const textFormatter: Formatter<string, string> = {
return chunk.text;
},

parseResponse: (response) => {
return response.text;
parseMessage: (message) => {
return message.text;
},
};
},
Expand Down
17 changes: 8 additions & 9 deletions js/ai/src/formats/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@
* limitations under the License.
*/

import { GenerateResponse, GenerateResponseChunk } from '../generate.js';
import { ModelRequest, Part } from '../model.js';
import { JSONSchema } from '@genkit-ai/core';
import { GenerateResponseChunk } from '../generate.js';
import { Message } from '../message.js';
import { ModelRequest } from '../model.js';

type OutputContentTypes =
| 'application/json'
| 'text/plain'
| 'application/jsonl';
type OutputContentTypes = 'application/json' | 'text/plain';

export interface Formatter<O = unknown, CO = unknown> {
name: string;
config: ModelRequest['output'];
handler: (req: ModelRequest) => {
parseResponse(response: GenerateResponse): O;
handler: (schema?: JSONSchema) => {
parseMessage(message: Message): O;
parseChunk?: (chunk: GenerateResponseChunk, cursor?: CC) => CO;
instructions?: string | Part[];
instructions?: string;
};
}
Loading
Loading