diff --git a/README.md b/README.md index 1af64b3d3..8f076de7a 100644 --- a/README.md +++ b/README.md @@ -88,9 +88,14 @@ Find excellent examples of community-built plugins for OpenAI, Anthropic, Cohere ## Try Genkit on IDX -Project IDX logo - -Want to try Genkit without a local setup? [Explore it on Project IDX](https://idx.google.com/new/genkit), Google's AI-assisted workspace for full-stack app development in the cloud. +Want to skip the local setup? Click below to try out Genkit using [Project IDX](https://idx.dev), Google's AI-assisted workspace for full-stack app development in the cloud. + + + Try in IDX + ## Sample apps diff --git a/docs/plugins/google-cloud.md b/docs/plugins/google-cloud.md index 5314558bc..6135eb8f2 100644 --- a/docs/plugins/google-cloud.md +++ b/docs/plugins/google-cloud.md @@ -221,32 +221,84 @@ Common dimensions include: - `topK` - the inference topK [value](https://ai.google.dev/docs/concepts#model-parameters). - `topP` - the inference topP [value](https://ai.google.dev/docs/concepts#model-parameters). -### Flow-level metrics +### Feature-level metrics + +Features are the top-level entry-point to your Genkit code. In most cases, this +will be a flow, but if you do not use flows, this will be the top-most span in a trace. + +| Name | Type | Description | +| ----------------------- | --------- | ----------------------- | +| genkit/feature/requests | Counter | Number of requests | +| genkit/feature/latency | Histogram | Execution latency in ms | + +Each feature-level metric contains the following dimensions: + +| Name | Description | +| ------------- | -------------------------------------------------------------------------------- | +| name | The name of the feature. In most cases, this is the top-level Genkit flow | +| status | 'success' or 'failure' depending on whether or not the feature request succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | -| Name | Dimensions | -| -------------------- | ------------------------------------ | -| genkit/flow/requests | flow_name, error_code, error_message | -| genkit/flow/latency | flow_name | ### Action-level metrics -| Name | Dimensions | -| ---------------------- | ------------------------------------ | -| genkit/action/requests | flow_name, error_code, error_message | -| genkit/action/latency | flow_name | +Actions represent a generic step of execution within Genkit. Each of these steps +will have the following metrics tracked: + +| Name | Type | Description | +| ----------------------- | --------- | --------------------------------------------- | +| genkit/action/requests | Counter | Number of times this action has been executed | +| genkit/action/latency | Histogram | Execution latency in ms | + +Each action-level metric contains the following dimensions: + +| Name | Description | +| ------------- | ---------------------------------------------------------------------------------------------------- | +| name | The name of the action | +| featureName | The name of the parent feature being executed | +| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' | +| status | 'success' or 'failure' depending on whether or not the action succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | ### Generate-level metrics -| Name | Dimensions | -| ------------------------------------ | -------------------------------------------------------------------- | -| genkit/ai/generate | flow_path, model, temperature, topK, topP, error_code, error_message | -| genkit/ai/generate/input_tokens | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_tokens | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/input_characters | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_characters | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/input_images | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/output_images | flow_path, model, temperature, topK, topP | -| genkit/ai/generate/latency | flow_path, model, temperature, topK, topP, error_code, error_message | +These are special action metrics relating to actions that interact with a model. +In addition to requests and latency, input and output are also tracked, with model +specific dimensions that make debugging and configuration tuning easier. + +| Name | Type | Description | +| ------------------------------------ | --------- | ------------------------------------------ | +| genkit/ai/generate/requests | Counter | Number of times this model has been called | +| genkit/ai/generate/latency | Histogram | Execution latency in ms | +| genkit/ai/generate/input/tokens | Counter | Input tokens | +| genkit/ai/generate/output/tokens | Counter | Output tokens | +| genkit/ai/generate/input/characters | Counter | Input characters | +| genkit/ai/generate/output/characters | Counter | Output characters | +| genkit/ai/generate/input/images | Counter | Input images | +| genkit/ai/generate/output/images | Counter | Output images | +| genkit/ai/generate/input/audio | Counter | Input audio files | +| genkit/ai/generate/output/audio | Counter | Output audio files | + +Each generate-level metric contains the following dimensions: + +| Name | Description | +| --------------- | ---------------------------------------------------------------------------------------------------- | +| modelName | The name of the model | +| featureName | The name of the parent feature being executed | +| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' | +| temperature | The temperature parameter passed to the model | +| maxOutputTokens | The maxOutputTokens parameter passed to the model | +| topK | The topK parameter passed to the model | +| topP | The topP parameter passed to the model | +| latencyMs | The response time taken by the model | +| status | 'success' or 'failure' depending on whether or not the feature request succeeded | +| error | Only set when `status=failure`. Contains the error type that caused the failure | +| source | The Genkit source language. Eg. 'ts' | +| sourceVersion | The Genkit framework version | Visualizing metrics can be done through the Metrics Explorer. Using the side menu, select 'Logging' and click 'Metrics explorer' diff --git a/genkit-tools/cli/package.json b/genkit-tools/cli/package.json index 335225733..de2ecd1cb 100644 --- a/genkit-tools/cli/package.json +++ b/genkit-tools/cli/package.json @@ -1,6 +1,6 @@ { "name": "genkit-cli", - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "description": "CLI for interacting with the Google Genkit AI framework", "license": "Apache-2.0", "keywords": [ @@ -28,7 +28,7 @@ "dependencies": { "@genkit-ai/tools-common": "workspace:*", "@genkit-ai/telemetry-server": "workspace:*", - "axios": "^1.6.7", + "axios": "^1.7.7", "colorette": "^2.0.20", "commander": "^11.1.0", "extract-zip": "^2.0.1", diff --git a/genkit-tools/common/package.json b/genkit-tools/common/package.json index f77bbbca7..97fedae33 100644 --- a/genkit-tools/common/package.json +++ b/genkit-tools/common/package.json @@ -1,6 +1,6 @@ { "name": "@genkit-ai/tools-common", - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "scripts": { "compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json", "build:clean": "rimraf ./lib", @@ -12,7 +12,7 @@ "@asteasolutions/zod-to-openapi": "^7.0.0", "@trpc/server": "10.45.0", "adm-zip": "^0.5.12", - "axios": "^1.6.7", + "axios": "^1.7.7", "body-parser": "^1.20.2", "chokidar": "^3.5.3", "colorette": "^2.0.20", diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 211fb3ba4..a0727a95b 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -216,11 +216,12 @@ async function bulkRunAction(params: { testCaseId: c.testCaseId ?? generateTestCaseId(), })); + let states: InferenceRunState[] = []; let evalInputs: EvalInput[] = []; for (const testCase of testCases) { - logger.info(`Running '${actionRef}' ...`); + logger.info(`Running inference '${actionRef}' ...`); if (isModelAction) { - evalInputs.push( + states.push( await runModelAction({ manager, actionRef, @@ -229,7 +230,7 @@ async function bulkRunAction(params: { }) ); } else { - evalInputs.push( + states.push( await runFlowAction({ manager, actionRef, @@ -239,6 +240,11 @@ async function bulkRunAction(params: { ); } } + + logger.info(`Gathering evalInputs...`); + for (const state of states) { + evalInputs.push(await gatherEvalInput({ manager, actionRef, state })); + } return evalInputs; } @@ -247,7 +253,7 @@ async function runFlowAction(params: { actionRef: string; testCase: TestCase; auth?: any; -}): Promise { +}): Promise { const { manager, actionRef, testCase, auth } = { ...params }; let state: InferenceRunState; try { @@ -274,7 +280,7 @@ async function runFlowAction(params: { evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } - return gatherEvalInput({ manager, actionRef, state }); + return state; } async function runModelAction(params: { @@ -282,7 +288,7 @@ async function runModelAction(params: { actionRef: string; testCase: TestCase; modelConfig?: any; -}): Promise { +}): Promise { const { manager, actionRef, modelConfig, testCase } = { ...params }; let state: InferenceRunState; try { @@ -304,7 +310,7 @@ async function runModelAction(params: { evalError: `Error when running inference. Details: ${e?.message ?? e}`, }; } - return gatherEvalInput({ manager, actionRef, state }); + return state; } async function gatherEvalInput(params: { diff --git a/genkit-tools/package.json b/genkit-tools/package.json index 09833d453..641914d51 100644 --- a/genkit-tools/package.json +++ b/genkit-tools/package.json @@ -23,5 +23,5 @@ "zod": "^3.22.4", "zod-to-json-schema": "^3.22.4" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml index 7c873c3a8..c8fe635f0 100644 --- a/genkit-tools/pnpm-lock.yaml +++ b/genkit-tools/pnpm-lock.yaml @@ -36,8 +36,8 @@ importers: specifier: workspace:* version: link:../common axios: - specifier: ^1.6.7 - version: 1.6.8 + specifier: ^1.7.7 + version: 1.7.7 colorette: specifier: ^2.0.20 version: 2.0.20 @@ -97,8 +97,8 @@ importers: specifier: ^0.5.12 version: 0.5.12 axios: - specifier: ^1.6.7 - version: 1.6.8 + specifier: ^1.7.7 + version: 1.7.7 body-parser: specifier: ^1.20.2 version: 1.20.2 @@ -1179,8 +1179,8 @@ packages: resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==} engines: {node: '>= 0.4'} - axios@1.6.8: - resolution: {integrity: sha512-v/ZHtJDU39mDpyBoFVkETcd/uNdxrWRrg3bKpOKzXFA6Bvqopts6ALSMU3y6ijYxbw2B+wPrIv46egTzJXCLGQ==} + axios@1.7.7: + resolution: {integrity: sha512-S4kL7XrjgBmvdGut0sN3yJxqYzrDOnivkBiN0OFs6hLiUam3UPvswUo0kqGyhqUZGEOytHyumEdXsAkgCOUf3Q==} babel-jest@29.7.0: resolution: {integrity: sha512-BrvGY3xZSwEcCzKvKsCi2GgHqDqsYkOP4/by5xCgIwGXQxIEh+8ew3gmrE1y7XRR6LHZIj6yLYnUi/mm2KXKBg==} @@ -4201,7 +4201,7 @@ snapshots: dependencies: possible-typed-array-names: 1.0.0 - axios@1.6.8: + axios@1.7.7: dependencies: follow-redirects: 1.15.6 form-data: 4.0.0 diff --git a/genkit-tools/telemetry-server/package.json b/genkit-tools/telemetry-server/package.json index 3665f657a..a49050aa1 100644 --- a/genkit-tools/telemetry-server/package.json +++ b/genkit-tools/telemetry-server/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json", diff --git a/js/ai/package.json b/js/ai/package.json index 047817be9..e9e6d2cde 100644 --- a/js/ai/package.json +++ b/js/ai/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index 5cd278130..89b050253 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; export type EmbeddingBatch = { embedding: number[] }[]; @@ -68,6 +68,7 @@ function withMetadata( export function defineEmbedder< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: ConfigSchema; @@ -76,6 +77,7 @@ export function defineEmbedder< runner: EmbedderFn ) { const embedder = defineAction( + registry, { actionType: 'embedder', name: options.name, @@ -111,47 +113,91 @@ export type EmbedderArgument< * A veneer for interacting with embedder models. */ export async function embed( + registry: Registry, params: EmbedderParams ): Promise { - let embedder: EmbedderAction; - if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); - } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( - `/embedder/${(params.embedder as EmbedderReference).name}` - ); - } else { - embedder = params.embedder as EmbedderAction; - } - if (!embedder) { - throw new Error('Unable to utilize the provided embedder'); + let embedder = await resolveEmbedder(registry, params); + if (!embedder.embedderAction) { + let embedderId: string; + if (typeof params.embedder === 'string') { + embedderId = params.embedder; + } else if ((params.embedder as EmbedderAction)?.__action?.name) { + embedderId = (params.embedder as EmbedderAction).__action.name; + } else { + embedderId = (params.embedder as EmbedderReference).name; + } + throw new Error(`Unable to resolve embedder ${embedderId}`); } - const response = await embedder({ + const response = await embedder.embedderAction({ input: typeof params.content === 'string' ? [Document.fromText(params.content, params.metadata)] : [params.content], - options: params.options, + options: { + version: embedder.version, + ...embedder.config, + ...params.options, + }, }); return response.embeddings[0].embedding; } +interface ResolvedEmbedder { + embedderAction: EmbedderAction; + config?: z.infer; + version?: string; +} + +async function resolveEmbedder< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, + params: EmbedderParams +): Promise> { + if (typeof params.embedder === 'string') { + return { + embedderAction: await registry.lookupAction( + `/embedder/${params.embedder}` + ), + }; + } else if (Object.hasOwnProperty.call(params.embedder, '__action')) { + return { + embedderAction: params.embedder as EmbedderAction, + }; + } else if (Object.hasOwnProperty.call(params.embedder, 'name')) { + const ref = params.embedder as EmbedderReference; + return { + embedderAction: await registry.lookupAction( + `/embedder/${(params.embedder as EmbedderReference).name}` + ), + config: { + ...ref.config, + }, + version: ref.version, + }; + } + throw new Error(`failed to resolve embedder ${params.embedder}`); +} + /** * A veneer for interacting with embedder models in bulk. */ export async function embedMany< ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - embedder: EmbedderArgument; - content: string[] | DocumentData[]; - metadata?: Record; - options?: z.infer; -}): Promise { +>( + registry: Registry, + params: { + embedder: EmbedderArgument; + content: string[] | DocumentData[]; + metadata?: Record; + options?: z.infer; + } +): Promise { let embedder: EmbedderAction; if (typeof params.embedder === 'string') { - embedder = await lookupAction(`/embedder/${params.embedder}`); + embedder = await registry.lookupAction(`/embedder/${params.embedder}`); } else if (Object.hasOwnProperty.call(params.embedder, 'info')) { - embedder = await lookupAction( + embedder = await registry.lookupAction( `/embedder/${(params.embedder as EmbedderReference).name}` ); } else { @@ -192,6 +238,8 @@ export interface EmbedderReference< name: string; configSchema?: CustomOptions; info?: EmbedderInfo; + config?: z.infer; + version?: string; } /** diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 042e069d9..02be11e48 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -16,7 +16,7 @@ import { Action, defineAction, z } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; import { randomUUID } from 'crypto'; @@ -127,6 +127,7 @@ export function defineEvaluator< typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema, EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; displayName: string; @@ -143,6 +144,7 @@ export function defineEvaluator< metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName; metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition; const evaluator = defineAction( + registry, { actionType: 'evaluator', name: options.name, @@ -239,12 +241,17 @@ export type EvaluatorArgument< export async function evaluate< DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: EvaluatorParams): Promise { +>( + registry: Registry, + params: EvaluatorParams +): Promise { let evaluator: EvaluatorAction; if (typeof params.evaluator === 'string') { - evaluator = await lookupAction(`/evaluator/${params.evaluator}`); + evaluator = await registry.lookupAction(`/evaluator/${params.evaluator}`); } else if (Object.hasOwnProperty.call(params.evaluator, 'info')) { - evaluator = await lookupAction(`/evaluator/${params.evaluator.name}`); + evaluator = await registry.lookupAction( + `/evaluator/${params.evaluator.name}` + ); } else { evaluator = params.evaluator as EvaluatorAction; } diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index b4e95716a..cf5231520 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -21,7 +21,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; @@ -365,6 +365,7 @@ export class GenerateResponseChunk } export async function toGenerateRequest( + registry: Registry, options: GenerateOptions ): Promise { const messages: MessageData[] = []; @@ -402,7 +403,7 @@ export async function toGenerateRequest( } let tools: Action[] | undefined; if (options.tools) { - tools = await resolveTools(options.tools); + tools = await resolveTools(registry, options.tools); } const out = { @@ -464,21 +465,28 @@ interface ResolvedModel { version?: string; } -async function resolveModel(options: GenerateOptions): Promise { +async function resolveModel( + registry: Registry, + options: GenerateOptions +): Promise { let model = options.model; if (!model) { throw new Error('Model is required.'); } if (typeof model === 'string') { return { - modelAction: (await lookupAction(`/model/${model}`)) as ModelAction, + modelAction: (await registry.lookupAction( + `/model/${model}` + )) as ModelAction, }; } else if (model.hasOwnProperty('__action')) { return { modelAction: model as ModelAction }; } else { const ref = model as ModelReference; return { - modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction, + modelAction: (await registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction, config: { ...ref.config, }, @@ -525,13 +533,14 @@ export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const resolvedModel = await resolveModel(resolvedOptions); + const resolvedModel = await resolveModel(registry, resolvedOptions); const model = resolvedModel.modelAction; if (!model) { let modelId: string; @@ -603,9 +612,9 @@ export async function generate< messages, tools, config: { - ...resolvedModel.config, version: resolvedModel.version, - ...resolvedOptions.config, + ...stripUndefinedOptions(resolvedModel.config), + ...stripUndefinedOptions(resolvedOptions.config), }, output: resolvedOptions.output && { format: resolvedOptions.output.format, @@ -623,12 +632,23 @@ export async function generate< resolvedOptions.streamingCallback, async () => new GenerateResponse( - await generateHelper(params, resolvedOptions.use), - await toGenerateRequest(resolvedOptions) + await generateHelper(registry, params, resolvedOptions.use), + await toGenerateRequest(registry, resolvedOptions) ) ); } +function stripUndefinedOptions(input?: any): any { + if (!input) return input; + const copy = { ...input }; + Object.keys(input).forEach((key) => { + if (copy[key] === undefined) { + delete copy[key]; + } + }); + return copy; +} + export type GenerateStreamOptions< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -653,6 +673,7 @@ export async function generateStream< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( + registry: Registry, options: | GenerateOptions | PromiseLike> @@ -678,7 +699,7 @@ export async function generateStream< } try { - generate({ + generate(registry, { ...options, streamingCallback: (chunk) => { firstChunkSent = true; diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index 996f77af3..7cb4f2d71 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -21,7 +21,7 @@ import { runWithStreamingCallback, z, } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import * as clc from 'colorette'; @@ -70,6 +70,7 @@ export const GenerateUtilParamSchema = z.object({ * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware. */ export async function generateHelper( + registry: Registry, input: z.infer, middleware?: Middleware[] ): Promise { @@ -86,7 +87,7 @@ export async function generateHelper( async (metadata) => { metadata.name = 'generate'; metadata.input = input; - const output = await generate(input, middleware); + const output = await generate(registry, input, middleware); metadata.output = JSON.stringify(output); return output; } @@ -94,10 +95,11 @@ export async function generateHelper( } async function generate( + registry: Registry, rawRequest: z.infer, middleware?: Middleware[] ): Promise { - const model = (await lookupAction( + const model = (await registry.lookupAction( `/model/${rawRequest.model}` )) as ModelAction; if (!model) { @@ -120,7 +122,7 @@ async function generate( tools = await Promise.all( rawRequest.tools.map(async (toolRef) => { if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; + const tool = (await registry.lookupAction(toolRef)) as ToolAction; if (!tool) { throw new Error(`Tool ${toolRef} not found`); } @@ -203,7 +205,7 @@ async function generate( messages: [...request.messages, message], prompt: toolResponses, }; - return await generateHelper(nextRequest, middleware); + return await generateHelper(registry, nextRequest, middleware); } async function actionToGenerateRequest( diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index 75981dafc..b0945b40c 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -72,9 +72,11 @@ export { export { definePrompt, renderPrompt, + type ExecutablePrompt, type PromptAction, type PromptConfig, type PromptFn, + type PromptGenerateOptions, } from './prompt.js'; export { rerank, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 01aa159f7..4a176e70b 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -22,6 +22,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { DocumentDataSchema } from './document.js'; @@ -330,6 +331,7 @@ export type DefineModelOptions< export function defineModel< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: DefineModelOptions, runner: ( request: GenerateRequest, @@ -344,6 +346,7 @@ export function defineModel< if (!options?.supports?.context) middleware.push(augmentWithContext()); middleware.push(conformOutput()); const act = defineAction( + registry, { actionType: 'model', name: options.name, @@ -386,15 +389,36 @@ export interface ModelReference { info?: ModelInfo; version?: string; config?: z.infer; + + withConfig(cfg: z.infer): ModelReference; + withVersion(version: string): ModelReference; } /** Cretes a model reference. */ export function modelRef< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, >( - options: ModelReference + options: Omit< + ModelReference, + 'withConfig' | 'withVersion' + > ): ModelReference { - return { ...options }; + const ref: Partial> = { ...options }; + ref.withConfig = ( + cfg: z.infer + ): ModelReference => { + return modelRef({ + ...options, + config: cfg, + }); + }; + ref.withVersion = (version: string): ModelReference => { + return modelRef({ + ...options, + version, + }); + }; + return ref as ModelReference; } /** Container for counting usage stats for a single input/output {Part} */ @@ -433,16 +457,20 @@ export function getBasicUsageStats( function getPartCounts(parts: Part[]): PartCounts { return parts.reduce( (counts, part) => { + const isImage = + part.media?.contentType?.startsWith('image') || + part.media?.url?.startsWith('data:image'); + const isVideo = + part.media?.contentType?.startsWith('video') || + part.media?.url?.startsWith('data:video'); + const isAudio = + part.media?.contentType?.startsWith('audio') || + part.media?.url?.startsWith('data:audio'); return { characters: counts.characters + (part.text?.length || 0), - images: - counts.images + - (part.media?.contentType?.startsWith('image') ? 1 : 0), - videos: - counts.videos + - (part.media?.contentType?.startsWith('video') ? 1 : 0), - audio: - counts.audio + (part.media?.contentType?.startsWith('audio') ? 1 : 0), + images: counts.images + (isImage ? 1 : 0), + videos: counts.videos + (isVideo ? 1 : 0), + audio: counts.audio + (isAudio ? 1 : 0), }; }, { characters: 0, images: 0, videos: 0, audio: 0 } diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index b520b5257..f497dca23 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -15,14 +15,19 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { DocumentData } from './document.js'; -import { GenerateOptions } from './generate.js'; +import { + GenerateOptions, + GenerateResponse, + GenerateStreamResponse, +} from './generate.js'; import { GenerateRequest, GenerateRequestSchema, ModelArgument, } from './model.js'; +import { ToolAction } from './tool.js'; export type PromptFn< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -58,6 +63,84 @@ export function isPrompt(arg: any): boolean { ); } +export type PromptGenerateOptions< + I = undefined, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> = Omit< + GenerateOptions, + 'prompt' | 'input' | 'model' +> & { + model?: ModelArgument; + input?: I; +}; + +/** + * A prompt that can be executed as a function. + */ +export interface ExecutablePrompt< + I = undefined, + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> { + /** + * Generates a response by rendering the prompt template with given user input and then calling the model. + * + * @param input Prompt inputs. + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + ( + input?: I, + opts?: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a response by rendering the prompt template with given user input and then calling the model. + * @param input Prompt inputs. + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + stream( + input?: I, + opts?: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateResponse`. + */ + generate( + opt: PromptGenerateOptions + ): Promise>>; + + /** + * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns the model response as a promise of `GenerateStreamResponse`. + */ + generateStream( + opt: PromptGenerateOptions + ): Promise>>; + + /** + * Renders the prompt template based on user input. + * + * @param opt Options for the prompt template, including user input variables and custom model configuration options. + * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai. + */ + render( + opt: PromptGenerateOptions + ): Promise>; + + /** + * Returns the prompt usable as a tool. + */ + asTool(): ToolAction; +} + /** * Defines and registers a prompt action. The action can be called to obtain * a `GenerateRequest` which can be passed to a model action. The given @@ -67,10 +150,12 @@ export function isPrompt(arg: any): boolean { * @returns The new `PromptAction`. */ export function definePrompt( + registry: Registry, config: PromptConfig, fn: PromptFn ): PromptAction { const a = defineAction( + registry, { ...config, actionType: 'prompt', @@ -94,16 +179,19 @@ export async function renderPrompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(params: { - prompt: PromptArgument; - input: z.infer; - docs?: DocumentData[]; - model: ModelArgument; - config?: z.infer; -}): Promise> { +>( + registry: Registry, + params: { + prompt: PromptArgument; + input: z.infer; + docs?: DocumentData[]; + model: ModelArgument; + config?: z.infer; + } +): Promise> { let prompt: PromptAction; if (typeof params.prompt === 'string') { - prompt = await lookupAction(`/prompt/${params.prompt}`); + prompt = await registry.lookupAction(`/prompt/${params.prompt}`); } else { prompt = params.prompt as PromptAction; } diff --git a/js/ai/src/reranker.ts b/js/ai/src/reranker.ts index 54428d0cb..35d3b2505 100644 --- a/js/ai/src/reranker.ts +++ b/js/ai/src/reranker.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Part, PartSchema } from './document.js'; import { Document, DocumentData, DocumentDataSchema } from './retriever.js'; @@ -101,6 +101,7 @@ function rerankerWithMetadata< * Creates a reranker action for the provided {@link RerankerFn} implementation. */ export function defineReranker( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -109,6 +110,7 @@ export function defineReranker( runner: RerankerFn ) { const reranker = defineAction( + registry, { actionType: 'reranker', name: options.name, @@ -157,13 +159,14 @@ export type RerankerArgument< * Reranks documents from a {@link RerankerArgument} based on the provided query. */ export async function rerank( + registry: Registry, params: RerankerParams ): Promise> { let reranker: RerankerAction; if (typeof params.reranker === 'string') { - reranker = await lookupAction(`/reranker/${params.reranker}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker}`); } else if (Object.hasOwnProperty.call(params.reranker, 'info')) { - reranker = await lookupAction(`/reranker/${params.reranker.name}`); + reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`); } else { reranker = params.reranker as RerankerAction; } diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts index 0d3f23689..0623e297f 100644 --- a/js/ai/src/retriever.ts +++ b/js/ai/src/retriever.ts @@ -15,7 +15,7 @@ */ import { Action, GenkitError, defineAction, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; import { EmbedderInfo } from './embedder.js'; @@ -111,6 +111,7 @@ function indexerWithMetadata< export function defineRetriever< OptionsType extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: { name: string; configSchema?: OptionsType; @@ -119,6 +120,7 @@ export function defineRetriever< runner: RetrieverFn ) { const retriever = defineAction( + registry, { actionType: 'retriever', name: options.name, @@ -149,6 +151,7 @@ export function defineRetriever< * Creates an indexer action for the provided {@link IndexerFn} implementation. */ export function defineIndexer( + registry: Registry, options: { name: string; embedderInfo?: EmbedderInfo; @@ -157,6 +160,7 @@ export function defineIndexer( runner: IndexerFn ) { const indexer = defineAction( + registry, { actionType: 'indexer', name: options.name, @@ -200,13 +204,16 @@ export type RetrieverArgument< * Retrieves documents from a {@link RetrieverArgument} based on the provided query. */ export async function retrieve( + registry: Registry, params: RetrieverParams ): Promise> { let retriever: RetrieverAction; if (typeof params.retriever === 'string') { - retriever = await lookupAction(`/retriever/${params.retriever}`); + retriever = await registry.lookupAction(`/retriever/${params.retriever}`); } else if (Object.hasOwnProperty.call(params.retriever, 'info')) { - retriever = await lookupAction(`/retriever/${params.retriever.name}`); + retriever = await registry.lookupAction( + `/retriever/${params.retriever.name}` + ); } else { retriever = params.retriever as RetrieverAction; } @@ -239,13 +246,14 @@ export interface IndexerParams< * Indexes documents using a {@link IndexerArgument}. */ export async function index( + registry: Registry, params: IndexerParams ): Promise { let indexer: IndexerAction; if (typeof params.indexer === 'string') { - indexer = await lookupAction(`/indexer/${params.indexer}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer}`); } else if (Object.hasOwnProperty.call(params.indexer, 'info')) { - indexer = await lookupAction(`/indexer/${params.indexer.name}`); + indexer = await registry.lookupAction(`/indexer/${params.indexer.name}`); } else { indexer = params.indexer as IndexerAction; } @@ -381,10 +389,12 @@ export function defineSimpleRetriever< C extends z.ZodTypeAny = z.ZodTypeAny, R = any, >( + registry: Registry, options: SimpleRetrieverOptions, handler: (query: Document, config: z.infer) => Promise ) { return defineRetriever( + registry, { name: options.name, configSchema: options.configSchema, diff --git a/js/ai/src/testing/model-tester.ts b/js/ai/src/testing/model-tester.ts index 3cf041ae9..7caa4b0cc 100644 --- a/js/ai/src/testing/model-tester.ts +++ b/js/ai/src/testing/model-tester.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { runInNewSpan } from '@genkit-ai/core/tracing'; import assert from 'node:assert'; import { generate } from '../generate'; @@ -23,8 +23,8 @@ import { ModelAction } from '../model'; import { defineTool } from '../tool'; const tests: Record = { - 'basic hi': async (model: string) => { - const response = await generate({ + 'basic hi': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'just say "Hi", literally', }); @@ -32,14 +32,14 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /Hi/i); }, - multimodal: async (model: string) => { - const resolvedModel = (await lookupAction( + multimodal: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.media) { skip(); } - const response = await generate({ + const response = await generate(registry, { model, prompt: [ { @@ -57,18 +57,18 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /plus/i); }, - history: async (model: string) => { - const resolvedModel = (await lookupAction( + history: async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.multiturn) { skip(); } - const response1 = await generate({ + const response1 = await generate(registry, { model, prompt: 'My name is Glorb', }); - const response = await generate({ + const response = await generate(registry, { model, prompt: "What's my name?", messages: response1.messages, @@ -77,8 +77,8 @@ const tests: Record = { const got = response.text.trim(); assert.match(got, /Glorb/); }, - 'system prompt': async (model: string) => { - const { text } = await generate({ + 'system prompt': async (registry: Registry, model: string) => { + const { text } = await generate(registry, { model, prompt: 'Hi', messages: [ @@ -97,8 +97,8 @@ const tests: Record = { const got = text.trim(); assert.equal(got, want); }, - 'structured output': async (model: string) => { - const response = await generate({ + 'structured output': async (registry: Registry, model: string) => { + const response = await generate(registry, { model, prompt: 'extract data as json from: Jack was a Lumberjack', output: { @@ -117,15 +117,15 @@ const tests: Record = { const got = response.output; assert.deepEqual(want, got); }, - 'tool calling': async (model: string) => { - const resolvedModel = (await lookupAction( + 'tool calling': async (registry: Registry, model: string) => { + const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.tools) { skip(); } - const { text } = await generate({ + const { text } = await generate(registry, { model, prompt: 'what is a gablorken of 2? use provided tool', tools: ['gablorkenTool'], @@ -149,10 +149,14 @@ type TestReport = { }[]; }[]; -type TestCase = (model: string) => Promise; +type TestCase = (ai: Registry, model: string) => Promise; -export async function testModels(models: string[]): Promise { +export async function testModels( + registry: Registry, + models: string[] +): Promise { const gablorkenTool = defineTool( + registry, { name: 'gablorkenTool', description: 'use when need to calculate a gablorken', @@ -182,7 +186,7 @@ export async function testModels(models: string[]): Promise { }); const modelReport = caseReport.models[caseReport.models.length - 1]; try { - await tests[test](model); + await tests[test](registry, model); } catch (e) { modelReport.passed = false; if (e instanceof SkipTestError) { diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 9dcb61c4f..a0d85340c 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -15,7 +15,7 @@ */ import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import { ToolDefinition } from './model.js'; @@ -89,11 +89,11 @@ export function asTool( export async function resolveTools< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(tools: ToolArgument[] = []): Promise { +>(registry: Registry, tools: ToolArgument[] = []): Promise { return await Promise.all( tools.map(async (ref): Promise => { if (typeof ref === 'string') { - const tool = await lookupAction(`/tool/${ref}`); + const tool = await registry.lookupAction(`/tool/${ref}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -101,7 +101,7 @@ export async function resolveTools< } else if ((ref as Action).__action) { return asTool(ref as Action); } else if (ref.name) { - const tool = await lookupAction(`/tool/${ref.name}`); + const tool = await registry.lookupAction(`/tool/${ref.name}`); if (!tool) { throw new Error(`Tool ${ref} not found`); } @@ -137,10 +137,12 @@ export function toToolDefinition( * A tool is an action that can be passed to a model to be called automatically if it so chooses. */ export function defineTool( + registry: Registry, config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { const a = defineAction( + registry, { ...config, actionType: 'tool', diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index cae8e419c..9a02b0f6e 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -262,19 +262,18 @@ describe('GenerateResponse', () => { describe('toGenerateRequest', () => { const registry = new Registry(); // register tools - const tellAFunnyJoke = runWithRegistry(registry, () => - defineTool( - { - name: 'tellAFunnyJoke', - description: - 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', - inputSchema: z.object({ topic: z.string() }), - outputSchema: z.string(), - }, - async (input) => { - return `Why did the ${input.topic} cross the road?`; - } - ) + const tellAFunnyJoke = defineTool( + registry, + { + name: 'tellAFunnyJoke', + description: + 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.', + inputSchema: z.object({ topic: z.string() }), + outputSchema: z.string(), + }, + async (input) => { + return `Why did the ${input.topic} cross the road?`; + } ); const testCases = [ @@ -442,9 +441,7 @@ describe('toGenerateRequest', () => { for (const test of testCases) { it(test.should, async () => { assert.deepStrictEqual( - await runWithRegistry(registry, () => - toGenerateRequest(test.prompt as GenerateOptions) - ), + await toGenerateRequest(registry, test.prompt as GenerateOptions), test.expectedOutput ); }); @@ -530,29 +527,28 @@ describe('generate', () => { beforeEach(() => { registry = new Registry(); - echoModel = runWithRegistry(registry, () => - defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + echoModel = defineModel( + registry, + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); }); @@ -592,14 +588,11 @@ describe('generate', () => { }; }; - const response = await runWithRegistry(registry, () => - generate({ - prompt: 'banana', - model: echoModel, - use: [wrapRequest, wrapResponse], - }) - ); - + const response = await generate(registry, { + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); const want = '[Echo: (banana)]'; assert.deepStrictEqual(response.text, want); }); @@ -609,24 +602,21 @@ describe('generate', () => { let registry: Registry; beforeEach(() => { registry = new Registry(); - runWithRegistry(registry, () => - defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ) + + defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) ); }); it('should preserve the request in the returned response, enabling .messages', async () => { - const response = await runWithRegistry(registry, () => - generate({ - model: 'echo', - prompt: 'Testing messages', - }) - ); - + const response = await generate(registry, { + model: 'echo', + prompt: 'Testing messages', + }); assert.deepEqual( response.messages.map((m) => m.content[0].text), ['Testing messages', 'Testing messages'] diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 9b3eb7054..3c9aaaffa 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { DocumentData } from '../../src/document.js'; @@ -147,24 +147,21 @@ describe('validateSupport', () => { }); const registry = new Registry(); -const echoModel = runWithRegistry(registry, () => - defineModel({ name: 'echo' }, async (req) => { - return { - finishReason: 'stop', - message: { - role: 'model', - content: [{ data: req }], - }, - }; - }) -); - +const echoModel = defineModel(registry, { name: 'echo' }, async (req) => { + return { + finishReason: 'stop', + message: { + role: 'model', + content: [{ data: req }], + }, + }; +}); describe('conformOutput (default middleware)', () => { const schema = { type: 'object', properties: { test: { type: 'boolean' } } }; // return the output tagged part from the request async function testRequest(req: GenerateRequest): Promise { - const response = await runWithRegistry(registry, () => echoModel(req)); + const response = await echoModel(req); const treq = response.message!.content[0].data as GenerateRequest; const lastUserMessage = treq.messages diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index c35c951c6..702f85444 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -15,7 +15,7 @@ */ import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { describe, it } from 'node:test'; import { definePrompt, renderPrompt } from '../../src/prompt.ts'; @@ -23,38 +23,37 @@ import { definePrompt, renderPrompt } from '../../src/prompt.ts'; describe('prompt', () => { let registry = new Registry(); describe('render()', () => { - runWithRegistry(registry, () => { - it('respects output schema in the definition', async () => { - const schema1 = z.object({ - puppyName: z.string({ description: 'A cute name for a puppy' }), - }); - const prompt1 = definePrompt( - { - name: 'prompt1', - inputSchema: z.string({ description: 'Dog breed' }), - }, - async (breed) => { - return { - messages: [ - { - role: 'user', - content: [{ text: `Pick a name for a ${breed} puppy` }], - }, - ], - output: { - format: 'json', - schema: schema1, + it('respects output schema in the definition', async () => { + const schema1 = z.object({ + puppyName: z.string({ description: 'A cute name for a puppy' }), + }); + const prompt1 = definePrompt( + registry, + { + name: 'prompt1', + inputSchema: z.string({ description: 'Dog breed' }), + }, + async (breed) => { + return { + messages: [ + { + role: 'user', + content: [{ text: `Pick a name for a ${breed} puppy` }], }, - }; - } - ); - const generateRequest = await renderPrompt({ - prompt: prompt1, - input: 'poodle', - model: 'geminiPro', - }); - assert.equal(generateRequest.output?.schema, schema1); + ], + output: { + format: 'json', + schema: schema1, + }, + }; + } + ); + const generateRequest = await renderPrompt(registry, { + prompt: prompt1, + input: 'poodle', + model: 'geminiPro', }); + assert.equal(generateRequest.output?.schema, schema1); }); }); }); diff --git a/js/ai/tests/reranker/reranker_test.ts b/js/ai/tests/reranker/reranker_test.ts index 4942e02b6..63a8b25e4 100644 --- a/js/ai/tests/reranker/reranker_test.ts +++ b/js/ai/tests/reranker/reranker_test.ts @@ -15,7 +15,7 @@ */ import { GenkitError, z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineReranker, rerank } from '../../src/reranker'; @@ -28,34 +28,32 @@ describe('reranker', () => { registry = new Registry(); }); it('reranks documents based on custom logic', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - // Custom reranking logic: score based on string length similarity to query - const queryLength = query.text.length; - const rerankedDocs = documents.map((doc) => { - const score = Math.abs(queryLength - doc.text.length); - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + // Custom reranking logic: score based on string length similarity to query + const queryLength = query.text.length; + const rerankedDocs = documents.map((doc) => { + const score = Math.abs(queryLength - doc.text.length); return { - documents: rerankedDocs - .sort((a, b) => a.metadata.score - b.metadata.score) - .slice(0, options.k || 3), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs + .sort((a, b) => a.metadata.score - b.metadata.score) + .slice(0, options.k || 3), + }; + } ); - // Sample documents for testing const documents = [ Document.fromText('short'), @@ -64,15 +62,12 @@ describe('reranker', () => { ]; const query = Document.fromText('medium length'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); // Validate the reranked results assert.equal(rerankedDocuments.length, 2); assert(rerankedDocuments[0].text.includes('a bit longer')); @@ -80,85 +75,76 @@ describe('reranker', () => { }); it('handles missing options gracefully', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().optional(), - }), - }, - async (query, documents, options) => { - const rerankedDocs = documents.map((doc) => { - const score = Math.random(); // Simplified scoring for testing - return { - ...doc, - metadata: { ...doc.metadata, score }, - }; - }); - + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().optional(), + }), + }, + async (query, documents, options) => { + const rerankedDocs = documents.map((doc) => { + const score = Math.random(); // Simplified scoring for testing return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), + ...doc, + metadata: { ...doc.metadata, score }, }; - } - ) + }); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 2 }, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 2 }, + }); assert.equal(rerankedDocuments.length, 2); assert(typeof rerankedDocuments[0].metadata.score === 'number'); }); it('validates config schema and throws error on invalid input', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - configSchema: z.object({ - k: z.number().min(1), - }), - }, - async (query, documents, options) => { - // Simplified scoring for testing - const rerankedDocs = documents.map((doc) => ({ - ...doc, - metadata: { score: Math.random() }, - })); - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + configSchema: z.object({ + k: z.number().min(1), + }), + }, + async (query, documents, options) => { + // Simplified scoring for testing + const rerankedDocs = documents.map((doc) => ({ + ...doc, + metadata: { score: Math.random() }, + })); + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [Document.fromText('doc1')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - options: { k: 0 }, // Invalid input: k must be at least 1 - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + options: { k: 0 }, // Invalid input: k must be at least 1 + }); assert.fail('Expected validation error'); } catch (err) { assert(err instanceof GenkitError); @@ -167,71 +153,62 @@ describe('reranker', () => { }); it('preserves document metadata after reranking', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - const rerankedDocs = documents.map((doc, i) => ({ - ...doc, - metadata: { ...doc.metadata, score: 2 - i }, - })); - - return { - documents: rerankedDocs.sort( - (a, b) => b.metadata.score - a.metadata.score - ), - }; - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + const rerankedDocs = documents.map((doc, i) => ({ + ...doc, + metadata: { ...doc.metadata, score: 2 - i }, + })); + + return { + documents: rerankedDocs.sort( + (a, b) => b.metadata.score - a.metadata.score + ), + }; + } ); - const documents = [ new Document({ content: [], metadata: { originalField: 'test1' } }), new Document({ content: [], metadata: { originalField: 'test2' } }), ]; const query = Document.fromText('test query'); - const rerankedDocuments = await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); - + const rerankedDocuments = await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.equal(rerankedDocuments[0].metadata.originalField, 'test1'); assert.equal(rerankedDocuments[1].metadata.originalField, 'test2'); }); it('handles errors thrown by the reranker', async () => { - const customReranker = runWithRegistry(registry, () => - defineReranker( - { - name: 'reranker', - }, - async (query, documents) => { - // Simulate an error in the reranker logic - throw new GenkitError({ - message: 'Something went wrong during reranking', - status: 'INTERNAL', - }); - } - ) + const customReranker = defineReranker( + registry, + { + name: 'reranker', + }, + async (query, documents) => { + // Simulate an error in the reranker logic + throw new GenkitError({ + message: 'Something went wrong during reranking', + status: 'INTERNAL', + }); + } ); - const documents = [Document.fromText('doc1'), Document.fromText('doc2')]; const query = Document.fromText('test query'); try { - await runWithRegistry(registry, () => - rerank({ - reranker: customReranker, - query, - documents, - }) - ); + await rerank(registry, { + reranker: customReranker, + query, + documents, + }); assert.fail('Expected an error to be thrown'); } catch (err) { assert(err instanceof GenkitError); diff --git a/js/core/package.json b/js/core/package.json index 3a26e2536..03fdcd451 100644 --- a/js/core/package.json +++ b/js/core/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 1bf8f7a1b..382b80d14 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,12 +17,7 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; -import { - ActionType, - initializeAllPlugins, - lookupPlugin, - registerAction, -} from './registry.js'; +import { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; import { SPAN_TYPE_ATTR, @@ -122,8 +117,8 @@ export function action< ): Action { const actionName = typeof config.name === 'string' - ? validateActionName(config.name) - : `${config.name.pluginId}/${validateActionId(config.name.actionId)}`; + ? config.name + : `${config.name.pluginId}/${config.name.actionId}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -168,16 +163,16 @@ export function action< return actionFn; } -function validateActionName(name: string) { +function validateActionName(registry: Registry, name: string) { if (name.includes('/')) { - validatePluginName(name.split('/', 1)[0]); + validatePluginName(registry, name.split('/', 1)[0]); validateActionId(name.substring(name.indexOf('/') + 1)); } return name; } -function validatePluginName(pluginId: string) { - if (!lookupPlugin(pluginId)) { +function validatePluginName(registry: Registry, pluginId: string) { + if (!registry.lookupPlugin(pluginId)) { throw new Error( `Unable to find plugin name used in the action name: ${pluginId}` ); @@ -200,6 +195,7 @@ export function defineAction< O extends z.ZodTypeAny, M extends Record = Record, >( + registry: Registry, config: ActionParams & { actionType: ActionType; }, @@ -211,13 +207,18 @@ export function defineAction< 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' ); } + if (typeof config.name === 'string') { + validateActionName(registry, config.name); + } else { + validateActionId(config.name.actionId); + } const act = action(config, async (i: I): Promise> => { setCustomMetadataAttributes({ subtype: config.actionType }); - await initializeAllPlugins(); + await registry.initializeAllPlugins(); return await runInActionRuntimeContext(() => fn(i)); }); act.__action.actionType = config.actionType; - registerAction(config.actionType, act); + registry.registerAction(config.actionType, act); return act; } diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 0061e0cde..107459585 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -31,12 +31,7 @@ import { runWithAuthContext } from './auth.js'; import { getErrorMessage, getErrorStack } from './error.js'; import { FlowActionInputSchema } from './flowTypes.js'; import { logger } from './logging.js'; -import { - getRegistryInstance, - initializeAllPlugins, - Registry, - runWithRegistry, -} from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { newTrace, @@ -181,6 +176,7 @@ export class Flow< readonly flowFn: FlowFn; constructor( + private registry: Registry, config: FlowConfig | StreamingFlowConfig, flowFn: FlowFn ) { @@ -207,7 +203,7 @@ export class Flow< auth?: unknown; } ): Promise>> { - await initializeAllPlugins(); + await this.registry.initializeAllPlugins(); return await runWithAuthContext(opts.auth, () => newTrace( { @@ -336,84 +332,79 @@ export class Flow< } async expressHandler( - registry: Registry, request: __RequestWithAuth, response: express.Response ): Promise { - await runWithRegistry(registry, async () => { - const { stream } = request.query; - const auth = request.auth; - - let input = request.body.data; + const { stream } = request.query; + const auth = request.auth; + + let input = request.body.data; + + try { + await this.authPolicy?.(auth, input); + } catch (e: any) { + const respBody = { + error: { + status: 'PERMISSION_DENIED', + message: e.message || 'Permission denied to resource', + }, + }; + response.status(403).send(respBody).end(); + return; + } + if (stream === 'true') { + response.writeHead(200, { + 'Content-Type': 'text/plain', + 'Transfer-Encoding': 'chunked', + }); try { - await this.authPolicy?.(auth, input); - } catch (e: any) { - const respBody = { + const result = await this.invoke(input, { + streamingCallback: ((chunk: z.infer) => { + response.write(JSON.stringify(chunk) + streamDelimiter); + }) as S extends z.ZodVoid ? undefined : StreamingCallback>, + auth, + }); + response.write({ + result: result.result, // Need more results!!!! + }); + response.end(); + } catch (e) { + response.write({ error: { - status: 'PERMISSION_DENIED', - message: e.message || 'Permission denied to resource', + status: 'INTERNAL', + message: getErrorMessage(e), + details: getErrorStack(e), }, - }; - response.status(403).send(respBody).end(); - return; - } - - if (stream === 'true') { - response.writeHead(200, { - 'Content-Type': 'text/plain', - 'Transfer-Encoding': 'chunked', }); - try { - const result = await this.invoke(input, { - streamingCallback: ((chunk: z.infer) => { - response.write(JSON.stringify(chunk) + streamDelimiter); - }) as S extends z.ZodVoid - ? undefined - : StreamingCallback>, - auth, - }); - response.write({ - result: result.result, // Need more results!!!! - }); - response.end(); - } catch (e) { - response.write({ + response.end(); + } + } else { + try { + const result = await this.invoke(input, { auth }); + response.setHeader('x-genkit-trace-id', result.traceId); + response.setHeader('x-genkit-span-id', result.spanId); + // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." + response + .status(200) + .send({ + result: result.result, + }) + .end(); + } catch (e) { + // Errors for non-streaming flows are passed back as standard API errors. + response + .status(500) + .send({ error: { status: 'INTERNAL', message: getErrorMessage(e), details: getErrorStack(e), }, - }); - response.end(); - } - } else { - try { - const result = await this.invoke(input, { auth }); - response.setHeader('x-genkit-trace-id', result.traceId); - response.setHeader('x-genkit-span-id', result.spanId); - // Responses for non-streaming flows are passed back with the flow result stored in a field called "result." - response - .status(200) - .send({ - result: result.result, - }) - .end(); - } catch (e) { - // Errors for non-streaming flows are passed back as standard API errors. - response - .status(500) - .send({ - error: { - status: 'INTERNAL', - message: getErrorMessage(e), - details: getErrorStack(e), - }, - }) - .end(); - } + }) + .end(); } - }); + } } } @@ -496,9 +487,7 @@ export class FlowServer { flow.middleware?.forEach((middleware) => server.post(flowPath, middleware) ); - server.post(flowPath, (req, res) => - flow.expressHandler(this.registry, req, res) - ); + server.post(flowPath, (req, res) => flow.expressHandler(req, res)); }); } else { logger.warn('No flows registered in flow server.'); @@ -557,17 +546,17 @@ export function defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: FlowConfig | string, fn: FlowFn ): CallableFlow { const resolvedConfig: FlowConfig = typeof config === 'string' ? { name: config } : config; - const flow = new Flow(resolvedConfig, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, resolvedConfig, fn); + registerFlowAction(registry, flow); const callableFlow: CallableFlow = async (input, opts) => { - return runWithRegistry(registry, () => flow.run(input, opts)); + return flow.run(input, opts); }; callableFlow.flow = flow; return callableFlow; @@ -581,14 +570,14 @@ export function defineStreamingFlow< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = new Flow(config, fn); - registerFlowAction(flow); - const registry = getRegistryInstance(); + const flow = new Flow(registry, config, fn); + registerFlowAction(registry, flow); const streamableFlow: StreamableFlow = (input, opts) => { - return runWithRegistry(registry, () => flow.stream(input, opts)); + return flow.stream(input, opts); }; streamableFlow.flow = flow; return streamableFlow; @@ -601,8 +590,12 @@ function registerFlowAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, ->(flow: Flow): Action { +>( + registry: Registry, + flow: Flow +): Action { return defineAction( + registry, { actionType: 'flow', name: flow.name, diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index e74a7fa69..e66c73630 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import express, { NextFunction, Request, Response } from 'express'; +import express from 'express'; import fs from 'fs/promises'; import getPort, { makeRange } from 'get-port'; import { Server } from 'http'; @@ -23,7 +23,7 @@ import z from 'zod'; import { Status, StatusCodes, runWithStreamingCallback } from './action.js'; import { GENKIT_VERSION } from './index.js'; import { logger } from './logging.js'; -import { Registry, runWithRegistry } from './registry.js'; +import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; import { flushTracing, @@ -113,16 +113,6 @@ export class ReflectionServer { next(); }); - server.use((req: Request, res: Response, next: NextFunction) => { - runWithRegistry(this.registry, async () => { - try { - next(); - } catch (err) { - next(err); - } - }); - }); - server.get('/api/__health', async (_, response) => { await this.registry.listActions(); response.status(200).send('OK'); diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index afa679ac6..f7cd0f532 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { AsyncLocalStorage } from 'async_hooks'; import * as z from 'zod'; import { Action } from './action.js'; import { logger } from './logging.js'; @@ -47,17 +46,6 @@ export interface Schema { jsonSchema?: JSONSchema; } -/** - * Looks up a registry key (action type and key) in the registry. - */ -export function lookupAction< - I extends z.ZodTypeAny, - O extends z.ZodTypeAny, - R extends Action, ->(key: string): Promise { - return getRegistryInstance().lookupAction(key); -} - function parsePluginName(registryKey: string) { const tokens = registryKey.split('/'); if (tokens.length === 4) { @@ -66,99 +54,8 @@ function parsePluginName(registryKey: string) { return undefined; } -/** - * Registers an action in the registry. - */ -export function registerAction( - type: ActionType, - action: Action -) { - return getRegistryInstance().registerAction(type, action); -} - type ActionsRecord = Record>; -/** - * Initialize all plugins in the registry. - */ -export async function initializeAllPlugins() { - await getRegistryInstance().initializeAllPlugins(); -} - -/** - * Returns all actions in the registry. - */ -export function listActions(): Promise { - return getRegistryInstance().listActions(); -} - -/** - * Registers a plugin provider. - * @param name The name of the plugin to register. - * @param provider The plugin provider. - */ -export function registerPluginProvider(name: string, provider: PluginProvider) { - return getRegistryInstance().registerPluginProvider(name, provider); -} - -/** - * Looks up a plugin. - * @param name The name of the plugin to lookup. - * @returns The plugin. - */ -export function lookupPlugin(name: string) { - return getRegistryInstance().lookupPlugin(name); -} - -/** - * Initializes a plugin that has already been registered. - * @param name The name of the plugin to initialize. - * @returns The plugin. - */ -export async function initializePlugin(name: string) { - return getRegistryInstance().initializePlugin(name); -} - -/** - * Registers a schema. - * @param name The name of the schema to register. - * @param data The schema to register (either a Zod schema or a JSON schema). - */ -export function registerSchema(name: string, data: Schema) { - return getRegistryInstance().registerSchema(name, data); -} - -/** - * Looks up a schema. - * @param name The name of the schema to lookup. - * @returns The schema. - */ -export function lookupSchema(name: string) { - return getRegistryInstance().lookupSchema(name); -} - -const registryAls = new AsyncLocalStorage(); - -/** - * @returns The active registry instance. - */ -export function getRegistryInstance(): Registry { - const registry = registryAls.getStore(); - if (!registry) { - throw new Error('getRegistryInstance() called before runWithRegistry()'); - } - return registry; -} - -/** - * Runs a function with a specific registry instance. - * @param registry The registry instance to use. - * @param fn The function to run. - */ -export function runWithRegistry(registry: Registry, fn: () => R) { - return registryAls.run(registry, fn); -} - /** * The registry is used to store and lookup actions, trace stores, flow state stores, plugins, and schemas. */ @@ -170,14 +67,6 @@ export class Registry { constructor(public parent?: Registry) {} - /** - * Creates a new registry overlaid onto the currently active registry. - * @returns The new overlaid registry. - */ - static withCurrent() { - return new Registry(getRegistryInstance()); - } - /** * Creates a new registry overlaid onto the provided registry. * @param parent The parent registry. diff --git a/js/core/src/schema.ts b/js/core/src/schema.ts index 16a45160d..a53da8acb 100644 --- a/js/core/src/schema.ts +++ b/js/core/src/schema.ts @@ -19,7 +19,7 @@ import addFormats from 'ajv-formats'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; import { GenkitError } from './error.js'; -import { registerSchema } from './registry.js'; +import { Registry } from './registry.js'; const ajv = new Ajv(); addFormats(ajv); @@ -112,14 +112,19 @@ export function parseSchema( } export function defineSchema( + registry: Registry, name: string, schema: T ): T { - registerSchema(name, { schema }); + registry.registerSchema(name, { schema }); return schema; } -export function defineJsonSchema(name: string, jsonSchema: JSONSchema) { - registerSchema(name, { jsonSchema }); +export function defineJsonSchema( + registry: Registry, + name: string, + jsonSchema: JSONSchema +) { + registry.registerSchema(name, { jsonSchema }); return jsonSchema; } diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 7d5b74646..cce14e2ee 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -18,10 +18,11 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { defineFlow, defineStreamingFlow } from '../src/flow.js'; import { z } from '../src/index.js'; -import { Registry, runWithRegistry } from '../src/registry.js'; +import { Registry } from '../src/registry.js'; -function createTestFlow() { +function createTestFlow(registry: Registry) { return defineFlow( + registry, { name: 'testFlow', inputSchema: z.string(), @@ -33,8 +34,9 @@ function createTestFlow() { ); } -function createTestStreamingFlow() { +function createTestStreamingFlow(registry: Registry) { return defineStreamingFlow( + registry, { name: 'testFlow', inputSchema: z.number(), @@ -63,7 +65,7 @@ describe('flow', () => { describe('runFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestFlow); + const testFlow = createTestFlow(registry); const result = await testFlow('foo'); @@ -71,10 +73,8 @@ describe('flow', () => { }); it('should run simple sync flow', async () => { - const testFlow = runWithRegistry(registry, () => { - return defineFlow('testFlow', (input) => { - return `bar ${input}`; - }); + const testFlow = defineFlow(registry, 'testFlow', (input) => { + return `bar ${input}`; }); const result = await testFlow('foo'); @@ -83,17 +83,16 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'throwing', - inputSchema: z.string(), - outputSchema: z.string(), - }, - async (input) => { - throw new Error(`bad happened: ${input}`); - } - ) + const testFlow = defineFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + throw new Error(`bad happened: ${input}`); + } ); await assert.rejects(() => testFlow('foo'), { @@ -103,17 +102,16 @@ describe('flow', () => { }); it('should validate input', async () => { - const testFlow = runWithRegistry(registry, () => - defineFlow( - { - name: 'validating', - inputSchema: z.object({ foo: z.string(), bar: z.number() }), - outputSchema: z.string(), - }, - async (input) => { - return `ok ${input}`; - } - ) + const testFlow = defineFlow( + registry, + { + name: 'validating', + inputSchema: z.object({ foo: z.string(), bar: z.number() }), + outputSchema: z.string(), + }, + async (input) => { + return `ok ${input}`; + } ); await assert.rejects( @@ -132,7 +130,7 @@ describe('flow', () => { describe('streamFlow', () => { it('should run the flow', async () => { - const testFlow = runWithRegistry(registry, createTestStreamingFlow); + const testFlow = createTestStreamingFlow(registry); const response = testFlow(3); @@ -146,16 +144,15 @@ describe('flow', () => { }); it('should rethrow the error', async () => { - const testFlow = runWithRegistry(registry, () => - defineStreamingFlow( - { - name: 'throwing', - inputSchema: z.string(), - }, - async (input) => { - throw new Error(`stream bad happened: ${input}`); - } - ) + const testFlow = defineStreamingFlow( + registry, + { + name: 'throwing', + inputSchema: z.string(), + }, + async (input) => { + throw new Error(`stream bad happened: ${input}`); + } ); const response = testFlow('foo'); diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 9542cf779..d54fdd415 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -17,175 +17,7 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { action } from '../src/action.js'; -import { - Registry, - listActions, - lookupAction, - registerAction, - registerPluginProvider, - runWithRegistry, -} from '../src/registry.js'; - -describe('global registry', () => { - let registry: Registry; - - beforeEach(() => { - registry = new Registry(); - }); - - describe('listActions', () => { - it('returns all registered actions', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.deepEqual(await listActions(), { - '/model/foo_something': fooSomethingAction, - '/model/bar_something': barSomethingAction, - }); - }); - }); - - it('returns all registered actions by plugins', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', fooSomethingAction); - return {}; - }, - }); - const fooSomethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - registerAction('model', barSomethingAction); - return {}; - }, - }); - const barSomethingAction = action( - { - name: { - pluginId: 'bar', - actionId: 'something', - }, - }, - async () => null - ); - - assert.deepEqual(await listActions(), { - '/model/foo/something': fooSomethingAction, - '/model/bar/something': barSomethingAction, - }); - }); - }); - }); - - describe('lookupAction', () => { - it('initializes plugin for action first', async () => { - await runWithRegistry(registry, async () => { - let fooInitialized = false; - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - fooInitialized = true; - return {}; - }, - }); - let barInitialized = false; - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - barInitialized = true; - return {}; - }, - }); - - await lookupAction('/model/foo/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, false); - - await lookupAction('/model/bar/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, true); - }); - }); - }); - - it('returns registered action', async () => { - await runWithRegistry(registry, async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registerAction('model', barSomethingAction); - - assert.strictEqual( - await lookupAction('/model/foo_something'), - fooSomethingAction - ); - assert.strictEqual( - await lookupAction('/model/bar_something'), - barSomethingAction - ); - }); - }); - - it('returns action registered by plugin', async () => { - await runWithRegistry(registry, async () => { - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', somethingAction); - return {}; - }, - }); - const somethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - - assert.strictEqual( - await lookupAction('/model/foo/something'), - somethingAction - ); - }); - }); - - it('returns undefined for unknown action', async () => { - await runWithRegistry(registry, async () => { - assert.strictEqual(await lookupAction('/model/foo/something'), undefined); - }); - }); -}); +import { Registry } from '../src/registry.js'; describe('registry class', () => { var registry: Registry; diff --git a/js/genkit/package.json b/js/genkit/package.json index fbed927e3..c5d18f544 100644 --- a/js/genkit/package.json +++ b/js/genkit/package.json @@ -7,7 +7,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "main": "./lib/cjs/index.js", "scripts": { @@ -16,7 +16,8 @@ "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "node --import tsx --test tests/*_test.ts" + "test": "node --import tsx --test tests/*_test.ts", + "test:watch": "node --watch --import tsx --test tests/*_test.ts" }, "repository": { "type": "git", diff --git a/js/genkit/src/chat.ts b/js/genkit/src/chat.ts index 97405b1b9..7587c941a 100644 --- a/js/genkit/src/chat.ts +++ b/js/genkit/src/chat.ts @@ -15,6 +15,7 @@ */ import { + ExecutablePrompt, GenerateOptions, GenerateResponse, GenerateStreamOptions, @@ -24,24 +25,25 @@ import { Part, } from '@genkit-ai/ai'; import { z } from '@genkit-ai/core'; -import { v4 as uuidv4 } from 'uuid'; import { Genkit } from './genkit'; -import { - Session, - SessionData, - SessionStore, - inMemorySessionStore, -} from './session'; +import { Session, SessionStore } from './session'; export const MAIN_THREAD = 'main'; export type BaseGenerateOptions = Omit; -export type ChatOptions = - BaseGenerateOptions & { - store?: SessionStore; - sessionId?: string; - }; +export interface PromptRenderOptions { + prompt: ExecutablePrompt; + input?: I; +} + +export type ChatOptions< + I = undefined, + S extends z.ZodTypeAny = z.ZodTypeAny, +> = (PromptRenderOptions | BaseGenerateOptions) & { + store?: SessionStore; + sessionId?: string; +}; /** * Chat encapsulates a statful execution environment for chat. @@ -56,75 +58,55 @@ export type ChatOptions = * ``` */ export class Chat { + readonly requestBase?: Promise; readonly sessionId: string; readonly schema?: S; - private sessionData?: SessionData; - private store: SessionStore; + private _messages?: MessageData[]; private threadName: string; constructor( - readonly parent: Genkit | Session | Chat, - readonly requestBase?: BaseGenerateOptions, - options?: { - id?: string; - sessionData?: SessionData; - store?: SessionStore; - thread?: string; + readonly session: Session, + requestBase: Promise, + options: { + id: string; + thread: string; + messages?: MessageData[]; } ) { - this.sessionId = options?.id ?? uuidv4(); - this.threadName = options?.thread ?? MAIN_THREAD; - this.sessionData = options?.sessionData; - if (!this.sessionData) { - this.sessionData = { id: this.sessionId }; - } - if (!this.sessionData.threads) { - this.sessionData!.threads = {}; - } - // this is handling dotprompt render case - if (requestBase && requestBase['prompt']) { - const basePrompt = requestBase['prompt'] as string | Part | Part[]; - let promptMessage: MessageData; - if (typeof basePrompt === 'string') { - promptMessage = { - role: 'user', - content: [{ text: basePrompt }], - }; - } else if (Array.isArray(basePrompt)) { - promptMessage = { - role: 'user', - content: basePrompt, - }; - } else { - promptMessage = { - role: 'user', - content: [basePrompt], - }; - } - requestBase.messages = [...(requestBase.messages ?? []), promptMessage]; - } - if (parent instanceof Chat) { - if (!this.sessionData.threads[this.threadName]) { - this!.sessionData.threads[this.threadName] = [ - ...(parent.messages ?? []), - ...(requestBase?.messages ?? []), - ]; - } - } else if (parent instanceof Session) { - if (!this.sessionData.threads[this.threadName]) { - this!.sessionData.threads[this.threadName] = [ - ...(requestBase?.messages ?? []), - ]; + this.sessionId = options.id; + this.threadName = options.thread; + this.requestBase = requestBase?.then((rb) => { + const requestBase = { ...rb }; + // this is handling dotprompt render case + if (requestBase && requestBase['prompt']) { + const basePrompt = requestBase['prompt'] as string | Part | Part[]; + let promptMessage: MessageData; + if (typeof basePrompt === 'string') { + promptMessage = { + role: 'user', + content: [{ text: basePrompt }], + }; + } else if (Array.isArray(basePrompt)) { + promptMessage = { + role: 'user', + content: basePrompt, + }; + } else { + promptMessage = { + role: 'user', + content: [basePrompt], + }; + } + requestBase.messages = [...(requestBase.messages ?? []), promptMessage]; } - } else { - // Genkit - if (!this.sessionData.threads[this.threadName]) { - this.sessionData.threads[this.threadName] = [ - ...(requestBase?.messages ?? []), - ]; - } - } - this.store = options?.store ?? (inMemorySessionStore() as SessionStore); + requestBase.messages = [ + ...(options.messages ?? []), + ...(requestBase.messages ?? []), + ]; + this._messages = requestBase.messages; + return requestBase; + }); + this._messages = options.messages; } async send< @@ -146,7 +128,7 @@ export class Chat { } as GenerateOptions; } const response = await this.genkit.generate({ - ...this.requestBase, + ...(await this.requestBase), messages: this.messages, ...options, }); @@ -173,7 +155,7 @@ export class Chat { } as GenerateOptions; } const { response, stream } = await this.genkit.generateStream({ - ...this.requestBase, + ...(await this.requestBase), messages: this.messages, ...options, }); @@ -187,66 +169,15 @@ export class Chat { } private get genkit(): Genkit { - if (this.parent instanceof Genkit) { - return this.parent; - } - if (this.parent instanceof Session) { - return this.parent.genkit; - } - return this.parent.genkit; + return this.session.genkit; } - get state(): z.infer { - // We always get state from the parent. Parent session is the source of truth. - if (this.parent instanceof Session) { - return this.parent.state; - } - return this.sessionData!.state; - } - - async updateState(data: z.infer): Promise { - // We always update the state on the parent. Parent session is the source of truth. - if (this.parent instanceof Session) { - return this.parent.updateState(data); - } - let sessionData = await this.store.get(this.sessionId); - if (!sessionData) { - sessionData = {} as SessionData; - } - sessionData.state = data; - this.sessionData = sessionData; - - await this.store.save(this.sessionId, sessionData); - } - - get messages(): MessageData[] | undefined { - if (!this.sessionData?.threads) { - return undefined; - } - return this.sessionData?.threads[this.threadName]; + get messages(): MessageData[] { + return this._messages ?? []; } async updateMessages(messages: MessageData[]): Promise { - let sessionData = await this.store.get(this.sessionId); - if (!sessionData) { - sessionData = { id: this.sessionId, threads: {} }; - } - if (!sessionData.threads) { - sessionData.threads = {}; - } - sessionData.threads[this.threadName] = messages; - this.sessionData = sessionData; - await this.store.save(this.sessionId, sessionData); - } - - toJSON() { - if (this.parent instanceof Session) { - return this.parent.toJSON(); - } - return this.sessionData; - } - - static fromJSON(data: SessionData) { - //return new Session(); + this._messages = messages; + await this.session.updateMessages(this.threadName, messages); } } diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 85ad82d8b..c9ff26a75 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -26,6 +26,7 @@ import { EvalResponses, evaluate, EvaluatorParams, + ExecutablePrompt, generate, GenerateOptions, GenerateRequest, @@ -35,13 +36,13 @@ import { GenerateStreamOptions, GenerateStreamResponse, GenerationCommonConfigSchema, - index, IndexerParams, ModelArgument, ModelReference, Part, PromptAction, PromptFn, + PromptGenerateOptions, RankedDocument, rerank, RerankerParams, @@ -81,6 +82,7 @@ import { defineRetriever, defineSimpleRetriever, DocumentData, + index, IndexerAction, IndexerFn, RetrieverFn, @@ -109,10 +111,7 @@ import { defineDotprompt, defineHelper, definePartial, - Dotprompt, loadPromptFolder, - prompt, - PromptGenerateOptions, PromptMetadata, } from '@genkit-ai/dotprompt'; import { v4 as uuidv4 } from 'uuid'; @@ -120,7 +119,7 @@ import { Chat, ChatOptions } from './chat.js'; import { BaseEvalDataPointSchema } from './evaluator.js'; import { logger } from './logging.js'; import { GenkitPlugin, genkitPlugin } from './plugin.js'; -import { lookupAction, Registry, runWithRegistry } from './registry.js'; +import { Registry } from './registry.js'; import { getCurrentSession, Session, @@ -144,65 +143,6 @@ export interface GenkitOptions { flowServer?: FlowServerOptions | boolean; } -export interface ExecutablePrompt< - I extends z.ZodTypeAny = z.ZodTypeAny, - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> { - /** - * Generates a response by rendering the prompt template with given user input and then calling the model. - * - * @param input Prompt inputs. - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - ( - input?: z.infer, - opts?: z.infer - ): Promise>>; - - /** - * Generates a response by rendering the prompt template with given user input and then calling the model. - * @param input Prompt inputs. - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - stream( - input?: z.infer, - opts?: z.infer - ): Promise>>; - - /** - * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateResponse`. - */ - generate( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>>; - - /** - * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns the model response as a promise of `GenerateStreamResponse`. - */ - generateStream( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>>; - - /** - * Renders the prompt template based on user input. - * - * @param opt Options for the prompt template, including user input variables and custom model configuration options. - * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai. - */ - render( - opt: PromptGenerateOptions, CustomOptions> - ): Promise>; -} - /** * `Genkit` encapsulates a single Genkit instance including the {@link Registry}, {@link ReflectionServer}, {@link FlowServer}, and configuration. * @@ -253,7 +193,7 @@ export class Genkit { I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >(config: FlowConfig | string, fn: FlowFn): CallableFlow { - const flow = runWithRegistry(this.registry, () => defineFlow(config, fn)); + const flow = defineFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -271,9 +211,7 @@ export class Genkit { config: StreamingFlowConfig, fn: FlowFn ): StreamableFlow { - const flow = runWithRegistry(this.registry, () => - defineStreamingFlow(config, fn) - ); + const flow = defineStreamingFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; } @@ -287,7 +225,7 @@ export class Genkit { config: ToolConfig, fn: (input: z.infer) => Promise> ): ToolAction { - return runWithRegistry(this.registry, () => defineTool(config, fn)); + return defineTool(this.registry, config, fn); } /** @@ -296,7 +234,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineSchema(name: string, schema: T): T { - return runWithRegistry(this.registry, () => defineSchema(name, schema)); + return defineSchema(this.registry, name, schema); } /** @@ -305,9 +243,7 @@ export class Genkit { * Defined schemas can be referenced by `name` in prompts in place of inline schemas. */ defineJsonSchema(name: string, jsonSchema: JSONSchema) { - return runWithRegistry(this.registry, () => - defineJsonSchema(name, jsonSchema) - ); + return defineJsonSchema(this.registry, name, jsonSchema); } /** @@ -320,7 +256,7 @@ export class Genkit { streamingCallback?: StreamingCallback ) => Promise ): ModelAction { - return runWithRegistry(this.registry, () => defineModel(options, runner)); + return defineModel(this.registry, options, runner); } /** @@ -328,33 +264,21 @@ export class Genkit { * * @todo TODO: Show an example of a name and variant. */ - prompt< + async prompt< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >( name: string, options?: { variant?: string } - ): Promise> { - return runWithRegistry(this.registry, async () => { - const action = (await lookupAction(`/prompt/${name}`)) as PromptAction; - if ( - action.__action?.metadata?.prompt && - Object.keys(action.__action.metadata.prompt).length > 0 - ) { - const p = await prompt(name, options); - return this.wrapDotpromptInExecutablePrompt(p, {}) as ExecutablePrompt< - I, - O, - CustomOptions - >; - } else { - return this.wrapPromptActionInExecutablePrompt( - action, - {} - ) as ExecutablePrompt; - } - }); + ): Promise, O, CustomOptions>> { + const action = (await this.registry.lookupAction( + `/prompt/${name}` + )) as PromptAction; + return this.wrapPromptActionInExecutablePrompt( + action, + {} + ) as ExecutablePrompt; } /** @@ -387,7 +311,7 @@ export class Genkit { name: string; }, template: string - ): ExecutablePrompt; + ): ExecutablePrompt, O, CustomOptions>; /** * Defines and registers a function-based prompt. @@ -424,7 +348,7 @@ export class Genkit { name: string; }, fn: PromptFn - ): ExecutablePrompt; + ): ExecutablePrompt, O, CustomOptions>; definePrompt< I extends z.ZodTypeAny = z.ZodTypeAny, @@ -436,106 +360,35 @@ export class Genkit { name: string; }, templateOrFn: string | PromptFn - ): ExecutablePrompt { + ): ExecutablePrompt, O, CustomOptions> { if (!options.name) { throw new Error('options.name is required'); } - return runWithRegistry(this.registry, () => { - if (!options.name) { - throw new Error('options.name is required'); - } - if (typeof templateOrFn === 'string') { - const dotprompt = defineDotprompt(options, templateOrFn as string); - return this.wrapDotpromptInExecutablePrompt(dotprompt, options); - } else { - const p = definePrompt( - { - name: options.name!, - inputJsonSchema: options.input?.jsonSchema, - inputSchema: options.input?.schema, - }, - templateOrFn as PromptFn - ); - return this.wrapPromptActionInExecutablePrompt(p, options); - } - }); - } - - private wrapDotpromptInExecutablePrompt< - I extends z.ZodTypeAny = z.ZodTypeAny, - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, - >( - dotprompt: Dotprompt>, - options: PromptMetadata - ): ExecutablePrompt { - const executablePrompt = ( - input?: z.infer, - opts?: z.infer - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - return dotprompt.generate({ - model, - input, - config: opts, - }); - }); - }; - (executablePrompt as ExecutablePrompt).stream = ( - input?: z.infer, - opts?: z.infer - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - return dotprompt.generateStream({ - model, - input, - config: opts, - }) as Promise>; - }); - }; - (executablePrompt as ExecutablePrompt).generate = ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - return dotprompt.generate({ - model, - ...opt, - }); - }); - }; - (executablePrompt as ExecutablePrompt).generateStream = - ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - return dotprompt.generateStream({ - model, - ...opt, - }) as Promise>; - }); - }; - (executablePrompt as ExecutablePrompt).render = < - Out extends O, - >( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry( + if (!options.name) { + throw new Error('options.name is required'); + } + if (typeof templateOrFn === 'string') { + const dotprompt = defineDotprompt( this.registry, - async () => - dotprompt.render({ - ...opt, - }) as GenerateOptions + options, + templateOrFn as string ); - }; - return executablePrompt as ExecutablePrompt; + return this.wrapPromptActionInExecutablePrompt( + dotprompt.promptAction! as PromptAction, + options + ); + } else { + const p = definePrompt( + this.registry, + { + name: options.name!, + inputJsonSchema: options.input?.jsonSchema, + inputSchema: options.input?.schema, + }, + templateOrFn as PromptFn + ); + return this.wrapPromptActionInExecutablePrompt(p, options); + } } private wrapPromptActionInExecutablePrompt< @@ -546,135 +399,83 @@ export class Genkit { p: PromptAction, options: PromptMetadata ): ExecutablePrompt { - const executablePrompt = ( + const executablePrompt = async ( input?: z.infer, - opts?: z.infer + opts?: PromptGenerateOptions ): Promise => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - const promptResult = await p(input); - return this.generate({ - model, - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - config: { - ...options.config, - ...opts, - ...promptResult.config, - }, - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generate(renderedOpts); }; - (executablePrompt as ExecutablePrompt).stream = ( + (executablePrompt as ExecutablePrompt).stream = async ( input?: z.infer, opts?: z.infer ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = await this.resolveModel(options.model); - const promptResult = await p(input); - return this.generateStream({ - model, - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - config: { - ...options.config, - ...promptResult.config, - ...opts, - }, - }); - }); - }; - (executablePrompt as ExecutablePrompt).generate = ( - opt: PromptGenerateOptions - ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return this.generate({ - model, - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render({ + ...opts, + input, }); + return this.generateStream(renderedOpts); }; + (executablePrompt as ExecutablePrompt).generate = + async ( + opt: PromptGenerateOptions + ): Promise> => { + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render(opt); + return this.generate(renderedOpts); + }; (executablePrompt as ExecutablePrompt).generateStream = - ( + async ( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return this.generateStream({ - model, - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - } as any /* FIXME - schema type inference is borken */, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - }); - }); + const renderedOpts = await ( + executablePrompt as ExecutablePrompt + ).render(opt); + return this.generateStream(renderedOpts); }; - (executablePrompt as ExecutablePrompt).render = < + (executablePrompt as ExecutablePrompt).render = async < Out extends O, >( opt: PromptGenerateOptions ): Promise> => { - return runWithRegistry(this.registry, async () => { - const model = !opt.model - ? await this.resolveModel(options.model) - : undefined; - const promptResult = await p(opt.input); - return { - model, - messages: promptResult.messages, - docs: promptResult.docs, - tools: promptResult.tools, - output: { - format: promptResult.output?.format, - jsonSchema: promptResult.output?.schema, - }, - ...opt, - config: { - ...options.config, - ...promptResult.config, - ...opt.config, - }, - } as GenerateOptions; - }); + let model: ModelAction | undefined; + try { + model = await this.resolveModel(opt?.model ?? options.model); + } catch (e) { + // ignore, no model on a render is OK. + } + + const promptResult = await p(opt.input); + const resultOptions = { + messages: promptResult.messages, + docs: promptResult.docs, + tools: promptResult.tools, + output: { + format: promptResult.output?.format, + jsonSchema: promptResult.output?.schema, + }, + config: { + ...options.config, + ...promptResult.config, + ...opt.config, + }, + model, + } as GenerateOptions; + delete (resultOptions as PromptGenerateOptions).input; + return resultOptions; }; + (executablePrompt as ExecutablePrompt).asTool = + (): ToolAction => { + return p as unknown as ToolAction; + }; return executablePrompt as ExecutablePrompt; } @@ -689,9 +490,7 @@ export class Genkit { }, runner: RetrieverFn ): RetrieverAction { - return runWithRegistry(this.registry, () => - defineRetriever(options, runner) - ); + return defineRetriever(this.registry, options, runner); } /** @@ -706,9 +505,7 @@ export class Genkit { options: SimpleRetrieverOptions, handler: (query: Document, config: z.infer) => Promise ): RetrieverAction { - return runWithRegistry(this.registry, () => - defineSimpleRetriever(options, handler) - ); + return defineSimpleRetriever(this.registry, options, handler); } /** @@ -722,7 +519,7 @@ export class Genkit { }, runner: IndexerFn ): IndexerAction { - return runWithRegistry(this.registry, () => defineIndexer(options, runner)); + return defineIndexer(this.registry, options, runner); } /** @@ -744,9 +541,7 @@ export class Genkit { }, runner: EvaluatorFn ): EvaluatorAction { - return runWithRegistry(this.registry, () => - defineEvaluator(options, runner) - ); + return defineEvaluator(this.registry, options, runner); } /** @@ -760,23 +555,21 @@ export class Genkit { }, runner: EmbedderFn ): EmbedderAction { - return runWithRegistry(this.registry, () => - defineEmbedder(options, runner) - ); + return defineEmbedder(this.registry, options, runner); } /** * create a handlebards helper (https://handlebarsjs.com/guide/block-helpers.html) to be used in dotpormpt templates. */ defineHelper(name: string, fn: Handlebars.HelperDelegate) { - return runWithRegistry(this.registry, () => defineHelper(name, fn)); + return defineHelper(name, fn); } /** * Creates a handlebars partial (https://handlebarsjs.com/guide/partials.html) to be used in dotpormpt templates. */ definePartial(name: string, source: string) { - return runWithRegistry(this.registry, () => definePartial(name, source)); + return definePartial(name, source); } /** @@ -790,9 +583,7 @@ export class Genkit { }, runner: RerankerFn ) { - return runWithRegistry(this.registry, () => - defineReranker(options, runner) - ); + return defineReranker(this.registry, options, runner); } /** @@ -801,7 +592,7 @@ export class Genkit { embed( params: EmbedderParams ): Promise { - return runWithRegistry(this.registry, () => embed(params)); + return embed(this.registry, params); } /** @@ -813,7 +604,7 @@ export class Genkit { metadata?: Record; options?: z.infer; }): Promise { - return runWithRegistry(this.registry, () => embedMany(params)); + return embedMany(this.registry, params); } /** @@ -823,7 +614,7 @@ export class Genkit { DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >(params: EvaluatorParams): Promise { - return runWithRegistry(this.registry, () => evaluate(params)); + return evaluate(this.registry, params); } /** @@ -832,7 +623,7 @@ export class Genkit { rerank( params: RerankerParams ): Promise> { - return runWithRegistry(this.registry, () => rerank(params)); + return rerank(this.registry, params); } /** @@ -841,7 +632,7 @@ export class Genkit { index( params: IndexerParams ): Promise { - return runWithRegistry(this.registry, () => index(params)); + return index(this.registry, params); } /** @@ -850,7 +641,7 @@ export class Genkit { retrieve( params: RetrieverParams ): Promise> { - return runWithRegistry(this.registry, () => retrieve(params)); + return retrieve(this.registry, params); } /** @@ -917,7 +708,7 @@ export class Genkit { O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, >( - parts: + opts: | GenerateOptions | PromiseLike> ): Promise>>; @@ -945,7 +736,7 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => generate(resolvedOptions)); + return generate(this.registry, resolvedOptions); } /** @@ -1052,9 +843,7 @@ export class Genkit { if (!resolvedOptions.model) { resolvedOptions.model = this.options.model; } - return runWithRegistry(this.registry, () => - generateStream(resolvedOptions) - ); + return generateStream(this.registry, resolvedOptions); } /** @@ -1068,23 +857,20 @@ export class Genkit { * response = await chat.send('another one') * ``` */ - async chat(options?: ChatOptions): Promise { - const session = await this.createSession(); + chat(options?: ChatOptions): Chat { + const session = this.createSession(); return session.chat(options); } /** * Create a session for this environment. */ - async createSession(options?: SessionOptions): Promise { + createSession(options?: SessionOptions): Session { const sessionId = uuidv4(); const sessionData: SessionData = { id: sessionId, - state: options?.state, + state: options?.initialState, }; - if (options?.store) { - await options.store.save(sessionId, sessionData); - } return new Session(this, { id: sessionId, sessionData, @@ -1132,9 +918,7 @@ export class Genkit { const plugins = [...(this.options.plugins ?? [])]; if (this.options.promptDir !== null) { const dotprompt = genkitPlugin('dotprompt', async (ai) => { - runWithRegistry(ai.registry, async () => - loadPromptFolder(this.options.promptDir ?? './prompts') - ); + loadPromptFolder(this.registry, this.options.promptDir ?? './prompts'); }); plugins.push(dotprompt); } @@ -1145,9 +929,7 @@ export class Genkit { name: loadedPlugin.name, async initializer() { logger.debug(`Initializing plugin ${loadedPlugin.name}:`); - return runWithRegistry(activeRegistry, () => - loadedPlugin.initializer() - ); + loadedPlugin.initializer(); }, }); }); @@ -1172,12 +954,16 @@ export class Genkit { return this.resolveModel(this.options.model); } if (typeof modelArg === 'string') { - return (await lookupAction(`/model/${modelArg}`)) as ModelAction; - } else if (modelArg.hasOwnProperty('name')) { - const ref = modelArg as ModelReference; - return (await lookupAction(`/model/${ref.name}`)) as ModelAction; - } else { + return (await this.registry.lookupAction( + `/model/${modelArg}` + )) as ModelAction; + } else if ((modelArg as ModelAction).__action) { return modelArg as ModelAction; + } else { + const ref = modelArg as ModelReference; + return (await this.registry.lookupAction( + `/model/${ref.name}` + )) as ModelAction; } } } diff --git a/js/genkit/src/registry.ts b/js/genkit/src/registry.ts index 0dab20e68..8c45c10d4 100644 --- a/js/genkit/src/registry.ts +++ b/js/genkit/src/registry.ts @@ -19,15 +19,4 @@ export { AsyncProvider, Registry, Schema, - getRegistryInstance, - initializeAllPlugins, - initializePlugin, - listActions, - lookupAction, - lookupPlugin, - lookupSchema, - registerAction, - registerPluginProvider, - registerSchema, - runWithRegistry, } from '@genkit-ai/core/registry'; diff --git a/js/genkit/src/session.ts b/js/genkit/src/session.ts index 20a7557b3..78824e15b 100644 --- a/js/genkit/src/session.ts +++ b/js/genkit/src/session.ts @@ -18,7 +18,7 @@ import { GenerateOptions, MessageData } from '@genkit-ai/ai'; import { z } from '@genkit-ai/core'; import { AsyncLocalStorage } from 'node:async_hooks'; import { v4 as uuidv4 } from 'uuid'; -import { Chat, ChatOptions, MAIN_THREAD } from './chat'; +import { Chat, ChatOptions, MAIN_THREAD, PromptRenderOptions } from './chat'; import { Genkit } from './genkit'; export type BaseGenerateOptions = Omit; @@ -29,7 +29,7 @@ export interface SessionOptions { /** Session store implementation for persisting the session state. */ store?: SessionStore; /** Initial state of the session. */ - state?: z.infer; + initialState?: z.infer; /** Custom session Id. */ sessionId?: string; } @@ -53,7 +53,7 @@ export class Session { private store: SessionStore; constructor( - readonly parent: Genkit, + readonly genkit: Genkit, options?: { id?: string; stateSchema?: S; @@ -63,7 +63,9 @@ export class Session { ) { this.id = options?.id ?? uuidv4(); this.schema = options?.stateSchema; - this.sessionData = options?.sessionData; + this.sessionData = options?.sessionData ?? { + id: this.id, + }; if (!this.sessionData) { this.sessionData = { id: this.id }; } @@ -73,24 +75,19 @@ export class Session { this.store = options?.store ?? new InMemorySessionStore(); } - get genkit(): Genkit { - return this.parent; - } - get state(): z.infer { // We always get state from the parent. Parent session is the source of truth. - if (this.parent instanceof Session) { - return this.parent.state; + if (this.genkit instanceof Session) { + return this.genkit.state; } return this.sessionData!.state; } + /** + * Update session state data. + */ async updateState(data: z.infer): Promise { - // We always update the state on the parent. Parent session is the source of truth. - if (this.parent instanceof Session) { - return this.parent.updateState(data); - } - let sessionData = await this.store.get(this.id); + let sessionData = this.sessionData; if (!sessionData) { sessionData = {} as SessionData; } @@ -100,6 +97,26 @@ export class Session { await this.store.save(this.id, sessionData); } + /** + * Update messages for a given thread. + */ + async updateMessages( + thread: string, + messasges: MessageData[] + ): Promise { + let sessionData = this.sessionData; + if (!sessionData) { + sessionData = {} as SessionData; + } + if (!sessionData.threads) { + sessionData.threads = {}; + } + sessionData.threads[thread] = messasges; + this.sessionData = sessionData; + + await this.store.save(this.id, sessionData); + } + /** * Create a chat session with the provided options. * @@ -111,9 +128,7 @@ export class Session { * response = await chat.send('another one') * ``` */ - chat( - options?: ChatOptions - ): Chat; + chat(options?: ChatOptions): Chat; /** * Craete a separaete chat conversation ("thread") within the same session state. @@ -129,14 +144,11 @@ export class Session { * await pirateChat.send('tell me a joke') * ``` */ - chat( - threadName: string, - options?: ChatOptions - ): Chat; - - chat( - optionsOrThreadName?: ChatOptions | string, - maybeOptions?: ChatOptions + chat(threadName: string, options?: ChatOptions): Chat; + + chat( + optionsOrThreadName?: ChatOptions | string, + maybeOptions?: ChatOptions ): Chat { let options: ChatOptions | undefined; let threadName = MAIN_THREAD; @@ -150,18 +162,22 @@ export class Session { options = optionsOrThreadName as ChatOptions; } } - return new Chat( - this, - { - ...options, - }, - { - thread: threadName, - id: this.id, - sessionData: this.sessionData, - store: this.store ?? options?.store, - } - ); + let requestBase: Promise; + if (!!(options as PromptRenderOptions)?.prompt?.render) { + const renderOptions = options as PromptRenderOptions; + requestBase = renderOptions.prompt.render({ + input: renderOptions.input, + }); + } else { + requestBase = Promise.resolve(options as BaseGenerateOptions); + } + return new Chat(this, requestBase, { + thread: threadName, + id: this.id, + messages: + (this.sessionData?.threads && this.sessionData?.threads[threadName]) ?? + [], + }); } toJSON() { diff --git a/js/genkit/tests/chat_test.ts b/js/genkit/tests/chat_test.ts index a00eda1f9..f6e28eb74 100644 --- a/js/genkit/tests/chat_test.ts +++ b/js/genkit/tests/chat_test.ts @@ -30,7 +30,7 @@ describe('session', () => { }); it('maintains history in the session', async () => { - const session = await ai.chat(); + const session = ai.chat(); let response = await session.send('hi'); assert.strictEqual(response.text, 'Echo: hi; config: {}'); @@ -59,8 +59,8 @@ describe('session', () => { }); it('maintains history in the session with streaming', async () => { - const session = await ai.chat(); - let { response, stream } = await session.sendStream('hi'); + const chat = ai.chat(); + let { response, stream } = await chat.sendStream('hi'); let chunks: string[] = []; for await (const chunk of stream) { @@ -69,7 +69,7 @@ describe('session', () => { assert.strictEqual((await response).text, 'Echo: hi; config: {}'); assert.deepStrictEqual(chunks, ['3', '2', '1']); - ({ response, stream } = await session.sendStream('bye')); + ({ response, stream } = await chat.sendStream('bye')); chunks = []; for await (const chunk of stream) { @@ -100,6 +100,7 @@ describe('session', () => { it('can init a session with a prompt', async () => { const prompt = ai.definePrompt({ name: 'hi' }, 'hi {{ name }}'); + const session = await ai.chat( await prompt.render({ input: { name: 'Genkit' }, @@ -114,12 +115,29 @@ describe('session', () => { ); }); - it('can send a prompt session to a session', async () => { + it('can start chat from a prompt', async () => { + const prompt = ai.definePrompt( + { name: 'hi', config: { version: 'abc' } }, + 'hi {{ name }} from template' + ); + const session = await ai.chat({ + prompt, + input: { name: 'Genkit' }, + }); + const response = await session.send('send it'); + + assert.strictEqual( + response.text, + 'Echo: hi Genkit from template,send it; config: {"version":"abc"}' + ); + }); + + it('can send a rendered prompt to chat', async () => { const prompt = ai.definePrompt( { name: 'hi', config: { version: 'abc' } }, 'hi {{ name }}' ); - const session = await ai.chat(); + const session = ai.chat(); const response = await session.send( await prompt.render({ input: { name: 'Genkit' }, diff --git a/js/genkit/tests/embed_test.ts b/js/genkit/tests/embed_test.ts new file mode 100644 index 000000000..18d940dde --- /dev/null +++ b/js/genkit/tests/embed_test.ts @@ -0,0 +1,140 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Document, EmbedderAction, embedderRef } from '@genkit-ai/ai'; +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; + +describe('embed', () => { + describe('default model', () => { + let ai: Genkit; + let embedder: EmbedderAction; + + beforeEach(() => { + ai = genkit({}); + embedder = defineTestEmbedder(ai); + }); + + it('passes string content as docs', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: 'hi', + }); + assert.deepStrictEqual((embedder as any).lastRequest, [ + [Document.fromText('hi')], + { + version: undefined, + }, + ]); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + }); + + it('passes docs content as docs', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: Document.fromText('hi'), + }); + assert.deepStrictEqual((embedder as any).lastRequest, [ + [Document.fromText('hi')], + { + version: undefined, + }, + ]); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + }); + }); + + describe('config', () => { + let ai: Genkit; + let embedder: EmbedderAction; + + beforeEach(() => { + ai = genkit({}); + embedder = defineTestEmbedder(ai); + }); + + it('takes config passed to generate', async () => { + const response = await ai.embed({ + embedder: 'echoEmbedder', + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: undefined, + }); + }); + + it('merges config from the ref', async () => { + const response = await ai.embed({ + embedder: embedderRef({ + name: 'echoEmbedder', + config: { + version: 'abc', + }, + }), + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: 'abc', + }); + }); + + it('picks up the top-level version from the ref', async () => { + const response = await ai.embed({ + embedder: embedderRef({ + name: 'echoEmbedder', + version: 'abc', + }), + content: 'hi', + options: { + temperature: 11, + }, + }); + assert.deepStrictEqual(response, [1, 2, 3, 4]); + assert.deepStrictEqual((embedder as any).lastRequest[1], { + temperature: 11, + version: 'abc', + }); + }); + }); +}); + +function defineTestEmbedder(ai: Genkit) { + const embedder = ai.defineEmbedder( + { name: 'echoEmbedder' }, + async (input, config) => { + (embedder as any).lastRequest = [input, config]; + return { + embeddings: [ + { + embedding: [1, 2, 3, 4], + }, + ], + }; + } + ); + return embedder; +} diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts new file mode 100644 index 000000000..1c63499f2 --- /dev/null +++ b/js/genkit/tests/generate_test.ts @@ -0,0 +1,158 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { modelRef } from '../../ai/src/model'; +import { Genkit, genkit } from '../src/genkit'; +import { defineEchoModel } from './helpers'; + +describe('generate', () => { + describe('default model', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({ + model: 'echoModel', + }); + defineEchoModel(ai); + }); + + it('calls the default model', async () => { + const response = await ai.generate({ + prompt: 'hi', + }); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model with just a string prompt', async () => { + const response = await ai.generate('hi'); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model with just parts prompt', async () => { + const response = await ai.generate([{ text: 'hi' }]); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + + it('calls the default model system', async () => { + const response = await ai.generate({ + prompt: 'hi', + system: 'talk like a pirate', + }); + assert.strictEqual( + response.text, + 'Echo: system: talk like a pirate,hi; config: {}' + ); + assert.deepStrictEqual(response.request, { + config: undefined, + docs: undefined, + messages: [ + { + role: 'system', + content: [{ text: 'talk like a pirate' }], + }, + { + role: 'user', + content: [{ text: 'hi' }], + }, + ], + output: { + format: 'text', + }, + tools: [], + }); + }); + + it('streams the default model', async () => { + const { response, stream } = await ai.generateStream('hi'); + + const chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk.text); + } + assert.strictEqual((await response).text, 'Echo: hi; config: {}'); + assert.deepStrictEqual(chunks, ['3', '2', '1']); + }); + }); + + describe('default model', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({}); + defineEchoModel(ai); + }); + + it('calls the explicitly passed in model', async () => { + const response = await ai.generate({ + model: 'echoModel', + prompt: 'hi', + }); + assert.strictEqual(response.text, 'Echo: hi; config: {}'); + }); + }); + + describe('config', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({}); + defineEchoModel(ai); + }); + + it('takes config passed to generate', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: 'echoModel', + config: { + temperature: 11, + }, + }); + assert.strictEqual(response.text, 'Echo: hi; config: {"temperature":11}'); + }); + + it('merges config from the ref', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: modelRef({ name: 'echoModel' }).withConfig({ + version: 'abc', + }), + config: { + temperature: 11, + }, + }); + assert.strictEqual( + response.text, + 'Echo: hi; config: {"version":"abc","temperature":11}' + ); + }); + + it('picks up the top-level version from the ref', async () => { + const response = await ai.generate({ + prompt: 'hi', + model: modelRef({ name: 'echoModel' }).withVersion('bcd'), + config: { + temperature: 11, + }, + }); + assert.strictEqual( + response.text, + 'Echo: hi; config: {"version":"bcd","temperature":11}' + ); + }); + }); +}); diff --git a/js/genkit/tests/models_test.ts b/js/genkit/tests/models_test.ts deleted file mode 100644 index 23298d6c0..000000000 --- a/js/genkit/tests/models_test.ts +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import assert from 'node:assert'; -import { beforeEach, describe, it } from 'node:test'; -import { Genkit, genkit } from '../src/genkit'; -import { defineEchoModel } from './helpers'; - -describe('models', () => { - describe('generate', () => { - describe('default model', () => { - let ai: Genkit; - - beforeEach(() => { - ai = genkit({ - model: 'echoModel', - }); - defineEchoModel(ai); - }); - - it('calls the default model', async () => { - const response = await ai.generate({ - prompt: 'hi', - }); - assert.strictEqual(response.text, 'Echo: hi; config: {}'); - }); - - it('calls the default model with just a string prompt', async () => { - const response = await ai.generate('hi'); - assert.strictEqual(response.text, 'Echo: hi; config: {}'); - }); - - it('calls the default model with just parts prompt', async () => { - const response = await ai.generate([{ text: 'hi' }]); - assert.strictEqual(response.text, 'Echo: hi; config: {}'); - }); - - it('calls the default model system', async () => { - const response = await ai.generate({ - prompt: 'hi', - system: 'talk like a pirate', - }); - assert.strictEqual( - response.text, - 'Echo: system: talk like a pirate,hi; config: {}' - ); - assert.deepStrictEqual(response.request, { - config: undefined, - context: undefined, - messages: [ - { - role: 'system', - content: [{ text: 'talk like a pirate' }], - }, - { - role: 'user', - content: [{ text: 'hi' }], - }, - ], - output: { - format: 'text', - }, - tools: [], - }); - }); - - it('streams the default model', async () => { - const { response, stream } = await ai.generateStream('hi'); - - const chunks: string[] = []; - for await (const chunk of stream) { - chunks.push(chunk.text); - } - assert.strictEqual((await response).text, 'Echo: hi; config: {}'); - assert.deepStrictEqual(chunks, ['3', '2', '1']); - }); - }); - - describe('default model', () => { - let ai: Genkit; - - beforeEach(() => { - ai = genkit({}); - defineEchoModel(ai); - }); - - it('calls the explicitly passed in model', async () => { - const response = await ai.generate({ - model: 'echoModel', - prompt: 'hi', - }); - assert.strictEqual(response.text, 'Echo: hi; config: {}'); - }); - }); - }); -}); diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index db77d44de..627c400f3 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -340,21 +340,21 @@ describe('definePrompt - dotprompt', () => { assert.deepStrictEqual(response, { config: {}, docs: undefined, - messages: [], - prompt: [ + messages: [ { - text: 'hi Genkit', + content: [ + { + text: 'hi Genkit', + }, + ], + role: 'user', }, ], output: { - format: undefined, + format: 'text', jsonSchema: undefined, - schema: undefined, }, - returnToolRequests: undefined, - streamingCallback: undefined, tools: [], - use: undefined, }); }); }); @@ -609,7 +609,9 @@ describe('definePrompt', () => { const response = await hi( { name: 'Genkit' }, { - version: 'abc', + config: { + version: 'abc', + }, } ); assert.strictEqual( @@ -716,9 +718,6 @@ describe('definePrompt', () => { assert.deepStrictEqual(response, { config: {}, docs: undefined, - input: { - name: 'Genkit', - }, messages: [ { content: [ diff --git a/js/genkit/tests/session_test.ts b/js/genkit/tests/session_test.ts index 2c77a18b0..be2ff5ada 100644 --- a/js/genkit/tests/session_test.ts +++ b/js/genkit/tests/session_test.ts @@ -30,7 +30,7 @@ describe('session', () => { }); it('maintains history in the session', async () => { - const session = await ai.createSession(); + const session = ai.createSession(); const chat = session.chat(); let response = await chat.send('hi'); @@ -61,15 +61,15 @@ describe('session', () => { it('maintains multithreaded history in the session', async () => { const store = new TestMemorySessionStore(); - const session = await ai.createSession({ + const session = ai.createSession({ store, - - state: { + initialState: { name: 'Genkit', }, }); - let response = await session.chat().send('hi main'); + let mainChat = session.chat(); + let response = await mainChat.send('hi main'); assert.strictEqual(response.text, 'Echo: hi main; config: {}'); const lawyerChat = session.chat('lawyerChat', { @@ -131,7 +131,7 @@ describe('session', () => { }); it('maintains history in the session with streaming', async () => { - const session = await ai.createSession(); + const session = ai.createSession(); const chat = session.chat(); let { response, stream } = await chat.sendStream('hi'); @@ -174,39 +174,40 @@ describe('session', () => { it('stores state and messages in the store', async () => { const store = new TestMemorySessionStore(); - const session = await ai.createSession({ + const session = ai.createSession({ store, + initialState: { + foo: 'bar', + }, }); - const initialState = await store.get(session.id); - delete initialState.id; // ignore - assert.deepStrictEqual(initialState, { - state: undefined, - threads: {}, - }); - const chat = session.chat(); await chat.send('hi'); await chat.send('bye'); const state = await store.get(session.id); - - assert.deepStrictEqual(state?.threads, { - main: [ - { content: [{ text: 'hi' }], role: 'user' }, - { - content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], - role: 'model', - }, - { content: [{ text: 'bye' }], role: 'user' }, - { - content: [ - { text: 'Echo: hi,Echo: hi,; config: {},bye' }, - { text: '; config: {}' }, - ], - role: 'model', - }, - ], + delete state.id; + assert.deepStrictEqual(state, { + state: { + foo: 'bar', + }, + threads: { + main: [ + { content: [{ text: 'hi' }], role: 'user' }, + { + content: [{ text: 'Echo: hi' }, { text: '; config: {}' }], + role: 'model', + }, + { content: [{ text: 'bye' }], role: 'user' }, + { + content: [ + { text: 'Echo: hi,Echo: hi,; config: {},bye' }, + { text: '; config: {}' }, + ], + role: 'model', + }, + ], + }, }); }); @@ -214,8 +215,8 @@ describe('session', () => { it('loads session from store', async () => { const store = new TestMemorySessionStore(); // init the store - const originalSession = await ai.createSession({ store }); - const originalMainChat = await originalSession.chat({ + const originalSession = ai.createSession({ store }); + const originalMainChat = originalSession.chat({ config: { temperature: 1, }, @@ -227,7 +228,7 @@ describe('session', () => { // load const session = await ai.loadSession(sessionId, { store }); - const mainChat = await session.chat(); + const mainChat = session.chat(); assert.deepStrictEqual(mainChat.messages, [ { content: [{ text: 'hi' }], role: 'user' }, { diff --git a/js/package.json b/js/package.json index bf2aac4ea..368fbb71d 100644 --- a/js/package.json +++ b/js/package.json @@ -22,5 +22,5 @@ "only-allow": "^1.2.1", "typescript": "^4.9.0" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/js/plugins/chroma/package.json b/js/plugins/chroma/package.json index c14cb31fc..14a79d909 100644 --- a/js/plugins/chroma/package.json +++ b/js/plugins/chroma/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/dev-local-vectorstore/package.json b/js/plugins/dev-local-vectorstore/package.json index 63ad4a706..deb3a56af 100644 --- a/js/plugins/dev-local-vectorstore/package.json +++ b/js/plugins/dev-local-vectorstore/package.json @@ -10,7 +10,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/dotprompt/package.json b/js/plugins/dotprompt/package.json index 2a65c33c8..f069e57d1 100644 --- a/js/plugins/dotprompt/package.json +++ b/js/plugins/dotprompt/package.json @@ -9,7 +9,7 @@ "prompting", "templating" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/dotprompt/src/index.ts b/js/plugins/dotprompt/src/index.ts index 85a12e6e2..f6c8d6072 100644 --- a/js/plugins/dotprompt/src/index.ts +++ b/js/plugins/dotprompt/src/index.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { Registry } from '@genkit-ai/core/registry'; import { readFileSync } from 'fs'; import { basename } from 'path'; import { @@ -43,10 +44,15 @@ export interface DotpromptPluginOptions { } export async function prompt( + registry: Registry, name: string, options?: { variant?: string } ): Promise> { - return (await lookupPrompt(name, options?.variant)) as Dotprompt; + return (await lookupPrompt( + registry, + name, + options?.variant + )) as Dotprompt; } export function promptRef( @@ -56,19 +62,22 @@ export function promptRef( return new DotpromptRef(name, options); } -export function loadPromptFile(path: string): Dotprompt { +export function loadPromptFile(registry: Registry, path: string): Dotprompt { return Dotprompt.parse( + registry, basename(path).split('.')[0], readFileSync(path, 'utf-8') ); } export async function loadPromptUrl( + registry: Registry, + name: string, url: string ): Promise { const fetch = (await import('node-fetch')).default; const response = await fetch(url); const text = await response.text(); - return Dotprompt.parse(name, text); + return Dotprompt.parse(registry, name, text); } diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 79e59e122..165176919 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -25,7 +25,7 @@ import { } from '@genkit-ai/ai/model'; import { ToolArgument } from '@genkit-ai/ai/tool'; import { z } from '@genkit-ai/core'; -import { lookupSchema } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { JSONSchema, parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { picoschema } from './picoschema.js'; @@ -39,6 +39,9 @@ export interface PromptMetadata< /** The name of the prompt. */ name?: string; + /** Description (intent) of the prompt, used when prompt passed as tool to an LLM. */ + description?: string; + /** The variant name for the prompt. */ variant?: string; @@ -119,27 +122,33 @@ function stripUndefinedOrNull(obj: any) { return obj; } -function fmSchemaToSchema(fmSchema: any) { +function fmSchemaToSchema(registry: Registry, fmSchema: any) { if (!fmSchema) return {}; - if (typeof fmSchema === 'string') return lookupSchema(fmSchema); + if (typeof fmSchema === 'string') return registry.lookupSchema(fmSchema); return { jsonSchema: picoschema(fmSchema) }; } -export function toMetadata(attributes: unknown): Partial { +export function toMetadata( + registry: Registry, + attributes: unknown +): Partial { const fm = parseSchema>(attributes, { schema: PromptFrontmatterSchema, }); let input: PromptMetadata['input'] | undefined; if (fm.input) { - input = { default: fm.input.default, ...fmSchemaToSchema(fm.input.schema) }; + input = { + default: fm.input.default, + ...fmSchemaToSchema(registry, fm.input.schema), + }; } let output: PromptMetadata['output'] | undefined; if (fm.output) { output = { format: fm.output.format, - ...fmSchemaToSchema(fm.output.schema), + ...fmSchemaToSchema(registry, fm.output.schema), }; } diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 3ae4e04c6..b216c21a8 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -27,6 +27,7 @@ import { import { MessageData, ModelArgument } from '@genkit-ai/ai/model'; import { DocumentData } from '@genkit-ai/ai/retriever'; import { GenkitError, z } from '@genkit-ai/core'; +import { Registry } from '@genkit-ai/core/registry'; import { parseSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, @@ -49,7 +50,10 @@ export type PromptData = PromptFrontmatter & { template: string }; export type PromptGenerateOptions< V = unknown, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> = Omit, 'prompt' | 'model'> & { +> = Omit< + GenerateOptions, + 'prompt' | 'input' | 'model' +> & { model?: ModelArgument; input?: V; }; @@ -73,16 +77,22 @@ export class Dotprompt implements PromptMetadata { tools?: PromptMetadata['tools']; config?: PromptMetadata['config']; + private _promptAction?: PromptAction; + private _render: (input: I, options?: RenderMetadata) => MessageData[]; - static parse(name: string, source: string) { + static parse(registry: Registry, name: string, source: string) { try { const fmResult = (fm as any)(source.trimStart(), { allowUnsafe: false, }) as FrontMatterResult; return new Dotprompt( - { ...toMetadata(fmResult.attributes), name } as PromptMetadata, + registry, + { + ...toMetadata(registry, fmResult.attributes), + name, + } as PromptMetadata, fmResult.body ); } catch (e: any) { @@ -94,7 +104,7 @@ export class Dotprompt implements PromptMetadata { } } - static fromAction(action: PromptAction): Dotprompt { + static fromAction(registry: Registry, action: PromptAction): Dotprompt { const { template, ...options } = action.__action.metadata!.prompt; const pm = options as PromptMetadata; if (pm.input?.schema) { @@ -104,11 +114,15 @@ export class Dotprompt implements PromptMetadata { if (pm.output?.schema) { pm.output.jsonSchema = options.output?.schema; } - const prompt = new Dotprompt(options as PromptMetadata, template); + const prompt = new Dotprompt(registry, options as PromptMetadata, template); return prompt; } - constructor(options: PromptMetadata, template: string) { + constructor( + private registry: Registry, + options: PromptMetadata, + template: string + ) { this.name = options.name || 'untitledPrompt'; this.variant = options.variant; this.model = options.model; @@ -164,11 +178,12 @@ export class Dotprompt implements PromptMetadata { return { ...toFrontmatter(this), template: this.template }; } - define(options?: { ns: string }): void { - definePrompt( + define(options?: { ns?: string; description?: string }): void { + this._promptAction = definePrompt( + this.registry, { name: registryDefinitionKey(this.name, this.variant, options?.ns), - description: 'Defined by Dotprompt', + description: options?.description ?? 'Defined by Dotprompt', inputSchema: this.input?.schema, inputJsonSchema: this.input?.jsonSchema, metadata: { @@ -176,10 +191,15 @@ export class Dotprompt implements PromptMetadata { prompt: this.toJSON(), }, }, - async (input?: I) => toGenerateRequest(this.render({ input })) + async (input?: I) => + toGenerateRequest(this.registry, this.render({ input })) ); } + get promptAction(): PromptAction | undefined { + return this._promptAction; + } + private _generateOptions< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, @@ -188,11 +208,20 @@ export class Dotprompt implements PromptMetadata { messages: options.messages, docs: options.docs, }); + let renderedPrompt; + let renderedMessages; + if (messages.length > 0 && messages[messages.length - 1].role === 'user') { + renderedPrompt = messages[messages.length - 1].content; + renderedMessages = messages.slice(0, messages.length - 1); + } else { + renderedPrompt = undefined; + renderedMessages = messages; + } return { model: options.model || this.model!, config: { ...this.config, ...options.config }, - messages: messages.slice(0, messages.length - 1), - prompt: messages[messages.length - 1].content, + messages: renderedMessages, + prompt: renderedPrompt, docs: options.docs, output: { format: options.output?.format || this.output?.format || undefined, @@ -258,7 +287,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise>> { const renderedOpts = this.renderInNewSpan(opt); - return generate(renderedOpts); + return generate(this.registry, renderedOpts); } /** @@ -271,7 +300,7 @@ export class Dotprompt implements PromptMetadata { opt: PromptGenerateOptions ): Promise { const renderedOpts = await this.renderInNewSpan(opt); - return generateStream(renderedOpts); + return generateStream(this.registry, renderedOpts); } } @@ -294,9 +323,10 @@ export class DotpromptRef { } /** Loads the prompt which is referenced. */ - async loadPrompt(): Promise> { + async loadPrompt(registry: Registry): Promise> { if (this._prompt) return this._prompt; this._prompt = (await lookupPrompt( + registry, this.name, this.variant, this.dir @@ -315,9 +345,10 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, opt: PromptGenerateOptions ): Promise>> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.generate(opt); } @@ -331,9 +362,11 @@ export class DotpromptRef { CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, + opt: PromptGenerateOptions ): Promise> { - const prompt = await this.loadPrompt(); + const prompt = await this.loadPrompt(registry); return prompt.render(opt); } } @@ -349,10 +382,11 @@ export function defineDotprompt< I extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, >( + registry: Registry, options: PromptMetadata, template: string ): Dotprompt> { - const prompt = new Dotprompt(options, template); - prompt.define(); + const prompt = new Dotprompt(registry, options, template); + prompt.define({ description: options.description }); return prompt; } diff --git a/js/plugins/dotprompt/src/registry.ts b/js/plugins/dotprompt/src/registry.ts index 3397f6b56..f0af18eec 100644 --- a/js/plugins/dotprompt/src/registry.ts +++ b/js/plugins/dotprompt/src/registry.ts @@ -17,7 +17,7 @@ import { PromptAction } from '@genkit-ai/ai'; import { GenkitError } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { existsSync, readdir, readFileSync } from 'fs'; import { basename, join, resolve } from 'path'; import { Dotprompt } from './prompt.js'; @@ -37,23 +37,27 @@ export function registryLookupKey(name: string, variant?: string, ns?: string) { } export async function lookupPrompt( + registry: Registry, name: string, variant?: string, dir: string = './prompts' ): Promise { let registryPrompt = - (await lookupAction(registryLookupKey(name, variant))) || - (await lookupAction(registryLookupKey(name, variant, 'dotprompt'))); + (await registry.lookupAction(registryLookupKey(name, variant))) || + (await registry.lookupAction( + registryLookupKey(name, variant, 'dotprompt') + )); if (registryPrompt) { - return Dotprompt.fromAction(registryPrompt as PromptAction); + return Dotprompt.fromAction(registry, registryPrompt as PromptAction); } else { // Handle the case where initialization isn't complete // or a file was added after the prompt folder was loaded. - return maybeLoadPrompt(dir, name, variant); + return maybeLoadPrompt(registry, dir, name, variant); } } async function maybeLoadPrompt( + registry: Registry, dir: string, name: string, variant?: string @@ -62,7 +66,7 @@ async function maybeLoadPrompt( const promptFolder = resolve(dir); const promptExists = existsSync(join(promptFolder, expectedFileName)); if (promptExists) { - return loadPrompt(promptFolder, expectedFileName); + return loadPrompt(registry, promptFolder, expectedFileName); } else { throw new GenkitError({ source: 'dotprompt', @@ -73,6 +77,8 @@ async function maybeLoadPrompt( } export async function loadPromptFolder( + registry: Registry, + dir: string = './prompts' ): Promise { const promptsPath = resolve(dir); @@ -114,7 +120,7 @@ export async function loadPromptFolder( .replace(`${promptsPath}/`, '') .replace(/\//g, '-'); } - loadPrompt(dirEnt.path, dirEnt.name, prefix); + loadPrompt(registry, dirEnt.path, dirEnt.name, prefix); } } }); @@ -129,6 +135,7 @@ export async function loadPromptFolder( } export function loadPrompt( + registry: Registry, path: string, filename: string, prefix = '' @@ -141,7 +148,7 @@ export function loadPrompt( variant = parts[1]; } const source = readFileSync(join(path, filename), 'utf8'); - const prompt = Dotprompt.parse(name, source); + const prompt = Dotprompt.parse(registry, name, source); if (variant) { prompt.variant = variant; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index 39857923c..0e6d31e1a 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -16,7 +16,7 @@ import { defineModel, ModelAction } from '@genkit-ai/ai/model'; import { z } from '@genkit-ai/core'; -import { Registry, runWithRegistry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { defineJsonSchema, defineSchema, @@ -29,11 +29,12 @@ import { defineDotprompt, Dotprompt, prompt, promptRef } from '../src/index.js'; import { PromptMetadata } from '../src/metadata.js'; function testPrompt( + registry: Registry, model: ModelAction, template: string, options?: Partial ): Dotprompt { - return new Dotprompt({ name: 'test', model, ...options }, template); + return new Dotprompt(registry, { name: 'test', model, ...options }, template); } describe('Prompt', () => { @@ -44,184 +45,194 @@ describe('Prompt', () => { describe('#render', () => { it('should render variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const rendered = prompt.render({ input: { name: 'Michael' } }); - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hello Michael, how are you?' }, - ]); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const rendered = prompt.render({ input: { name: 'Michael' } }); + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hello Michael, how are you?' }, + ]); }); it('should render default variables', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`, { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?`, + { input: { default: { name: 'Fellow Human' } }, - }); + } + ); - const rendered = prompt.render({ input: {} }); - assert.deepStrictEqual(rendered.prompt, [ - { - text: 'Hello Fellow Human, how are you?', - }, - ]); - }); + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.prompt, [ + { + text: 'Hello Fellow Human, how are you?', + }, + ]); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - invalidSchemaPrompt.render({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + invalidSchemaPrompt.render({ input: { foo: 'baz' } }); + }, ValidationError); }); it('should render with overridden fields', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - - const streamingCallback = (c) => console.log(c); - const middleware = []; - - const rendered = prompt.render({ - input: { name: 'Michael' }, - streamingCallback, - returnToolRequests: true, - use: middleware, - }); - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - assert.strictEqual(rendered.use, middleware); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + + const streamingCallback = (c) => console.log(c); + const middleware = []; + + const rendered = prompt.render({ + input: { name: 'Michael' }, + streamingCallback, + returnToolRequests: true, + use: middleware, }); + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); it('should support system prompt with history', () => { - runWithRegistry(registry, () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt( - model, - `{{ role "system" }}Testing system {{name}}` - ); - - const rendered = prompt.render({ - input: { name: 'Michael' }, - messages: [ - { role: 'user', content: [{ text: 'history 1' }] }, - { role: 'model', content: [{ text: 'history 2' }] }, - { role: 'user', content: [{ text: 'history 3' }] }, - ], - }); - assert.deepStrictEqual(rendered.messages, [ - { role: 'system', content: [{ text: 'Testing system Michael' }] }, + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `{{ role "system" }}Testing system {{name}}` + ); + + const rendered = prompt.render({ + input: { name: 'Michael' }, + messages: [ { role: 'user', content: [{ text: 'history 1' }] }, { role: 'model', content: [{ text: 'history 2' }] }, - ]); - assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); + { role: 'user', content: [{ text: 'history 3' }] }, + ], }); + assert.deepStrictEqual(rendered.messages, [ + { role: 'system', content: [{ text: 'Testing system Michael' }] }, + { role: 'user', content: [{ text: 'history 1' }] }, + { role: 'model', content: [{ text: 'history 2' }] }, + ]); + assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); }); }); describe('#generate', () => { it('renders and calls the model', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `Hello {{name}}, how are you?`); - const response = await prompt.generate({ input: { name: 'Bob' } }); - assert.equal(response.text, `Hello Bob, how are you?`); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt( + registry, + model, + `Hello {{name}}, how are you?` + ); + const response = await prompt.generate({ input: { name: 'Bob' } }); + assert.equal(response.text, `Hello Bob, how are you?`); }); it('rejects input not matching the schema', async () => { - await runWithRegistry(registry, async () => { - const invalidSchemaPrompt = defineDotprompt( - { - name: 'invalidInput', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + const invalidSchemaPrompt = defineDotprompt( + registry, + { + name: 'invalidInput', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `You asked for {{foo}}.` - ); + }, + `You asked for {{foo}}.` + ); - await assert.rejects(async () => { - await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await invalidSchemaPrompt.generate({ input: { foo: 'baz' } }); + }, ValidationError); }); }); describe('#toJSON', () => { it('should convert zod to json schema', () => { - runWithRegistry(registry, () => { - const schema = z.object({ name: z.string() }); - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - const prompt = testPrompt(model, `hello {{name}}`, { - input: { schema }, - }); - - assert.deepStrictEqual( - prompt.toJSON().input?.schema, - toJsonSchema({ schema }) - ); + const schema = z.object({ name: z.string() }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `hello {{name}}`, { + input: { schema }, }); + + assert.deepStrictEqual( + prompt.toJSON().input?.schema, + toJsonSchema({ schema }) + ); }); }); @@ -230,6 +241,7 @@ describe('Prompt', () => { assert.throws( () => { Dotprompt.parse( + registry, 'example', `--- input: { @@ -247,6 +259,7 @@ This is the rest of the prompt` it('should parse picoschema', () => { const p = Dotprompt.parse( + registry, 'example', `--- input: @@ -277,54 +290,53 @@ output: }); it('should use registered schemas', () => { - runWithRegistry(registry, () => { - const MyInput = defineSchema('MyInput', z.number()); - defineJsonSchema('MyOutput', { type: 'boolean' }); + const MyInput = defineSchema(registry, 'MyInput', z.number()); + defineJsonSchema(registry, 'MyOutput', { type: 'boolean' }); - const p = Dotprompt.parse( - 'example2', - `--- + const p = Dotprompt.parse( + registry, + 'example2', + `--- input: schema: MyInput output: schema: MyOutput ---` - ); + ); - assert.deepEqual(p.input, { schema: MyInput }); - assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); - }); + assert.deepEqual(p.input, { schema: MyInput }); + assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); }); }); describe('defineDotprompt', () => { it('registers a prompt and its variant', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); - - defineDotprompt( - { - name: 'promptName', - variant: 'variantName', - model: 'echo', - }, - `And this is its variant.` - ); - - const basePrompt = await prompt('promptName'); - assert.equal('This is a prompt.', basePrompt.template); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const variantPrompt = await prompt('promptName', { + defineDotprompt( + registry, + { + name: 'promptName', variant: 'variantName', - }); - assert.equal('And this is its variant.', variantPrompt.template); + model: 'echo', + }, + `And this is its variant.` + ); + + const basePrompt = await prompt(registry, 'promptName'); + assert.equal('This is a prompt.', basePrompt.template); + + const variantPrompt = await prompt(registry, 'promptName', { + variant: 'variantName', }); + assert.equal('And this is its variant.', variantPrompt.template); }); }); }); @@ -336,138 +348,153 @@ describe('DotpromptRef', () => { }); it('Should load a prompt correctly', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'promptName', - model: 'echo', - }, - `This is a prompt.` - ); + defineDotprompt( + registry, + { + name: 'promptName', + model: 'echo', + }, + `This is a prompt.` + ); - const ref = promptRef('promptName'); + const ref = promptRef('promptName'); - const p = await ref.loadPrompt(); + const p = await ref.loadPrompt(registry); - const isDotprompt = p instanceof Dotprompt; + const isDotprompt = p instanceof Dotprompt; - assert.equal(isDotprompt, true); - assert.equal(p.template, 'This is a prompt.'); - }); + assert.equal(isDotprompt, true); + assert.equal(p.template, 'This is a prompt.'); }); it('Should generate output correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - const model = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - message: input.messages[0], - finishReason: 'stop', - }) - ); - defineDotprompt( - { - name: 'generatePrompt', - model: 'echo', - }, - `Hello {{name}}, this is a test prompt.` - ); - - const ref = promptRef('generatePrompt'); - const response = await ref.generate({ input: { name: 'Alice' } }); - - assert.equal(response.text, 'Hello Alice, this is a test prompt.'); - }); + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + defineDotprompt( + registry, + { + name: 'generatePrompt', + model: 'echo', + }, + `Hello {{name}}, this is a test prompt.` + ); + + const ref = promptRef('generatePrompt'); + const response = await ref.generate(registry, { input: { name: 'Alice' } }); + + assert.equal(response.text, 'Hello Alice, this is a test prompt.'); }); it('Should render correctly using DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'renderPrompt', - model: 'echo', - }, - `Hi {{name}}, welcome to the system.` - ); - - const ref = promptRef('renderPrompt'); - const rendered = await ref.render({ input: { name: 'Bob' } }); - - assert.deepStrictEqual(rendered.prompt, [ - { text: 'Hi Bob, welcome to the system.' }, - ]); - }); + defineDotprompt( + registry, + { + name: 'renderPrompt', + model: 'echo', + }, + `Hi {{name}}, welcome to the system.` + ); + + const ref = promptRef('renderPrompt'); + const rendered = await ref.render(registry, { input: { name: 'Bob' } }); + + assert.deepStrictEqual(rendered.prompt, [ + { text: 'Hi Bob, welcome to the system.' }, + ]); }); it('Should handle invalid schema input in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'invalidSchemaPromptRef', - model: 'echo', - input: { - jsonSchema: { - properties: { foo: { type: 'boolean' } }, - required: ['foo'], - }, + defineDotprompt( + registry, + { + name: 'invalidSchemaPromptRef', + model: 'echo', + input: { + jsonSchema: { + properties: { foo: { type: 'boolean' } }, + required: ['foo'], }, }, - `This is the prompt with foo={{foo}}.` - ); + }, + `This is the prompt with foo={{foo}}.` + ); - const ref = promptRef('invalidSchemaPromptRef'); + const ref = promptRef('invalidSchemaPromptRef'); - await assert.rejects(async () => { - await ref.generate({ input: { foo: 'not_a_boolean' } }); - }, ValidationError); - }); + await assert.rejects(async () => { + await ref.generate(registry, { input: { foo: 'not_a_boolean' } }); + }, ValidationError); }); it('Should support streamingCallback in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'streamingCallbackPrompt', - model: 'echo', - }, - `Hello {{name}}, streaming test.` - ); - - const ref = promptRef('streamingCallbackPrompt'); - - const streamingCallback = (chunk) => console.log(chunk); - const options = { - input: { name: 'Charlie' }, - streamingCallback, - returnToolRequests: true, - }; - - const rendered = await ref.render(options); - - assert.strictEqual(rendered.streamingCallback, streamingCallback); - assert.strictEqual(rendered.returnToolRequests, true); - }); + defineDotprompt( + registry, + { + name: 'streamingCallbackPrompt', + model: 'echo', + }, + `Hello {{name}}, streaming test.` + ); + + const ref = promptRef('streamingCallbackPrompt'); + + const streamingCallback = (chunk) => console.log(chunk); + const options = { + input: { name: 'Charlie' }, + streamingCallback, + returnToolRequests: true, + }; + + const rendered = await ref.render(registry, options); + + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); }); it('Should cache loaded prompt in DotpromptRef', async () => { - await runWithRegistry(registry, async () => { - defineDotprompt( - { - name: 'cacheTestPrompt', - model: 'echo', - }, - `This is a prompt for cache test.` - ); - - const ref = promptRef('cacheTestPrompt'); - const firstLoad = await ref.loadPrompt(); - const secondLoad = await ref.loadPrompt(); + defineDotprompt( + registry, + { + name: 'cacheTestPrompt', + model: 'echo', + }, + `This is a prompt for cache test.` + ); + + const ref = promptRef('cacheTestPrompt'); + const firstLoad = await ref.loadPrompt(registry); + const secondLoad = await ref.loadPrompt(registry); + + assert.strictEqual( + firstLoad, + secondLoad, + 'Loaded prompts should be identical (cached).' + ); + }); - assert.strictEqual( - firstLoad, - secondLoad, - 'Loaded prompts should be identical (cached).' - ); - }); + it('should render system prompt', () => { + const model = defineModel( + registry, + { name: 'echo', supports: { tools: true } }, + async (input) => ({ + message: input.messages[0], + finishReason: 'stop', + }) + ); + const prompt = testPrompt(registry, model, `{{ role "system"}} hi`); + + const rendered = prompt.render({ input: {} }); + assert.deepStrictEqual(rendered.messages, [ + { + content: [{ text: ' hi' }], + role: 'system', + }, + ]); }); }); diff --git a/js/plugins/evaluators/package.json b/js/plugins/evaluators/package.json index 1692268b4..82eee18c9 100644 --- a/js/plugins/evaluators/package.json +++ b/js/plugins/evaluators/package.json @@ -11,7 +11,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/evaluators/src/metrics/answer_relevancy.ts b/js/plugins/evaluators/src/metrics/answer_relevancy.ts index 74e91e743..833a5ce33 100644 --- a/js/plugins/evaluators/src/metrics/answer_relevancy.ts +++ b/js/plugins/evaluators/src/metrics/answer_relevancy.ts @@ -47,6 +47,7 @@ export async function answerRelevancyScore< throw new Error('Output was not provided'); } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/answer_relevancy.prompt') ); const response = await ai.generate({ diff --git a/js/plugins/evaluators/src/metrics/faithfulness.ts b/js/plugins/evaluators/src/metrics/faithfulness.ts index 3b3ed9e0e..244d0f10a 100644 --- a/js/plugins/evaluators/src/metrics/faithfulness.ts +++ b/js/plugins/evaluators/src/metrics/faithfulness.ts @@ -54,6 +54,7 @@ export async function faithfulnessScore< throw new Error('Output was not provided'); } const longFormPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_long_form.prompt') ); const longFormResponse = await ai.generate({ @@ -75,6 +76,7 @@ export async function faithfulnessScore< const allStatements = statements.map((s) => `statement: ${s}`).join('\n'); const allContext = context.join('\n'); const nliPrompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/faithfulness_nli.prompt') ); const response = await ai.generate({ diff --git a/js/plugins/evaluators/src/metrics/maliciousness.ts b/js/plugins/evaluators/src/metrics/maliciousness.ts index 048b7a9bb..5538cbc25 100644 --- a/js/plugins/evaluators/src/metrics/maliciousness.ts +++ b/js/plugins/evaluators/src/metrics/maliciousness.ts @@ -39,6 +39,7 @@ export async function maliciousnessScore< } const prompt = await loadPromptFile( + ai.registry, path.resolve(getDirName(), '../../prompts/maliciousness.prompt') ); //TODO: safetySettings are gemini specific - pull these out so they are tied to the LLM diff --git a/js/plugins/firebase/package.json b/js/plugins/firebase/package.json index 0dce7761d..cfce903c1 100644 --- a/js/plugins/firebase/package.json +++ b/js/plugins/firebase/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -37,8 +37,8 @@ }, "peerDependencies": { "@google-cloud/firestore": "^7.6.0", - "firebase-admin": "^12.2.0", - "firebase-functions": "^4.8.0 || ^5.0.0", + "firebase-admin": ">=12.2", + "firebase-functions": ">=4.8", "genkit": "workspace:*" }, "devDependencies": { diff --git a/js/plugins/firebase/src/functions.ts b/js/plugins/firebase/src/functions.ts index 1e0e64870..89248274f 100644 --- a/js/plugins/firebase/src/functions.ts +++ b/js/plugins/firebase/src/functions.ts @@ -131,7 +131,7 @@ function wrapHttpsFlow< } await config.authPolicy.provider(req, res, () => - flow.expressHandler(genkit.registry, req, res) + flow.expressHandler(req, res) ); } ); diff --git a/js/plugins/google-cloud/package.json b/js/plugins/google-cloud/package.json index b136bc19c..9ff63dd0d 100644 --- a/js/plugins/google-cloud/package.json +++ b/js/plugins/google-cloud/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/google-cloud/src/telemetry/action.ts b/js/plugins/google-cloud/src/telemetry/action.ts index 9eca330bd..6b2c0b131 100644 --- a/js/plugins/google-cloud/src/telemetry/action.ts +++ b/js/plugins/google-cloud/src/telemetry/action.ts @@ -55,8 +55,7 @@ class ActionTelemetry implements Telemetry { const actionName = (attributes['genkit:name'] as string) || ''; const path = (attributes['genkit:path'] as string) || ''; - let featureName = (attributes['genkit:metadata:flow:name'] || - extractOuterFeatureNameFromPath(path)) as string; + let featureName = extractOuterFeatureNameFromPath(path); if (!featureName || featureName === '') { featureName = actionName; } @@ -68,13 +67,11 @@ class ActionTelemetry implements Telemetry { if (state === 'success') { this.writeSuccess(actionName, featureName, path, latencyMs); - return; - } - if (state === 'error') { + } else if (state === 'error') { this.writeFailure(actionName, featureName, path, latencyMs, errorName); + } else { + logger.warn(`Unknown action state; ${state}`); } - - logger.warn(`Unknown action state; ${state}`); } private writeSuccess( diff --git a/js/plugins/google-cloud/tests/metrics_test.ts b/js/plugins/google-cloud/tests/metrics_test.ts index bb8e8821a..f9ba5c2cc 100644 --- a/js/plugins/google-cloud/tests/metrics_test.ts +++ b/js/plugins/google-cloud/tests/metrics_test.ts @@ -30,7 +30,6 @@ import { } from '@opentelemetry/sdk-metrics'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; import { GenerateResponseData, Genkit, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { SPAN_TYPE_ATTR, appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; import { after, before, beforeEach, describe, it } from 'node:test'; @@ -176,20 +175,20 @@ describe('GoogleCloudMetrics', () => { assert.equal(requestCounter.attributes.source, 'ts'); assert.equal(requestCounter.attributes.status, 'success'); assert.ok(requestCounter.attributes.sourceVersion); + assert.equal(requestCounter.attributes.featureName, 'testFlowWithActions'); assert.equal(latencyHistogram.value.count, 6); assert.equal(latencyHistogram.attributes.name, 'testAction'); assert.equal(latencyHistogram.attributes.source, 'ts'); assert.equal(latencyHistogram.attributes.status, 'success'); assert.ok(latencyHistogram.attributes.sourceVersion); + assert.equal(requestCounter.attributes.featureName, 'testFlowWithActions'); }); it('writes feature metrics for an action', async () => { const testAction = createAction(ai, 'featureAction'); - await runWithRegistry(ai.registry, async () => { - await testAction(null); - await testAction(null); - }); + await testAction(null); + await testAction(null); await getExportedSpans(); @@ -211,11 +210,9 @@ describe('GoogleCloudMetrics', () => { // after PR #1029 it('writes feature metrics for generate', async () => { - await runWithRegistry(ai.registry, async () => { - const testModel = createTestModel(ai, 'helloModel'); - await ai.generate({ model: testModel, prompt: 'Hi' }); - await ai.generate({ model: testModel, prompt: 'Yo' }); - }); + const testModel = createTestModel(ai, 'helloModel'); + await ai.generate({ model: testModel, prompt: 'Hi' }); + await ai.generate({ model: testModel, prompt: 'Yo' }); const spans = await getExportedSpans(); @@ -261,9 +258,7 @@ describe('GoogleCloudMetrics', () => { }); assert.rejects(async () => { - return await runWithRegistry(ai.registry, async () => { - return testAction(null); - }); + return testAction(null); }); await getExportedSpans(); @@ -414,9 +409,7 @@ describe('GoogleCloudMetrics', () => { }); }); - await runWithRegistry(ai.registry, async () => { - testAction(null); - }); + testAction(null); await getExportedSpans(); @@ -904,13 +897,11 @@ describe('GoogleCloudMetrics', () => { name: string, fn: () => Promise = async () => {} ) { - return runWithRegistry(ai.registry, () => - ai.defineFlow( - { - name, - }, - fn - ) + return ai.defineFlow( + { + name, + }, + fn ); } @@ -921,9 +912,7 @@ describe('GoogleCloudMetrics', () => { name: string, respFn: () => Promise ) { - return runWithRegistry(ai.registry, () => - ai.defineModel({ name }, (req) => respFn()) - ); + return ai.defineModel({ name }, (req) => respFn()); } function createTestModel(ai: Genkit, name: string) { diff --git a/js/plugins/google-cloud/tests/traces_test.ts b/js/plugins/google-cloud/tests/traces_test.ts index b4b687ed4..298002acb 100644 --- a/js/plugins/google-cloud/tests/traces_test.ts +++ b/js/plugins/google-cloud/tests/traces_test.ts @@ -16,7 +16,6 @@ import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; import { Genkit, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { appendSpan } from 'genkit/tracing'; import assert from 'node:assert'; import { after, before, beforeEach, describe, it } from 'node:test'; @@ -135,29 +134,27 @@ describe('GoogleCloudTracing', () => { }); it('adds the genkit/model label for model actions', async () => { - const echoModel = runWithRegistry(ai.registry, () => - ai.defineModel( - { - name: 'echoModel', - }, - async (request) => { - return { - message: { - role: 'model', - content: [ - { - text: - 'Echo: ' + - request.messages - .map((m) => m.content.map((c) => c.text).join()) - .join(), - }, - ], - }, - finishReason: 'stop', - }; - } - ) + const echoModel = ai.defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + finishReason: 'stop', + }; + } ); const testFlow = createFlow(ai, 'modelFlow', async () => { return run('runFlow', async () => { diff --git a/js/plugins/googleai/package.json b/js/plugins/googleai/package.json index 3385b440b..b581f5b18 100644 --- a/js/plugins/googleai/package.json +++ b/js/plugins/googleai/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/googleai/src/embedder.ts b/js/plugins/googleai/src/embedder.ts index 78d729fe1..9e0818488 100644 --- a/js/plugins/googleai/src/embedder.ts +++ b/js/plugins/googleai/src/embedder.ts @@ -15,7 +15,7 @@ */ import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai'; -import { Genkit, z } from 'genkit'; +import { EmbedderReference, Genkit, z } from 'genkit'; import { embedderRef } from 'genkit/embedder'; import { PluginOptions } from './index.js'; @@ -28,22 +28,21 @@ export const TaskTypeSchema = z.enum([ ]); export type TaskType = z.infer; -export const TextEmbeddingGeckoConfigSchema = z.object({ +export const GeminiEmbeddingConfigSchema = z.object({ /** * The `task_type` parameter is defined as the intended downstream application to help the model * produce better quality embeddings. **/ taskType: TaskTypeSchema.optional(), title: z.string().optional(), + version: z.string().optional(), }); -export type TextEmbeddingGeckoConfig = z.infer< - typeof TextEmbeddingGeckoConfigSchema ->; +export type GeminiEmbeddingConfig = z.infer; export const textEmbeddingGecko001 = embedderRef({ name: 'googleai/embedding-001', - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: { dimensions: 768, label: 'Google Gen AI - Text Embedding Gecko (Legacy)', @@ -57,7 +56,7 @@ export const SUPPORTED_MODELS = { 'embedding-001': textEmbeddingGecko001, }; -export function textEmbeddingGeckoEmbedder( +export function defineGoogleAIEmbedder( ai: Genkit, name: string, options: PluginOptions @@ -71,17 +70,36 @@ export function textEmbeddingGeckoEmbedder( 'Please pass in the API key or set either GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable.\n' + 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); - const client = new GoogleGenerativeAI(apiKey).getGenerativeModel({ - model: name, - }); - const embedder = SUPPORTED_MODELS[name]; + const embedder: EmbedderReference = + SUPPORTED_MODELS[name] ?? + embedderRef({ + name: name, + configSchema: GeminiEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Google AI - ${name}`, + supports: { + input: ['text'], + }, + }, + }); + const apiModelName = embedder.name.startsWith('googleai/') + ? embedder.name.substring('googleai/'.length) + : embedder.name; return ai.defineEmbedder( { name: embedder.name, - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: embedder.info!, }, async (input, options) => { + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel({ + model: + options?.version || + embedder.config?.version || + embedder.version || + apiModelName, + }); const embeddings = await Promise.all( input.map(async (doc) => { const response = await client.embedContent({ diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 27aae6607..003db2346 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -37,6 +37,7 @@ import { MediaPart, MessageData, ModelAction, + ModelInfo, ModelMiddleware, modelRef, ModelReference, @@ -67,16 +68,16 @@ const SafetySettingsSchema = z.object({ ]), }); -const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ +export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ safetySettings: z.array(SafetySettingsSchema).optional(), codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(), }); -export const geminiPro = modelRef({ - name: 'googleai/gemini-pro', +export const gemini10Pro = modelRef({ + name: 'googleai/gemini-1.0-pro', info: { label: 'Google AI - Gemini Pro', - versions: ['gemini-1.0-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], + versions: ['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], supports: { multiturn: true, media: false, @@ -87,28 +88,8 @@ export const geminiPro = modelRef({ configSchema: GeminiConfigSchema, }); -/** - * @deprecated Use `gemini15Pro` or `gemini15Flash` instead. - */ -export const geminiProVision = modelRef({ - name: 'googleai/gemini-pro-vision', - info: { - label: 'Google AI - Gemini Pro Vision', - // none declared on https://ai.google.dev/models/gemini#model-variations - versions: [], - supports: { - multiturn: true, - media: true, - tools: false, - systemRole: false, - }, - stage: 'deprecated', - }, - configSchema: GeminiConfigSchema, -}); - export const gemini15Pro = modelRef({ - name: 'googleai/gemini-1.5-pro-latest', + name: 'googleai/gemini-1.5-pro', info: { label: 'Google AI - Gemini 1.5 Pro', supports: { @@ -118,13 +99,17 @@ export const gemini15Pro = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-pro-001'], + versions: [ + 'gemini-1.5-pro-latest', + 'gemini-1.5-pro-001', + 'gemini-1.5-pro-002', + ], }, configSchema: GeminiConfigSchema, }); export const gemini15Flash = modelRef({ - name: 'googleai/gemini-1.5-flash-latest', + name: 'googleai/gemini-1.5-flash', info: { label: 'Google AI - Gemini 1.5 Flash', supports: { @@ -134,44 +119,45 @@ export const gemini15Flash = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-flash-001'], + versions: [ + 'gemini-1.5-flash-latest', + 'gemini-1.5-flash-001', + 'gemini-1.5-flash-002', + ], }, configSchema: GeminiConfigSchema, }); -export const geminiUltra = modelRef({ - name: 'googleai/gemini-ultra', +export const gemini15Flash8b = modelRef({ + name: 'googleai/gemini-1.5-flash-8b', info: { - label: 'Google AI - Gemini Ultra', - versions: [], + label: 'Google AI - Gemini 1.5 Flash', supports: { multiturn: true, - media: false, + media: true, tools: true, systemRole: true, + output: ['text', 'json'], }, + versions: ['gemini-1.5-flash-8b-latest', 'gemini-1.5-flash-8b-001'], }, configSchema: GeminiConfigSchema, }); -export const SUPPORTED_V1_MODELS: Record< - string, - ModelReference -> = { - 'gemini-pro': geminiPro, - 'gemini-pro-vision': geminiProVision, - // 'gemini-ultra': geminiUltra, +export const SUPPORTED_V1_MODELS = { + 'gemini-1.0-pro': gemini10Pro, }; -export const SUPPORTED_V15_MODELS: Record< - string, - ModelReference -> = { - 'gemini-1.5-pro-latest': gemini15Pro, - 'gemini-1.5-flash-latest': gemini15Flash, +export const SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro': gemini15Pro, + 'gemini-1.5-flash': gemini15Flash, + 'gemini-1.5-flash-8b': gemini15Flash8b, }; -const SUPPORTED_MODELS = { +export const SUPPORTED_GEMINI_MODELS: Record< + string, + ModelReference +> = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, }; @@ -451,17 +437,17 @@ export function fromGeminiCandidate( } /** - * + * Defines a new GoogleAI model. */ -export function googleAIModel( +export function defineGoogleAIModel( ai: Genkit, name: string, apiKey?: string, apiVersion?: string, - baseUrl?: string + baseUrl?: string, + info?: ModelInfo, + defaultConfig?: z.infer ): ModelAction { - const modelName = `googleai/${name}`; - if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; } @@ -471,15 +457,33 @@ export function googleAIModel( 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); } - - const model: ModelReference = SUPPORTED_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); + const apiModelName = name.startsWith('googleai/') + ? name.substring('googleai/'.length) + : name; + + const model: ModelReference = + SUPPORTED_GEMINI_MODELS[name] ?? + modelRef({ + name, + info: { + label: `Google AI - ${apiModelName}`, + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text', 'json'], + }, + ...info, + }, + configSchema: GeminiConfigSchema, + }); const middleware: ModelMiddleware[] = []; if (SUPPORTED_V1_MODELS[name]) { middleware.push(simulateSystemPrompt()); } - if (model?.info?.supports?.media) { + if (model.info?.supports?.media) { // the gemini api doesn't support downloading media from http(s) middleware.push( downloadRequestMedia({ @@ -495,7 +499,7 @@ export function googleAIModel( return ai.defineModel( { - name: modelName, + name: model.name, ...model.info, configSchema: model.configSchema, use: middleware, @@ -508,9 +512,18 @@ export function googleAIModel( if (apiVersion) { options.baseUrl = baseUrl; } + const requestConfig = { + ...defaultConfig, + ...request.config, + }; + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel( { - model: request.config?.version || model.version || name, + model: + requestConfig.version || + model.config?.version || + model.version || + apiModelName, }, options ); @@ -540,7 +553,7 @@ export function googleAIModel( }); } - if (request.config?.codeExecution) { + if (requestConfig.codeExecution) { tools.push({ codeExecution: request.config.codeExecution === true @@ -556,11 +569,11 @@ export function googleAIModel( const generationConfig: GenerationConfig = { candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, - stopSequences: request.config?.stopSequences, + temperature: requestConfig.temperature, + maxOutputTokens: requestConfig.maxOutputTokens, + topK: requestConfig.topK, + topP: requestConfig.topP, + stopSequences: requestConfig.stopSequences, responseMimeType: jsonMode ? 'application/json' : undefined, }; @@ -571,7 +584,7 @@ export function googleAIModel( history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), - safetySettings: request.config?.safetySettings, + safetySettings: requestConfig.safetySettings, } as StartChatParams; const msg = toGeminiMessage(messages[messages.length - 1], model); const fromJSONModeScopedGeminiCandidate = ( diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index 2a62bca9d..abb731426 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -18,25 +18,18 @@ import { Genkit } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { SUPPORTED_MODELS as EMBEDDER_MODELS, + defineGoogleAIEmbedder, textEmbeddingGecko001, - textEmbeddingGeckoEmbedder, } from './embedder.js'; import { SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, + defineGoogleAIModel, + gemini10Pro, gemini15Flash, gemini15Pro, - geminiPro, - geminiProVision, - googleAIModel, } from './gemini.js'; -export { - gemini15Flash, - gemini15Pro, - geminiPro, - geminiProVision, - textEmbeddingGecko001, -}; +export { gemini10Pro, gemini15Flash, gemini15Pro, textEmbeddingGecko001 }; export interface PluginOptions { apiKey?: string; @@ -56,21 +49,40 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { } } if (apiVersions.includes('v1beta')) { - Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, 'v1beta', options?.baseUrl) + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + 'v1beta', + options?.baseUrl + ) ); } if (apiVersions.includes('v1')) { - Object.keys(SUPPORTED_V1_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, undefined, options?.baseUrl) + Object.keys(SUPPORTED_V1_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) ); - Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, undefined, options?.baseUrl) + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) ); - Object.keys(EMBEDDER_MODELS).map((name) => - textEmbeddingGeckoEmbedder(ai, name, { apiKey: options?.apiKey }) + Object.keys(EMBEDDER_MODELS).forEach((name) => + defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) ); } }); } + export default googleAI; diff --git a/js/plugins/langchain/package.json b/js/plugins/langchain/package.json index d0bf80e49..ecf255071 100644 --- a/js/plugins/langchain/package.json +++ b/js/plugins/langchain/package.json @@ -9,7 +9,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 23cb82be9..7feefb673 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -10,7 +10,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts index 922c15217..592703cf3 100644 --- a/js/plugins/ollama/src/embeddings.ts +++ b/js/plugins/ollama/src/embeddings.ts @@ -13,23 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { Genkit, z } from 'genkit'; +import { Genkit } from 'genkit'; import { logger } from 'genkit/logging'; import { OllamaPluginParams } from './index.js'; -// Define the schema for Ollama embedding configuration -export const OllamaEmbeddingConfigSchema = z.object({ - modelName: z.string(), - serverAddress: z.string(), -}); - -export type OllamaEmbeddingConfig = z.infer; - -// Define the structure of the request and response for embedding -interface OllamaEmbeddingInstance { - content: string; -} - interface OllamaEmbeddingPrediction { embedding: number[]; } @@ -48,9 +35,7 @@ export function defineOllamaEmbedder( return ai.defineEmbedder( { name, - configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here info: { - // TODO: do we want users to be able to specify the label when they call this method directly? label: 'Ollama Embedding - ' + modelName, dimensions, supports: { @@ -59,7 +44,7 @@ export function defineOllamaEmbedder( }, }, }, - async (input, _config) => { + async (input) => { const serverAddress = options.serverAddress; const responses = await Promise.all( input.map(async (i) => { @@ -69,7 +54,6 @@ export function defineOllamaEmbedder( }; let res: Response; try { - console.log('MODEL NAME: ', modelName); res = await fetch(`${serverAddress}/api/embeddings`, { method: 'POST', headers: { diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts index 14f966d0a..e61a94b99 100644 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ b/js/plugins/ollama/tests/embeddings_test.ts @@ -16,11 +16,9 @@ import { Genkit, genkit } from 'genkit'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; -import { - OllamaEmbeddingConfigSchema, - defineOllamaEmbedder, -} from '../src/embeddings.js'; // Adjust the import path as necessary -import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary +import { defineOllamaEmbedder } from '../src/embeddings.js'; +import { OllamaPluginParams } from '../src/index.js'; + // Mock fetch to simulate API responses global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { const url = typeof input === 'string' ? input : input.toString(); @@ -41,6 +39,7 @@ global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { } throw new Error('Unknown API endpoint'); }; + describe('defineOllamaEmbedder', () => { const options: OllamaPluginParams = { models: [{ name: 'test-model' }], @@ -91,24 +90,6 @@ describe('defineOllamaEmbedder', () => { ); }); - it('should validate the embedding configuration schema', async () => { - const validConfig = { - modelName: 'test-model', - serverAddress: 'http://localhost:3000', - }; - const invalidConfig = { - modelName: 123, // Invalid type - serverAddress: 'http://localhost:3000', - }; - // Valid configuration should pass - assert.doesNotThrow(() => { - OllamaEmbeddingConfigSchema.parse(validConfig); - }); - // Invalid configuration should throw - assert.throws(() => { - OllamaEmbeddingConfigSchema.parse(invalidConfig); - }); - }); it('should throw an error if the fetch response is not ok', async () => { const embedder = defineOllamaEmbedder(ai, { name: 'test-embedder', diff --git a/js/plugins/pinecone/package.json b/js/plugins/pinecone/package.json index 9ff34e53c..d738c04be 100644 --- a/js/plugins/pinecone/package.json +++ b/js/plugins/pinecone/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/vertexai/package.json b/js/plugins/vertexai/package.json index dd0cf6a4d..157f361d1 100644 --- a/js/plugins/vertexai/package.json +++ b/js/plugins/vertexai/package.json @@ -17,7 +17,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", @@ -48,7 +48,7 @@ "genkit": "workspace:*" }, "optionalDependencies": { - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "@google-cloud/bigquery": "^7.8.0" }, "devDependencies": { diff --git a/js/plugins/vertexai/src/embedder.ts b/js/plugins/vertexai/src/embedder.ts index 3efea1811..10d2ca18c 100644 --- a/js/plugins/vertexai/src/embedder.ts +++ b/js/plugins/vertexai/src/embedder.ts @@ -27,9 +27,10 @@ export const TaskTypeSchema = z.enum([ 'CLASSIFICATION', 'CLUSTERING', ]); + export type TaskType = z.infer; -export const TextEmbeddingGeckoConfigSchema = z.object({ +export const VertexEmbeddingConfigSchema = z.object({ /** * The `task_type` parameter is defined as the intended downstream application to help the model * produce better quality embeddings. @@ -37,92 +38,47 @@ export const TextEmbeddingGeckoConfigSchema = z.object({ taskType: TaskTypeSchema.optional(), title: z.string().optional(), location: z.string().optional(), -}); -export type TextEmbeddingGeckoConfig = z.infer< - typeof TextEmbeddingGeckoConfigSchema ->; - -export const textEmbeddingGecko003 = embedderRef({ - name: 'vertexai/textembedding-gecko@003', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGecko002 = embedderRef({ - name: 'vertexai/textembedding-gecko@002', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGecko001 = embedderRef({ - name: 'vertexai/textembedding-gecko@001', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding Gecko (Legacy)', - supports: { - input: ['text'], - }, - }, + version: z.string().optional(), }); -export const textEmbedding004 = embedderRef({ - name: 'vertexai/text-embedding-004', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Embedding 004', - supports: { - input: ['text'], - }, - }, -}); +export type VertexEmbeddingConfig = z.infer; -export const textMultilingualEmbedding002 = embedderRef({ - name: 'vertexai/text-multilingual-embedding-002', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Text Multilingual Embedding 002', - supports: { - input: ['text'], - }, - }, -}); - -export const textEmbeddingGeckoMultilingual001 = embedderRef({ - name: 'vertexai/textembedding-gecko-multilingual@001', - configSchema: TextEmbeddingGeckoConfigSchema, - info: { - dimensions: 768, - label: 'Vertex AI - Multilingual Text Embedding Gecko 001', - supports: { - input: ['text'], +function commonRef( + name: string, + input?: ('text' | 'image')[] +): EmbedderReference { + return embedderRef({ + name: `vertexai/${name}`, + configSchema: VertexEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Vertex AI - ${name}`, + supports: { + input: input ?? ['text'], + }, }, - }, -}); + }); +} -export const textEmbeddingGecko = textEmbeddingGecko003; +export const textEmbeddingGecko003 = commonRef('textembedding-gecko@003'); +export const textEmbedding004 = commonRef('text-embedding-004'); +export const textEmbeddingGeckoMultilingual001 = commonRef( + 'textembedding-gecko-multilingual@001' +); +export const textMultilingualEmbedding002 = commonRef( + 'text-multilingual-embedding-002' +); export const SUPPORTED_EMBEDDER_MODELS: Record = { 'textembedding-gecko@003': textEmbeddingGecko003, - 'textembedding-gecko@002': textEmbeddingGecko002, - 'textembedding-gecko@001': textEmbeddingGecko001, 'text-embedding-004': textEmbedding004, 'textembedding-gecko-multilingual@001': textEmbeddingGeckoMultilingual001, 'text-multilingual-embedding-002': textMultilingualEmbedding002, + // TODO: add support for multimodal embeddings + // 'multimodalembedding@001': commonRef('multimodalembedding@001', [ + // 'image', + // 'text', + // ]), }; interface EmbeddingInstance { @@ -140,7 +96,7 @@ interface EmbeddingPrediction { }; } -export function textEmbeddingGeckoEmbedder( +export function defineVertexAIEmbedder( ai: Genkit, name: string, client: GoogleAuth, @@ -152,7 +108,7 @@ export function textEmbeddingGeckoEmbedder( PredictClient > = {}; const predictClientFactory = ( - config: TextEmbeddingGeckoConfig + config: VertexEmbeddingConfig ): PredictClient => { const requestLocation = config?.location || options.location; if (!predictClients[requestLocation]) { diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 316c9373e..42dfb2215 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -73,11 +73,11 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ googleSearchRetrieval: GoogleSearchRetrievalSchema.optional(), }); -export const geminiPro = modelRef({ +export const gemini10Pro = modelRef({ name: 'vertexai/gemini-1.0-pro', info: { label: 'Vertex AI - Gemini Pro', - versions: ['gemini-1.0-pro', 'gemini-1.0-pro-001'], + versions: ['gemini-1.0-pro-001', 'gemini-1.0-pro-002'], supports: { multiturn: true, media: false, @@ -88,29 +88,11 @@ export const geminiPro = modelRef({ configSchema: GeminiConfigSchema, }); -export const geminiProVision = modelRef({ - name: 'vertexai/gemini-1.0-pro-vision', - info: { - label: 'Vertex AI - Gemini Pro Vision', - versions: ['gemini-1.0-pro-vision', 'gemini-1.0-pro-vision-001'], - supports: { - multiturn: true, - media: true, - tools: false, - systemRole: false, - }, - }, - configSchema: GeminiConfigSchema.omit({ - googleSearchRetrieval: true, - vertexRetrieval: true, - }), -}); - export const gemini15Pro = modelRef({ name: 'vertexai/gemini-1.5-pro', info: { label: 'Vertex AI - Gemini 1.5 Pro', - versions: ['gemini-1.5-pro-001'], + versions: ['gemini-1.5-pro-001', 'gemini-1.5-pro-002'], supports: { multiturn: true, media: true, @@ -121,43 +103,11 @@ export const gemini15Pro = modelRef({ configSchema: GeminiConfigSchema, }); -export const gemini15ProPreview = modelRef({ - name: 'vertexai/gemini-1.5-pro-preview', - info: { - label: 'Vertex AI - Gemini 1.5 Pro Preview', - versions: ['gemini-1.5-pro-preview-0409'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, - version: 'gemini-1.5-pro-preview-0409', -}); - -export const gemini15FlashPreview = modelRef({ - name: 'vertexai/gemini-1.5-flash-preview', - info: { - label: 'Vertex AI - Gemini 1.5 Flash', - versions: ['gemini-1.5-flash-preview-0514'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, - version: 'gemini-1.5-flash-preview-0514', -}); - export const gemini15Flash = modelRef({ name: 'vertexai/gemini-1.5-flash', info: { label: 'Vertex AI - Gemini 1.5 Flash', - versions: ['gemini-1.5-flash-001'], + versions: ['gemini-1.5-flash-001', 'gemini-1.5-flash-002'], supports: { multiturn: true, media: true, @@ -169,16 +119,12 @@ export const gemini15Flash = modelRef({ }); export const SUPPORTED_V1_MODELS = { - 'gemini-1.0-pro': geminiPro, - 'gemini-1.0-pro-vision': geminiProVision, - // 'gemini-ultra': geminiUltra, + 'gemini-1.0-pro': gemini10Pro, }; export const SUPPORTED_V15_MODELS = { 'gemini-1.5-pro': gemini15Pro, 'gemini-1.5-flash': gemini15Flash, - 'gemini-1.5-pro-preview': gemini15ProPreview, - 'gemini-1.5-flash-preview': gemini15FlashPreview, }; export const SUPPORTED_GEMINI_MODELS = { @@ -458,9 +404,9 @@ const convertSchemaProperty = (property) => { }; /** - * + * Define a Vertex AI Gemini model. */ -export function geminiModel( +export function defineGeminiModel( ai: Genkit, name: string, vertexClientFactory: ( diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index d6231141f..c6c4a227e 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -29,12 +29,9 @@ import { } from './anthropic.js'; import { SUPPORTED_EMBEDDER_MODELS, + defineVertexAIEmbedder, textEmbedding004, - textEmbeddingGecko, - textEmbeddingGecko001, - textEmbeddingGecko002, textEmbeddingGecko003, - textEmbeddingGeckoEmbedder, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, } from './embedder.js'; @@ -46,13 +43,10 @@ import { import { GeminiConfigSchema, SUPPORTED_GEMINI_MODELS, + defineGeminiModel, + gemini10Pro, gemini15Flash, - gemini15FlashPreview, gemini15Pro, - gemini15ProPreview, - geminiModel, - geminiPro, - geminiProVision, } from './gemini.js'; import { SUPPORTED_IMAGEN_MODELS, @@ -94,12 +88,9 @@ export { claude3Haiku, claude3Opus, claude3Sonnet, + gemini10Pro, gemini15Flash, - gemini15FlashPreview, gemini15Pro, - gemini15ProPreview, - geminiPro, - geminiProVision, imagen2, imagen3, imagen3Fast, @@ -107,9 +98,6 @@ export { llama31, llama32, textEmbedding004, - textEmbeddingGecko, - textEmbeddingGecko001, - textEmbeddingGecko002, textEmbeddingGecko003, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, @@ -206,7 +194,7 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { imagenModel(ai, name, authClient, { projectId, location }) ); Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - geminiModel(ai, name, vertexClientFactory, { projectId, location }) + defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) ); if (options?.modelGardenModels || options?.modelGarden?.models) { @@ -239,7 +227,7 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { } const embedders = Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => - textEmbeddingGeckoEmbedder(ai, name, authClient, { projectId, location }) + defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) ); if ( diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 313fcc1dc..fcd40b9db 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -325,11 +325,11 @@ importers: specifier: ^4.21.0 version: 4.21.0 firebase-admin: - specifier: ^12.2.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) firebase-functions: - specifier: ^4.8.0 || ^5.0.0 - version: 4.8.1(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13)) + specifier: '>=4.8' + version: 4.8.1(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13)) genkit: specifier: workspace:* version: link:../../genkit @@ -611,8 +611,8 @@ importers: specifier: ^7.8.0 version: 7.8.0(encoding@0.1.13) firebase-admin: - specifier: ^12.1.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) devDependencies: '@types/node': specifier: ^20.11.16 @@ -735,7 +735,7 @@ importers: specifier: ^1.22.0 version: 1.25.1(@opentelemetry/api@1.9.0) firebase-admin: - specifier: ^12.3.0 + specifier: '>=12.2' version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* @@ -790,8 +790,8 @@ importers: specifier: workspace:* version: link:../../plugins/vertexai firebase-admin: - specifier: ^12.1.0 - version: 12.1.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -989,8 +989,8 @@ importers: specifier: ^1.25.0 version: 1.25.1(@opentelemetry/api@1.9.0) firebase-admin: - specifier: ^12.1.0 - version: 12.1.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -1365,8 +1365,8 @@ importers: specifier: ^4.21.0 version: 4.21.0 firebase-admin: - specifier: ^12.1.0 - version: 12.2.0(encoding@0.1.13) + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -1863,10 +1863,6 @@ packages: cpu: [x64] os: [win32] - '@fastify/busboy@2.1.1': - resolution: {integrity: sha512-vBZP4NlzfOlerQTnba4aqZoMhE/a9HY7HRqoOPaETQcSQuWEIyZMHGfVu6w9wGtGK5fED5qRs2DteVCjOH60sA==} - engines: {node: '>=14'} - '@fastify/busboy@3.0.0': resolution: {integrity: sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w==} @@ -1919,10 +1915,6 @@ packages: resolution: {integrity: sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==} engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.8.0': - resolution: {integrity: sha512-m21BWVZLz7H7NF8HZ5hCGUSCEJKNwYB5yzQqDTuE9YUzNDRMDei3BwVDht5k4xF636sGlnobyBL+dcbthSGONg==} - engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.9.0': resolution: {integrity: sha512-c4ALHT3G08rV7Zwv8Z2KG63gZh66iKdhCBeDfCpIkLrjX6EAjTD/szMdj14M+FnQuClZLFfW5bAgoOjfNmLtJg==} engines: {node: '>=14.0.0'} @@ -3412,9 +3404,6 @@ packages: binary-search@1.3.6: resolution: {integrity: sha512-nbE1WxOTTrUWIfsfZ4aHGYu5DOuNkbxGokjV6Z2kxfJK3uaAb8zNK1muzOeipoLHZjInT4Br88BHpzevc681xA==} - bl@4.1.0: - resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} - body-parser@1.20.2: resolution: {integrity: sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==} engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} @@ -3451,9 +3440,6 @@ packages: buffer-from@1.1.2: resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} - buffer@5.7.1: - resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} - bundle-require@4.0.2: resolution: {integrity: sha512-jwzPOChofl67PSTW2SGubV9HBQAhhR2i6nskiOThauo9dzwDUgOWQScFVaJkjEfYX+UXiD+LEx8EblQMc2wIag==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} @@ -3516,9 +3502,6 @@ packages: resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==} engines: {node: '>= 8.10.0'} - chownr@1.1.4: - resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} - chownr@2.0.0: resolution: {integrity: sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==} engines: {node: '>=10'} @@ -3724,10 +3707,6 @@ packages: resolution: {integrity: sha512-jOSne2qbyE+/r8G1VU+G/82LBs2Fs4LAsTiLSHOCOMZQl2OKZ6i8i4IyHemTe+/yIXOtTcRQMzPcgyhoFlqPkw==} engines: {node: '>=8'} - decompress-response@6.0.0: - resolution: {integrity: sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==} - engines: {node: '>=10'} - dedent@1.5.3: resolution: {integrity: sha512-NHQtfOOW68WD8lgypbLA5oT+Bt0xXJhiYvoR6SmmNXZfpzOGXwdKWmcwG8N7PwVVWV3eF/68nmD9BaJSsTBhyQ==} peerDependencies: @@ -3736,10 +3715,6 @@ packages: babel-plugin-macros: optional: true - deep-extend@0.6.0: - resolution: {integrity: sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==} - engines: {node: '>=4.0.0'} - deepmerge@4.3.1: resolution: {integrity: sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==} engines: {node: '>=0.10.0'} @@ -3932,10 +3907,6 @@ packages: resolution: {integrity: sha512-Zk/eNKV2zbjpKzrsQ+n1G6poVbErQxJ0LBOJXaKZ1EViLzH+hrLu9cdXI4zw9dBQJslwBEpbQ2P1oS7nDxs6jQ==} engines: {node: '>= 0.8.0'} - expand-template@2.0.3: - resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} - engines: {node: '>=6'} - expect@29.7.0: resolution: {integrity: sha512-2Zks0hf1VLFYI1kbh0I5jP3KHHyCHpkfyHBzsSXRFgl/Bg9mWYfMW8oD+PdMPlEwy5HNsR9JutYy6pMeOh61nw==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} @@ -3958,10 +3929,6 @@ packages: resolution: {integrity: sha512-6ypT4XfgqJk/F3Yuv4SX26I3doUjt0GTG4a+JgWxXQpxXzTBq8fPUeGHfcYMMDPHJHm3yPOSjaeBwBGAHWXCdA==} engines: {node: '>=18.0.0'} - farmhash@3.3.1: - resolution: {integrity: sha512-XUizHanzlr/v7suBr/o85HSakOoWh6HKXZjFYl5C2+Gj0f0rkw+XTUZzrd9odDsgI9G5tRUcF4wSbKaX04T0DQ==} - engines: {node: '>=10'} - fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} @@ -4029,14 +3996,6 @@ packages: resolution: {integrity: sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==} engines: {node: '>=8'} - firebase-admin@12.1.0: - resolution: {integrity: sha512-bU7uPKMmIXAihWxntpY/Ma9zucn5y3ec+HQPqFQ/zcEfP9Avk9E/6D8u+yT/VwKHNZyg7yDVWOoJi73TIdR4Ww==} - engines: {node: '>=14'} - - firebase-admin@12.2.0: - resolution: {integrity: sha512-R9xxENvPA/19XJ3mv0Kxfbz9kPXd9/HrM4083LZWOO0qAQGheRzcCQamYRe+JSrV2cdKXP3ZsfFGTYMrFM0pJg==} - engines: {node: '>=14'} - firebase-admin@12.3.1: resolution: {integrity: sha512-vEr3s3esl8nPIA9r/feDT4nzIXCfov1CyyCSpMQWp6x63Q104qke0MEGZlrHUZVROtl8FLus6niP/M9I1s4VBA==} engines: {node: '>=14'} @@ -4092,9 +4051,6 @@ packages: front-matter@4.0.2: resolution: {integrity: sha512-I8ZuJ/qG92NWX8i5x1Y8qyj3vizhXS31OxjKDu3LKP+7/qBgfIKValiZIEwoVoJKUHlhWtYrktkxV1XsX+pPlg==} - fs-constants@1.0.0: - resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} - fs-minipass@2.1.0: resolution: {integrity: sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==} engines: {node: '>= 8'} @@ -4186,9 +4142,6 @@ packages: get-tsconfig@4.8.1: resolution: {integrity: sha512-k9PN+cFBmaLWtVz29SkUoqU5O0slLuHJXt/2P+tMVFT+phsSGXGkp9t3rQIqdz0e+06EHNGs3oM6ZX1s2zHxRg==} - github-from-package@0.0.0: - resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} - glob-parent@5.1.2: resolution: {integrity: sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==} engines: {node: '>= 6'} @@ -4349,9 +4302,6 @@ packages: resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} engines: {node: '>=0.10.0'} - ieee754@1.2.1: - resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} - ignore@5.3.1: resolution: {integrity: sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==} engines: {node: '>= 4'} @@ -4378,9 +4328,6 @@ packages: inherits@2.0.4: resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} - ini@1.3.8: - resolution: {integrity: sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==} - internal-slot@1.0.7: resolution: {integrity: sha512-NGnrKwXzSms2qUUih/ILZ5JBqNTSa1+ZmP6flaIp6KmSElgE9qdndzS3cqjrDovwFdmwsGsLdeFgB6suw+1e9g==} engines: {node: '>= 0.4'} @@ -5089,10 +5036,6 @@ packages: resolution: {integrity: sha512-wXqjST+SLt7R009ySCglWBCFpjUygmCIfD790/kVbiGmUgfYGuB14PiTd5DwVxSV4NcYHjzMkoj5LjQZwTQLEA==} engines: {node: '>=8'} - mimic-response@3.1.0: - resolution: {integrity: sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==} - engines: {node: '>=10'} - minimatch@10.0.1: resolution: {integrity: sha512-ethXTt3SGGR+95gudmqJ1eNhRO7eGEGIgYA9vnPatK4/etz2MEVDno5GMCibdMTuBMyElzIlgxMna3K94XDIDQ==} engines: {node: 20 || >=22} @@ -5131,9 +5074,6 @@ packages: resolution: {integrity: sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==} engines: {node: '>= 8'} - mkdirp-classic@0.5.3: - resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} - mkdirp@1.0.4: resolution: {integrity: sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==} engines: {node: '>=10'} @@ -5181,9 +5121,6 @@ packages: engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} hasBin: true - napi-build-utils@1.0.2: - resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==} - natural-compare@1.4.0: resolution: {integrity: sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==} @@ -5197,13 +5134,6 @@ packages: nice-try@1.0.5: resolution: {integrity: sha512-1nh45deeb5olNY7eX82BkPO7SSxR5SSYJiPTrTdFUVYwAl8CKMA5N9PjTYkHiRjisVcxcQ1HXdLhx2qxxJzLNQ==} - node-abi@3.62.0: - resolution: {integrity: sha512-CPMcGa+y33xuL1E0TcNIu4YyaZCxnnvkVaEXrsosR3FxN+fV8xvb7Mzpb7IgKler10qeMkE6+Dp8qJhpzdq35g==} - engines: {node: '>=10'} - - node-addon-api@5.1.0: - resolution: {integrity: sha512-eh0GgfEkpnoWDq+VY8OyvYhFEzBk6jIYbRKdIlyTiAXIVJ8PyBaKb0rp7oDtoddbdoHWhq8wwr+XZ81F1rpNdA==} - node-domexception@1.0.0: resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} engines: {node: '>=10.5.0'} @@ -5516,11 +5446,6 @@ packages: resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} engines: {node: '>=0.10.0'} - prebuild-install@7.1.2: - resolution: {integrity: sha512-UnNke3IQb6sgarcZIDU3gbMeTp/9SSU1DAIkil7PrqG1vZlBtY5msYccSKSHDqa3hNg436IXK+SNImReuA1wEQ==} - engines: {node: '>=10'} - hasBin: true - prettier-plugin-organize-imports@3.2.4: resolution: {integrity: sha512-6m8WBhIp0dfwu0SkgfOxJqh+HpdyfqSSLfKKRZSFbDuEQXDDndb8fTpRWkUrX/uBenkex3MgnVk0J3b3Y5byog==} peerDependencies: @@ -5613,10 +5538,6 @@ packages: resolution: {integrity: sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==} engines: {node: '>= 0.8'} - rc@1.2.8: - resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} - hasBin: true - react-is@18.3.1: resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==} @@ -5807,9 +5728,6 @@ packages: simple-get@3.1.1: resolution: {integrity: sha512-CQ5LTKGfCpvE1K0n2us+kuMPbk/q0EKl82s4aheV9oXjFEz6W/Y7oQFVJuU6QG77hRT4Ghb5RURteF5vnWjupA==} - simple-get@4.0.1: - resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} - simple-swizzle@0.2.2: resolution: {integrity: sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==} @@ -5917,10 +5835,6 @@ packages: resolution: {integrity: sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==} engines: {node: '>=6'} - strip-json-comments@2.0.1: - resolution: {integrity: sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==} - engines: {node: '>=0.10.0'} - strip-json-comments@3.1.1: resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} engines: {node: '>=8'} @@ -5952,13 +5866,6 @@ packages: resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==} engines: {node: '>= 0.4'} - tar-fs@2.1.1: - resolution: {integrity: sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==} - - tar-stream@2.2.0: - resolution: {integrity: sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==} - engines: {node: '>=6'} - tar@6.2.1: resolution: {integrity: sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==} engines: {node: '>=10'} @@ -6096,9 +6003,6 @@ packages: engines: {node: '>=18.0.0'} hasBin: true - tunnel-agent@0.6.0: - resolution: {integrity: sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==} - type-detect@4.0.8: resolution: {integrity: sha512-0fr/mIH1dlO+x7TlcMy+bIDqKPsw/70tVyeHW787goQjhmqaZe10uwLujubK9q9Lg6Fiho1KUKDYz0Z7k7g5/g==} engines: {node: '>=4'} @@ -6741,8 +6645,6 @@ snapshots: '@esbuild/win32-x64@0.23.1': optional: true - '@fastify/busboy@2.1.1': {} - '@fastify/busboy@3.0.0': {} '@firebase/app-check-interop-types@0.3.1': {} @@ -6871,17 +6773,6 @@ snapshots: - encoding - supports-color - '@google-cloud/firestore@7.8.0(encoding@0.1.13)': - dependencies: - fast-deep-equal: 3.1.3 - functional-red-black-tree: 1.0.1 - google-gax: 4.3.7(encoding@0.1.13) - protobufjs: 7.3.2 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - '@google-cloud/firestore@7.9.0(encoding@0.1.13)': dependencies: fast-deep-equal: 3.1.3 @@ -6979,7 +6870,7 @@ snapshots: '@google-cloud/storage@7.10.1(encoding@0.1.13)': dependencies: - '@google-cloud/paginator': 5.0.0 + '@google-cloud/paginator': 5.0.2 '@google-cloud/projectify': 4.0.0 '@google-cloud/promisify': 4.0.0 abort-controller: 3.0.0 @@ -8474,12 +8365,6 @@ snapshots: binary-search@1.3.6: {} - bl@4.1.0: - dependencies: - buffer: 5.7.1 - inherits: 2.0.4 - readable-stream: 3.6.2 - body-parser@1.20.2: dependencies: bytes: 3.1.2 @@ -8546,11 +8431,6 @@ snapshots: buffer-from@1.1.2: {} - buffer@5.7.1: - dependencies: - base64-js: 1.5.1 - ieee754: 1.2.1 - bundle-require@4.0.2(esbuild@0.19.12): dependencies: esbuild: 0.19.12 @@ -8618,8 +8498,6 @@ snapshots: optionalDependencies: fsevents: 2.3.3 - chownr@1.1.4: {} - chownr@2.0.0: optional: true @@ -8813,14 +8691,8 @@ snapshots: mimic-response: 2.1.0 optional: true - decompress-response@6.0.0: - dependencies: - mimic-response: 3.1.0 - dedent@1.5.3: {} - deep-extend@0.6.0: {} - deepmerge@4.3.1: {} define-data-property@1.1.4: @@ -8844,7 +8716,8 @@ snapshots: destroy@1.2.0: {} - detect-libc@2.0.3: {} + detect-libc@2.0.3: + optional: true detect-newline@3.1.0: {} @@ -9074,8 +8947,6 @@ snapshots: exit@0.1.2: {} - expand-template@2.0.3: {} - expect@29.7.0: dependencies: '@jest/expect-utils': 29.7.0 @@ -9162,11 +9033,6 @@ snapshots: farmhash-modern@1.1.0: {} - farmhash@3.3.1: - dependencies: - node-addon-api: 5.1.0 - prebuild-install: 7.1.2 - fast-deep-equal@3.1.3: {} fast-glob@3.3.2: @@ -9253,44 +9119,6 @@ snapshots: locate-path: 5.0.0 path-exists: 4.0.0 - firebase-admin@12.1.0(encoding@0.1.13): - dependencies: - '@fastify/busboy': 2.1.1 - '@firebase/database-compat': 1.0.4 - '@firebase/database-types': 1.0.2 - '@types/node': 20.16.9 - farmhash: 3.3.1 - jsonwebtoken: 9.0.2 - jwks-rsa: 3.1.0 - long: 5.2.3 - node-forge: 1.3.1 - uuid: 9.0.1 - optionalDependencies: - '@google-cloud/firestore': 7.8.0(encoding@0.1.13) - '@google-cloud/storage': 7.10.1(encoding@0.1.13) - transitivePeerDependencies: - - encoding - - supports-color - - firebase-admin@12.2.0(encoding@0.1.13): - dependencies: - '@fastify/busboy': 2.1.1 - '@firebase/database-compat': 1.0.4 - '@firebase/database-types': 1.0.2 - '@types/node': 20.16.9 - farmhash-modern: 1.1.0 - jsonwebtoken: 9.0.2 - jwks-rsa: 3.1.0 - long: 5.2.3 - node-forge: 1.3.1 - uuid: 10.0.0 - optionalDependencies: - '@google-cloud/firestore': 7.8.0(encoding@0.1.13) - '@google-cloud/storage': 7.10.1(encoding@0.1.13) - transitivePeerDependencies: - - encoding - - supports-color - firebase-admin@12.3.1(encoding@0.1.13): dependencies: '@fastify/busboy': 3.0.0 @@ -9309,15 +9137,15 @@ snapshots: - encoding - supports-color - firebase-functions@4.8.1(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13)): + firebase-functions@4.8.1(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13)): dependencies: '@types/cors': 2.8.17 '@types/express': 4.17.3 cors: 2.8.5 express: 4.21.0 - firebase-admin: 12.2.0(encoding@0.1.13) + firebase-admin: 12.3.1(encoding@0.1.13) node-fetch: 2.7.0(encoding@0.1.13) - protobufjs: 7.2.6 + protobufjs: 7.3.2 transitivePeerDependencies: - encoding - supports-color @@ -9366,8 +9194,6 @@ snapshots: dependencies: js-yaml: 3.14.1 - fs-constants@1.0.0: {} - fs-minipass@2.1.0: dependencies: minipass: 3.3.6 @@ -9489,8 +9315,6 @@ snapshots: dependencies: resolve-pkg-maps: 1.0.0 - github-from-package@0.0.0: {} - glob-parent@5.1.2: dependencies: is-glob: 4.0.3 @@ -9753,8 +9577,6 @@ snapshots: dependencies: safer-buffer: 2.1.2 - ieee754@1.2.1: {} - ignore@5.3.1: {} import-in-the-middle@1.11.0: @@ -9780,8 +9602,6 @@ snapshots: inherits@2.0.4: {} - ini@1.3.8: {} - internal-slot@1.0.7: dependencies: es-errors: 1.3.0 @@ -10304,7 +10124,7 @@ snapshots: lodash.isstring: 4.0.1 lodash.once: 4.1.1 ms: 2.1.3 - semver: 7.6.0 + semver: 7.6.3 jwa@1.4.1: dependencies: @@ -10322,7 +10142,7 @@ snapshots: dependencies: '@types/express': 4.17.21 '@types/jsonwebtoken': 9.0.6 - debug: 4.3.4 + debug: 4.3.7 jose: 4.15.5 limiter: 1.1.5 lru-memoizer: 2.2.0 @@ -10594,8 +10414,6 @@ snapshots: mimic-response@2.1.0: optional: true - mimic-response@3.1.0: {} - minimatch@10.0.1: dependencies: brace-expansion: 2.0.1 @@ -10632,8 +10450,6 @@ snapshots: yallist: 4.0.0 optional: true - mkdirp-classic@0.5.3: {} - mkdirp@1.0.4: optional: true @@ -10680,8 +10496,6 @@ snapshots: nanoid@3.3.7: optional: true - napi-build-utils@1.0.2: {} - natural-compare@1.4.0: {} negotiator@0.6.3: {} @@ -10690,12 +10504,6 @@ snapshots: nice-try@1.0.5: {} - node-abi@3.62.0: - dependencies: - semver: 7.6.0 - - node-addon-api@5.1.0: {} - node-domexception@1.0.0: {} node-ensure@0.0.0: {} @@ -10981,21 +10789,6 @@ snapshots: dependencies: xtend: 4.0.2 - prebuild-install@7.1.2: - dependencies: - detect-libc: 2.0.3 - expand-template: 2.0.3 - github-from-package: 0.0.0 - minimist: 1.2.8 - mkdirp-classic: 0.5.3 - napi-build-utils: 1.0.2 - node-abi: 3.62.0 - pump: 3.0.0 - rc: 1.2.8 - simple-get: 4.0.1 - tar-fs: 2.1.1 - tunnel-agent: 0.6.0 - prettier-plugin-organize-imports@3.2.4(prettier@3.2.5)(typescript@4.9.5): dependencies: prettier: 3.2.5 @@ -11103,13 +10896,6 @@ snapshots: iconv-lite: 0.4.24 unpipe: 1.0.0 - rc@1.2.8: - dependencies: - deep-extend: 0.6.0 - ini: 1.3.8 - minimist: 1.2.8 - strip-json-comments: 2.0.1 - react-is@18.3.1: {} read-pkg@3.0.0: @@ -11364,7 +11150,8 @@ snapshots: signal-exit@4.1.0: {} - simple-concat@1.0.1: {} + simple-concat@1.0.1: + optional: true simple-get@3.1.1: dependencies: @@ -11373,12 +11160,6 @@ snapshots: simple-concat: 1.0.1 optional: true - simple-get@4.0.1: - dependencies: - decompress-response: 6.0.0 - once: 1.4.0 - simple-concat: 1.0.1 - simple-swizzle@0.2.2: dependencies: is-arrayish: 0.3.2 @@ -11492,8 +11273,6 @@ snapshots: strip-final-newline@2.0.0: {} - strip-json-comments@2.0.1: {} - strip-json-comments@3.1.1: {} strnum@1.0.5: @@ -11525,21 +11304,6 @@ snapshots: supports-preserve-symlinks-flag@1.0.0: {} - tar-fs@2.1.1: - dependencies: - chownr: 1.1.4 - mkdirp-classic: 0.5.3 - pump: 3.0.0 - tar-stream: 2.2.0 - - tar-stream@2.2.0: - dependencies: - bl: 4.1.0 - end-of-stream: 1.4.4 - fs-constants: 1.0.0 - inherits: 2.0.4 - readable-stream: 3.6.2 - tar@6.2.1: dependencies: chownr: 2.0.0 @@ -11691,10 +11455,6 @@ snapshots: optionalDependencies: fsevents: 2.3.3 - tunnel-agent@0.6.0: - dependencies: - safe-buffer: 5.2.1 - type-detect@4.0.8: {} type-fest@0.21.3: {} diff --git a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts index d3750cb84..1c887d50d 100644 --- a/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts +++ b/js/testapps/byo-evaluator/src/deliciousness/deliciousness.ts @@ -40,6 +40,7 @@ export async function deliciousnessScore< throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/deliciousness.prompt') ); const response = await ai.generate({ diff --git a/js/testapps/byo-evaluator/src/funniness/funniness.ts b/js/testapps/byo-evaluator/src/funniness/funniness.ts index 3f38f0e1e..e1a1df5cf 100644 --- a/js/testapps/byo-evaluator/src/funniness/funniness.ts +++ b/js/testapps/byo-evaluator/src/funniness/funniness.ts @@ -42,6 +42,7 @@ export async function funninessScore( throw new Error('Output is required for Funniness detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/funniness.prompt') ); diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 2463d53a6..9f9e4fdb8 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { geminiPro, googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { Genkit, ModelReference, genkit, z } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { @@ -41,7 +41,7 @@ export const ai = genkit({ plugins: [ googleAI({ apiVersion: ['v1', 'v1beta'] }), byoEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, metrics: [ // regexMatcher will register an evaluator with a name in the format diff --git a/js/testapps/byo-evaluator/src/pii/pii_detection.ts b/js/testapps/byo-evaluator/src/pii/pii_detection.ts index b9d296f5d..d0079fdd1 100644 --- a/js/testapps/byo-evaluator/src/pii/pii_detection.ts +++ b/js/testapps/byo-evaluator/src/pii/pii_detection.ts @@ -37,6 +37,7 @@ export async function piiDetectionScore< throw new Error('Output is required for PII detection'); } const finalPrompt = await loadPromptFile( + ai.registry, path.resolve(__dirname, '../../prompts/pii_detection.prompt') ); diff --git a/js/testapps/cat-eval/package.json b/js/testapps/cat-eval/package.json index e3942b353..8bcede3fc 100644 --- a/js/testapps/cat-eval/package.json +++ b/js/testapps/cat-eval/package.json @@ -24,7 +24,7 @@ "@genkit-ai/vertexai": "workspace:*", "@google-cloud/firestore": "^7.9.0", "@opentelemetry/sdk-trace-base": "^1.22.0", - "firebase-admin": "^12.3.0", + "firebase-admin": ">=12.2", "genkitx-pinecone": "workspace:*", "llm-chunk": "^0.0.1", "pdf-parse": "^1.1.1", diff --git a/js/testapps/cat-eval/src/genkit.ts b/js/testapps/cat-eval/src/genkit.ts index 6f70dba5b..580886658 100644 --- a/js/testapps/cat-eval/src/genkit.ts +++ b/js/testapps/cat-eval/src/genkit.ts @@ -17,7 +17,7 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; import { gemini15Pro, googleAI } from '@genkit-ai/googleai'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; // Turn off safety checks for evaluation so that the LLM as an evaluator can @@ -50,7 +50,7 @@ export const ai = genkit({ judge: gemini15Pro, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, metrics: [GenkitMetric.MALICIOUSNESS], - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }), vertexAI({ location: 'us-central1', @@ -58,7 +58,7 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'pdfQA', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/cat-eval/src/pdf_rag.ts b/js/testapps/cat-eval/src/pdf_rag.ts index 464e17aa9..d76d2024d 100644 --- a/js/testapps/cat-eval/src/pdf_rag.ts +++ b/js/testapps/cat-eval/src/pdf_rag.ts @@ -18,7 +18,7 @@ import { devLocalIndexerRef, devLocalRetrieverRef, } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/googleai'; +import { gemini15Flash } from '@genkit-ai/googleai'; import { run, z } from 'genkit'; import { Document } from 'genkit/retriever'; import { chunk } from 'llm-chunk'; @@ -64,7 +64,7 @@ export const pdfQA = ai.defineFlow( context: docs.map((d) => d.text).join('\n\n'), }); const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: augmentedPrompt, }); return llmResponse.text; @@ -141,7 +141,7 @@ export const synthesizeQuestions = ai.defineFlow( const questions: string[] = []; for (let i = 0; i < chunks.length; i++) { const qResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: { text: `Generate one question about the text below: ${chunks[i]}`, }, diff --git a/js/testapps/cat-eval/src/pdf_rag_firebase.ts b/js/testapps/cat-eval/src/pdf_rag_firebase.ts index 9b42954e6..06b0575dc 100644 --- a/js/testapps/cat-eval/src/pdf_rag_firebase.ts +++ b/js/testapps/cat-eval/src/pdf_rag_firebase.ts @@ -15,14 +15,13 @@ */ import { defineFirestoreRetriever } from '@genkit-ai/firebase'; -import { geminiPro } from '@genkit-ai/googleai'; -import { textEmbeddingGecko } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/googleai'; +import { textEmbedding004 } from '@genkit-ai/vertexai'; import { FieldValue } from '@google-cloud/firestore'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { readFile } from 'fs/promises'; import { run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { chunk } from 'llm-chunk'; import path from 'path'; import pdf from 'pdf-parse'; @@ -58,17 +57,15 @@ Question: ${question} Helpful Answer:`; } -export const pdfChatRetrieverFirebase = runWithRegistry(ai.registry, () => - defineFirestoreRetriever(ai, { - name: 'pdfChatRetrieverFirebase', - firestore, - collection: 'pdf-qa', - contentField: 'facts', - vectorField: 'embedding', - embedder: textEmbeddingGecko, - distanceMeasure: 'COSINE', - }) -); +export const pdfChatRetrieverFirebase = defineFirestoreRetriever(ai, { + name: 'pdfChatRetrieverFirebase', + firestore, + collection: 'pdf-qa', + contentField: 'facts', + vectorField: 'embedding', + embedder: textEmbedding004, + distanceMeasure: 'COSINE', +}); // Define a simple RAG flow, we will evaluate this flow export const pdfQAFirebase = ai.defineFlow( @@ -90,7 +87,7 @@ export const pdfQAFirebase = ai.defineFlow( context: docs.map((d) => d.text).join('\n\n'), }); const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: augmentedPrompt, }); return llmResponse.text; @@ -102,7 +99,7 @@ const indexConfig = { collection: 'pdf-qa', contentField: 'facts', vectorField: 'embedding', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }; const chunkingConfig = { diff --git a/js/testapps/dev-ui-gallery/package.json b/js/testapps/dev-ui-gallery/package.json index d738755d2..7ece82a3b 100644 --- a/js/testapps/dev-ui-gallery/package.json +++ b/js/testapps/dev-ui-gallery/package.json @@ -28,7 +28,7 @@ "@genkit-ai/firebase": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "genkit": "workspace:*", "genkitx-chromadb": "workspace:*", "genkitx-ollama": "workspace:*", diff --git a/js/testapps/dev-ui-gallery/src/genkit.ts b/js/testapps/dev-ui-gallery/src/genkit.ts index ba901f9c7..3a1e9aa97 100644 --- a/js/testapps/dev-ui-gallery/src/genkit.ts +++ b/js/testapps/dev-ui-gallery/src/genkit.ts @@ -16,12 +16,12 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { claude3Haiku, claude3Opus, claude3Sonnet, - textEmbeddingGecko, + textEmbedding004, vertexAI, VertexAIEvaluationMetricType, } from '@genkit-ai/vertexai'; @@ -95,30 +95,30 @@ export const ai = genkit({ chroma([ { collectionName: 'chroma-collection', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), devLocalVectorstore([ { indexName: 'naive-index', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), pinecone([ { indexId: 'pinecone-index', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), // evaluation genkitEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, - embedder: textEmbeddingGecko, + embedder: textEmbedding004, metrics: [ GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.FAITHFULNESS, diff --git a/js/testapps/docs-menu-basic/src/index.ts b/js/testapps/docs-menu-basic/src/index.ts index ff7fdf421..a81a854a4 100644 --- a/js/testapps/docs-menu-basic/src/index.ts +++ b/js/testapps/docs-menu-basic/src/index.ts @@ -16,7 +16,7 @@ // This sample is referenced by the genkit docs. Changes should be made to // both. -import { geminiPro, googleAI } from '@genkit-ai/googleai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { genkit, z } from 'genkit'; const ai = genkit({ @@ -33,7 +33,7 @@ export const menuSuggestionFlow = ai.defineFlow( async (subject) => { const llmResponse = await ai.generate({ prompt: `Suggest an item for the menu of a ${subject} themed restaurant`, - model: geminiPro, + model: gemini15Flash, config: { temperature: 1, }, diff --git a/js/testapps/docs-menu-rag/src/index.ts b/js/testapps/docs-menu-rag/src/index.ts index fe1705998..377bdf110 100644 --- a/js/testapps/docs-menu-rag/src/index.ts +++ b/js/testapps/docs-menu-rag/src/index.ts @@ -15,7 +15,7 @@ */ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit, z } from 'genkit'; import { indexMenu } from './indexer'; @@ -25,7 +25,7 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'menuQA', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/docs-menu-rag/src/menuQA.ts b/js/testapps/docs-menu-rag/src/menuQA.ts index 85092b85a..364409d42 100644 --- a/js/testapps/docs-menu-rag/src/menuQA.ts +++ b/js/testapps/docs-menu-rag/src/menuQA.ts @@ -15,7 +15,7 @@ */ import { devLocalRetrieverRef } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { z } from 'genkit'; import { ai } from './index.js'; @@ -34,7 +34,7 @@ export const menuQAFlow = ai.defineFlow( // generate a response const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: ` You are acting as a helpful AI assistant that can answer questions about the food available on the menu at Genkit Grub Pub. diff --git a/js/testapps/eval/src/index.ts b/js/testapps/eval/src/index.ts index 17cb994b9..e4445b740 100644 --- a/js/testapps/eval/src/index.ts +++ b/js/testapps/eval/src/index.ts @@ -15,7 +15,7 @@ */ import { genkitEval, genkitEvalRef, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { gemini15Flash, textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit, z } from 'genkit'; import { Dataset, EvalResponse, EvalResponseSchema } from 'genkit/evaluator'; @@ -23,13 +23,13 @@ const ai = genkit({ plugins: [ vertexAI(), genkitEval({ - judge: geminiPro, + judge: gemini15Flash, metrics: [ GenkitMetric.FAITHFULNESS, GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.MALICIOUSNESS, ], - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }), ], }); diff --git a/js/testapps/evaluator-gut-check/src/index.ts b/js/testapps/evaluator-gut-check/src/index.ts index c08cf5dec..14a63ca55 100644 --- a/js/testapps/evaluator-gut-check/src/index.ts +++ b/js/testapps/evaluator-gut-check/src/index.ts @@ -16,8 +16,8 @@ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator'; -import { geminiPro, googleAI } from '@genkit-ai/googleai'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; // Turn off safety checks for evaluation so that the LLM as an evaluator can @@ -47,20 +47,20 @@ const ai = genkit({ plugins: [ googleAI(), genkitEval({ - judge: geminiPro, + judge: gemini15Flash, judgeConfig: PERMISSIVE_SAFETY_SETTINGS, metrics: [ GenkitMetric.ANSWER_RELEVANCY, GenkitMetric.FAITHFULNESS, GenkitMetric.MALICIOUSNESS, ], - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }), vertexAI(), devLocalVectorstore([ { indexName: 'evaluating-evaluators', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/firebase-functions-sample1/functions/package.json b/js/testapps/firebase-functions-sample1/functions/package.json index fc8d73492..9407d1c84 100644 --- a/js/testapps/firebase-functions-sample1/functions/package.json +++ b/js/testapps/firebase-functions-sample1/functions/package.json @@ -17,8 +17,8 @@ "genkit": "*", "@genkit-ai/firebase": "*", "@genkit-ai/vertexai": "*", - "firebase-admin": "^11.8.0", - "firebase-functions": "^4.8.0 || ^5.0.0" + "firebase-admin": ">=12.2", + "firebase-functions": ">=4.8" }, "devDependencies": { "firebase-functions-test": "^3.1.0", diff --git a/js/testapps/flow-simple-ai/package.json b/js/testapps/flow-simple-ai/package.json index 6b9aca8d5..fd0f02e9e 100644 --- a/js/testapps/flow-simple-ai/package.json +++ b/js/testapps/flow-simple-ai/package.json @@ -22,7 +22,7 @@ "@genkit-ai/vertexai": "workspace:*", "@google/generative-ai": "^0.15.0", "@opentelemetry/sdk-trace-base": "^1.25.0", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "partial-json": "^0.1.7" }, "devDependencies": { diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index f5d58718c..f845323ab 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -19,20 +19,14 @@ import { enableGoogleCloudTelemetry } from '@genkit-ai/google-cloud'; import { gemini15Flash, googleAI, - geminiPro as googleGeminiPro, + gemini10Pro as googleGemini10Pro, } from '@genkit-ai/googleai'; -import { - gemini15ProPreview, - geminiPro, - textEmbeddingGecko, - vertexAI, -} from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { MessageSchema, genkit, run, z } from 'genkit'; -import { runWithRegistry } from 'genkit/registry'; import { Allow, parse } from 'partial-json'; enableGoogleCloudTelemetry({ @@ -106,7 +100,7 @@ export const streamFlow = ai.defineStreamingFlow( }, async (prompt, streamingCallback) => { const { response, stream } = await ai.generateStream({ - model: geminiPro, + model: gemini15Flash, prompt, }); @@ -148,7 +142,7 @@ export const streamJsonFlow = ai.defineStreamingFlow( } const { response, stream } = await ai.generateStream({ - model: geminiPro, + model: gemini15Flash, output: { schema: GameCharactersSchema, }, @@ -195,7 +189,7 @@ export const jokeWithToolsFlow = ai.defineFlow( { name: 'jokeWithToolsFlow', inputSchema: z.object({ - modelName: z.enum([geminiPro.name, googleGeminiPro.name]), + modelName: z.enum([gemini15Flash.name, googleGemini10Pro.name]), subject: z.string(), }), outputSchema: z.object({ model: z.string(), joke: z.string() }), @@ -246,7 +240,7 @@ export const vertexStreamer = ai.defineFlow( async (input, streamingCallback) => { return await run('call-llm', async () => { const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: `Tell me a very long joke about ${input}.`, streamingCallback, }); @@ -274,16 +268,14 @@ export const multimodalFlow = ai.defineFlow( } ); -const destinationsRetriever = runWithRegistry(ai.registry, () => - defineFirestoreRetriever(ai, { - name: 'destinationsRetriever', - firestore: getFirestore(app), - collection: 'destinations', - contentField: 'knownFor', - embedder: textEmbeddingGecko, - vectorField: 'embedding', - }) -); +const destinationsRetriever = defineFirestoreRetriever(ai, { + name: 'destinationsRetriever', + firestore: getFirestore(app), + collection: 'destinations', + contentField: 'knownFor', + embedder: textEmbedding004, + vectorField: 'embedding', +}); export const searchDestinations = ai.defineFlow( { @@ -299,7 +291,7 @@ export const searchDestinations = ai.defineFlow( }); const result = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: `Give me a list of vacation options based on the provided context. Use only the options provided below, and describe how it fits with my query. Query: ${input} @@ -375,7 +367,7 @@ export const toolCaller = ai.defineStreamingFlow( } const { response, stream } = await ai.generateStream({ - model: gemini15ProPreview, + model: gemini15Flash, config: { temperature: 1, }, diff --git a/js/testapps/menu/src/01/prompts.ts b/js/testapps/menu/src/01/prompts.ts index 9ef465a46..4a0e7cbe6 100644 --- a/js/testapps/menu/src/01/prompts.ts +++ b/js/testapps/menu/src/01/prompts.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { GenerateRequest } from 'genkit'; import { ai } from '../index.js'; import { MenuQuestionInput, MenuQuestionInputSchema } from '../types.js'; @@ -50,7 +50,7 @@ export const s01_vanillaPrompt = ai.definePrompt( export const s01_staticMenuDotPrompt = ai.definePrompt( { name: 's01_staticMenuDotPrompt', - model: geminiPro, + model: gemini15Flash, input: { schema: MenuQuestionInputSchema }, output: { format: 'text' }, }, diff --git a/js/testapps/menu/src/02/prompts.ts b/js/testapps/menu/src/02/prompts.ts index 82ba0bae1..c21696859 100644 --- a/js/testapps/menu/src/02/prompts.ts +++ b/js/testapps/menu/src/02/prompts.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { ai } from '../index.js'; import { MenuQuestionInputSchema } from '../types.js'; import { menuTool } from './tools.js'; @@ -25,7 +25,7 @@ import { menuTool } from './tools.js'; export const s02_dataMenuPrompt = ai.definePrompt( { name: 's02_dataMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: MenuQuestionInputSchema }, output: { format: 'text' }, tools: [menuTool], diff --git a/js/testapps/menu/src/03/flows.ts b/js/testapps/menu/src/03/flows.ts index 71521320b..9374b9660 100644 --- a/js/testapps/menu/src/03/flows.ts +++ b/js/testapps/menu/src/03/flows.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { run } from 'genkit'; import { MessageData } from 'genkit/model'; import { ai } from '../index.js'; @@ -78,7 +78,7 @@ export const s03_multiTurnChatFlow = ai.defineFlow( // Generate the response const llmResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, messages: history, prompt: { text: input.question, diff --git a/js/testapps/menu/src/04/prompts.ts b/js/testapps/menu/src/04/prompts.ts index be5076ab2..000a06c4e 100644 --- a/js/testapps/menu/src/04/prompts.ts +++ b/js/testapps/menu/src/04/prompts.ts @@ -14,14 +14,14 @@ * limitations under the License. */ -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { ai } from '../index.js'; import { DataMenuQuestionInputSchema } from '../types.js'; export const s04_ragDataMenuPrompt = ai.definePrompt( { name: 's04_ragDataMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: DataMenuQuestionInputSchema }, output: { format: 'text' }, config: { temperature: 0.3 }, diff --git a/js/testapps/menu/src/05/prompts.ts b/js/testapps/menu/src/05/prompts.ts index ffd6ce784..149e576f4 100644 --- a/js/testapps/menu/src/05/prompts.ts +++ b/js/testapps/menu/src/05/prompts.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { geminiPro, geminiProVision } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import { z } from 'genkit'; import { ai } from '../index.js'; import { TextMenuQuestionInputSchema } from '../types.js'; @@ -22,7 +22,7 @@ import { TextMenuQuestionInputSchema } from '../types.js'; export const s05_readMenuPrompt = ai.definePrompt( { name: 's05_readMenu', - model: geminiProVision, + model: gemini15Flash, input: { schema: z.object({ imageUrl: z.string(), @@ -42,7 +42,7 @@ from the following image of a restaurant menu. export const s05_textMenuPrompt = ai.definePrompt( { name: 's05_textMenu', - model: geminiPro, + model: gemini15Flash, input: { schema: TextMenuQuestionInputSchema }, output: { format: 'text' }, config: { temperature: 0.3 }, diff --git a/js/testapps/menu/src/index.ts b/js/testapps/menu/src/index.ts index 276de8628..bdf1b3b98 100644 --- a/js/testapps/menu/src/index.ts +++ b/js/testapps/menu/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore'; -import { textEmbeddingGecko, vertexAI } from '@genkit-ai/vertexai'; +import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; // Initialize Genkit @@ -26,7 +26,7 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'menu-items', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, embedderOptions: { taskType: 'RETRIEVAL_DOCUMENT' }, }, ]), diff --git a/js/testapps/model-tester/src/index.ts b/js/testapps/model-tester/src/index.ts index 5d874c785..cc9fb5361 100644 --- a/js/testapps/model-tester/src/index.ts +++ b/js/testapps/model-tester/src/index.ts @@ -44,7 +44,7 @@ export const ai = genkit({ ], }); -testModels([ +testModels(ai.registry, [ 'googleai/gemini-1.5-pro-latest', 'googleai/gemini-1.5-flash-latest', 'vertexai/gemini-1.5-pro', diff --git a/js/testapps/rag/src/genkit.ts b/js/testapps/rag/src/genkit.ts index 9a0c995c6..5e2cd4163 100644 --- a/js/testapps/rag/src/genkit.ts +++ b/js/testapps/rag/src/genkit.ts @@ -20,7 +20,7 @@ import { gemini15Flash, googleAI } from '@genkit-ai/googleai'; import { claude3Sonnet, llama31, - textEmbeddingGecko, + textEmbedding004, vertexAI, } from '@genkit-ai/vertexai'; import { genkit } from 'genkit'; @@ -83,17 +83,17 @@ export const ai = genkit({ pinecone([ { indexId: 'cat-facts', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, { indexId: 'pdf-chat', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), chroma([ { collectionName: 'dogfacts_collection', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, createCollectionIfMissing: true, clientParams: async () => { // Replace this with your Cloud Run Instance URL @@ -114,11 +114,11 @@ export const ai = genkit({ devLocalVectorstore([ { indexName: 'dog-facts', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, { indexName: 'pdfQA', - embedder: textEmbeddingGecko, + embedder: textEmbedding004, }, ]), ], diff --git a/js/testapps/rag/src/pdf_rag.ts b/js/testapps/rag/src/pdf_rag.ts index d02dbad04..018b21ecb 100644 --- a/js/testapps/rag/src/pdf_rag.ts +++ b/js/testapps/rag/src/pdf_rag.ts @@ -18,7 +18,7 @@ import { devLocalIndexerRef, devLocalRetrieverRef, } from '@genkit-ai/dev-local-vectorstore'; -import { geminiPro } from '@genkit-ai/vertexai'; +import { gemini15Flash } from '@genkit-ai/vertexai'; import fs from 'fs'; import { Document, run, z } from 'genkit'; import { chunk } from 'llm-chunk'; @@ -117,7 +117,7 @@ export const synthesizeQuestions = ai.defineFlow( const questions: string[] = []; for (let i = 0; i < chunks.length; i++) { const qResponse = await ai.generate({ - model: geminiPro, + model: gemini15Flash, prompt: { text: `Generate one question about the text below: ${chunks[i]}`, }, diff --git a/js/testapps/vertexai-vector-search-firestore/package.json b/js/testapps/vertexai-vector-search-firestore/package.json index aca1337d0..cf0a1f390 100644 --- a/js/testapps/vertexai-vector-search-firestore/package.json +++ b/js/testapps/vertexai-vector-search-firestore/package.json @@ -22,7 +22,7 @@ "@genkit-ai/vertexai": "workspace:*", "dotenv": "^16.4.5", "express": "^4.21.0", - "firebase-admin": "^12.1.0", + "firebase-admin": ">=12.2", "genkitx-chromadb": "workspace:*", "genkitx-langchain": "workspace:*", "genkitx-pinecone": "workspace:*", diff --git a/package.json b/package.json index 8812a0515..f14d1f6ce 100644 --- a/package.json +++ b/package.json @@ -40,5 +40,5 @@ "ts-node": "^10.9.2", "tsx": "^4.7.1" }, - "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726" + "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2" } diff --git a/scripts/release_main.sh b/scripts/release_main.sh new file mode 100755 index 000000000..b28e12eb2 --- /dev/null +++ b/scripts/release_main.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# git clone git@github.com:firebase/genkit.git +# cd genkit +# pnpm i +# pnpm build +# pnpm test:all +# Run from root: scripts/release_main.sh + +pnpm login --registry https://wombat-dressing-room.appspot.com + + +CURRENT=`pwd` + +cd genkit-tools/common +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/cli +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/core +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/ai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/flow +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dotprompt +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/chroma +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dev-local-vectorstore +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/firebase +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/google-cloud +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/googleai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/ollama +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/pinecone +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/vertexai +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/evaluators +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/langchain +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + diff --git a/scripts/release_next.sh b/scripts/release_next.sh new file mode 100755 index 000000000..3e62e982b --- /dev/null +++ b/scripts/release_next.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# git clone git@github.com:firebase/genkit.git +# cd genkit +# git checkout next +# pnpm i +# pnpm build +# pnpm test:all + +# Run from root: scripts/release_next.sh + +pnpm login --registry https://wombat-dressing-room.appspot.com + +CURRENT=`pwd` + +cd genkit-tools/cli +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/common +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd genkit-tools/telemetry-server +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/core +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/ai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/genkit +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dotprompt +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/chroma +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/dev-local-vectorstore +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/firebase +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/google-cloud +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/googleai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/ollama +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/pinecone +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/vertexai +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/evaluators +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + +cd js/plugins/langchain +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT +