diff --git a/README.md b/README.md
index 1af64b3d3..8f076de7a 100644
--- a/README.md
+++ b/README.md
@@ -88,9 +88,14 @@ Find excellent examples of community-built plugins for OpenAI, Anthropic, Cohere
## Try Genkit on IDX
-
-
-Want to try Genkit without a local setup? [Explore it on Project IDX](https://idx.google.com/new/genkit), Google's AI-assisted workspace for full-stack app development in the cloud.
+Want to skip the local setup? Click below to try out Genkit using [Project IDX](https://idx.dev), Google's AI-assisted workspace for full-stack app development in the cloud.
+
+
+
+
## Sample apps
diff --git a/docs/plugins/google-cloud.md b/docs/plugins/google-cloud.md
index 5314558bc..6135eb8f2 100644
--- a/docs/plugins/google-cloud.md
+++ b/docs/plugins/google-cloud.md
@@ -221,32 +221,84 @@ Common dimensions include:
- `topK` - the inference topK [value](https://ai.google.dev/docs/concepts#model-parameters).
- `topP` - the inference topP [value](https://ai.google.dev/docs/concepts#model-parameters).
-### Flow-level metrics
+### Feature-level metrics
+
+Features are the top-level entry-point to your Genkit code. In most cases, this
+will be a flow, but if you do not use flows, this will be the top-most span in a trace.
+
+| Name | Type | Description |
+| ----------------------- | --------- | ----------------------- |
+| genkit/feature/requests | Counter | Number of requests |
+| genkit/feature/latency | Histogram | Execution latency in ms |
+
+Each feature-level metric contains the following dimensions:
+
+| Name | Description |
+| ------------- | -------------------------------------------------------------------------------- |
+| name | The name of the feature. In most cases, this is the top-level Genkit flow |
+| status | 'success' or 'failure' depending on whether or not the feature request succeeded |
+| error | Only set when `status=failure`. Contains the error type that caused the failure |
+| source | The Genkit source language. Eg. 'ts' |
+| sourceVersion | The Genkit framework version |
-| Name | Dimensions |
-| -------------------- | ------------------------------------ |
-| genkit/flow/requests | flow_name, error_code, error_message |
-| genkit/flow/latency | flow_name |
### Action-level metrics
-| Name | Dimensions |
-| ---------------------- | ------------------------------------ |
-| genkit/action/requests | flow_name, error_code, error_message |
-| genkit/action/latency | flow_name |
+Actions represent a generic step of execution within Genkit. Each of these steps
+will have the following metrics tracked:
+
+| Name | Type | Description |
+| ----------------------- | --------- | --------------------------------------------- |
+| genkit/action/requests | Counter | Number of times this action has been executed |
+| genkit/action/latency | Histogram | Execution latency in ms |
+
+Each action-level metric contains the following dimensions:
+
+| Name | Description |
+| ------------- | ---------------------------------------------------------------------------------------------------- |
+| name | The name of the action |
+| featureName | The name of the parent feature being executed |
+| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' |
+| status | 'success' or 'failure' depending on whether or not the action succeeded |
+| error | Only set when `status=failure`. Contains the error type that caused the failure |
+| source | The Genkit source language. Eg. 'ts' |
+| sourceVersion | The Genkit framework version |
### Generate-level metrics
-| Name | Dimensions |
-| ------------------------------------ | -------------------------------------------------------------------- |
-| genkit/ai/generate | flow_path, model, temperature, topK, topP, error_code, error_message |
-| genkit/ai/generate/input_tokens | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/output_tokens | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/input_characters | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/output_characters | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/input_images | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/output_images | flow_path, model, temperature, topK, topP |
-| genkit/ai/generate/latency | flow_path, model, temperature, topK, topP, error_code, error_message |
+These are special action metrics relating to actions that interact with a model.
+In addition to requests and latency, input and output are also tracked, with model
+specific dimensions that make debugging and configuration tuning easier.
+
+| Name | Type | Description |
+| ------------------------------------ | --------- | ------------------------------------------ |
+| genkit/ai/generate/requests | Counter | Number of times this model has been called |
+| genkit/ai/generate/latency | Histogram | Execution latency in ms |
+| genkit/ai/generate/input/tokens | Counter | Input tokens |
+| genkit/ai/generate/output/tokens | Counter | Output tokens |
+| genkit/ai/generate/input/characters | Counter | Input characters |
+| genkit/ai/generate/output/characters | Counter | Output characters |
+| genkit/ai/generate/input/images | Counter | Input images |
+| genkit/ai/generate/output/images | Counter | Output images |
+| genkit/ai/generate/input/audio | Counter | Input audio files |
+| genkit/ai/generate/output/audio | Counter | Output audio files |
+
+Each generate-level metric contains the following dimensions:
+
+| Name | Description |
+| --------------- | ---------------------------------------------------------------------------------------------------- |
+| modelName | The name of the model |
+| featureName | The name of the parent feature being executed |
+| path | The path of execution from the feature root to this action. eg. '/myFeature/parentAction/thisAction' |
+| temperature | The temperature parameter passed to the model |
+| maxOutputTokens | The maxOutputTokens parameter passed to the model |
+| topK | The topK parameter passed to the model |
+| topP | The topP parameter passed to the model |
+| latencyMs | The response time taken by the model |
+| status | 'success' or 'failure' depending on whether or not the feature request succeeded |
+| error | Only set when `status=failure`. Contains the error type that caused the failure |
+| source | The Genkit source language. Eg. 'ts' |
+| sourceVersion | The Genkit framework version |
Visualizing metrics can be done through the Metrics Explorer. Using the side menu, select 'Logging' and click 'Metrics explorer'
diff --git a/genkit-tools/cli/package.json b/genkit-tools/cli/package.json
index 335225733..de2ecd1cb 100644
--- a/genkit-tools/cli/package.json
+++ b/genkit-tools/cli/package.json
@@ -1,6 +1,6 @@
{
"name": "genkit-cli",
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"description": "CLI for interacting with the Google Genkit AI framework",
"license": "Apache-2.0",
"keywords": [
@@ -28,7 +28,7 @@
"dependencies": {
"@genkit-ai/tools-common": "workspace:*",
"@genkit-ai/telemetry-server": "workspace:*",
- "axios": "^1.6.7",
+ "axios": "^1.7.7",
"colorette": "^2.0.20",
"commander": "^11.1.0",
"extract-zip": "^2.0.1",
diff --git a/genkit-tools/common/package.json b/genkit-tools/common/package.json
index f77bbbca7..97fedae33 100644
--- a/genkit-tools/common/package.json
+++ b/genkit-tools/common/package.json
@@ -1,6 +1,6 @@
{
"name": "@genkit-ai/tools-common",
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"scripts": {
"compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json",
"build:clean": "rimraf ./lib",
@@ -12,7 +12,7 @@
"@asteasolutions/zod-to-openapi": "^7.0.0",
"@trpc/server": "10.45.0",
"adm-zip": "^0.5.12",
- "axios": "^1.6.7",
+ "axios": "^1.7.7",
"body-parser": "^1.20.2",
"chokidar": "^3.5.3",
"colorette": "^2.0.20",
diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts
index 211fb3ba4..a0727a95b 100644
--- a/genkit-tools/common/src/eval/evaluate.ts
+++ b/genkit-tools/common/src/eval/evaluate.ts
@@ -216,11 +216,12 @@ async function bulkRunAction(params: {
testCaseId: c.testCaseId ?? generateTestCaseId(),
}));
+ let states: InferenceRunState[] = [];
let evalInputs: EvalInput[] = [];
for (const testCase of testCases) {
- logger.info(`Running '${actionRef}' ...`);
+ logger.info(`Running inference '${actionRef}' ...`);
if (isModelAction) {
- evalInputs.push(
+ states.push(
await runModelAction({
manager,
actionRef,
@@ -229,7 +230,7 @@ async function bulkRunAction(params: {
})
);
} else {
- evalInputs.push(
+ states.push(
await runFlowAction({
manager,
actionRef,
@@ -239,6 +240,11 @@ async function bulkRunAction(params: {
);
}
}
+
+ logger.info(`Gathering evalInputs...`);
+ for (const state of states) {
+ evalInputs.push(await gatherEvalInput({ manager, actionRef, state }));
+ }
return evalInputs;
}
@@ -247,7 +253,7 @@ async function runFlowAction(params: {
actionRef: string;
testCase: TestCase;
auth?: any;
-}): Promise {
+}): Promise {
const { manager, actionRef, testCase, auth } = { ...params };
let state: InferenceRunState;
try {
@@ -274,7 +280,7 @@ async function runFlowAction(params: {
evalError: `Error when running inference. Details: ${e?.message ?? e}`,
};
}
- return gatherEvalInput({ manager, actionRef, state });
+ return state;
}
async function runModelAction(params: {
@@ -282,7 +288,7 @@ async function runModelAction(params: {
actionRef: string;
testCase: TestCase;
modelConfig?: any;
-}): Promise {
+}): Promise {
const { manager, actionRef, modelConfig, testCase } = { ...params };
let state: InferenceRunState;
try {
@@ -304,7 +310,7 @@ async function runModelAction(params: {
evalError: `Error when running inference. Details: ${e?.message ?? e}`,
};
}
- return gatherEvalInput({ manager, actionRef, state });
+ return state;
}
async function gatherEvalInput(params: {
diff --git a/genkit-tools/package.json b/genkit-tools/package.json
index 09833d453..641914d51 100644
--- a/genkit-tools/package.json
+++ b/genkit-tools/package.json
@@ -23,5 +23,5 @@
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.4"
},
- "packageManager": "pnpm@9.12.0+sha256.a61b67ff6cc97af864564f4442556c22a04f2e5a7714fbee76a1011361d9b726"
+ "packageManager": "pnpm@9.12.2+sha256.2ef6e547b0b07d841d605240dce4d635677831148cd30f6d564b8f4f928f73d2"
}
diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml
index 7c873c3a8..c8fe635f0 100644
--- a/genkit-tools/pnpm-lock.yaml
+++ b/genkit-tools/pnpm-lock.yaml
@@ -36,8 +36,8 @@ importers:
specifier: workspace:*
version: link:../common
axios:
- specifier: ^1.6.7
- version: 1.6.8
+ specifier: ^1.7.7
+ version: 1.7.7
colorette:
specifier: ^2.0.20
version: 2.0.20
@@ -97,8 +97,8 @@ importers:
specifier: ^0.5.12
version: 0.5.12
axios:
- specifier: ^1.6.7
- version: 1.6.8
+ specifier: ^1.7.7
+ version: 1.7.7
body-parser:
specifier: ^1.20.2
version: 1.20.2
@@ -1179,8 +1179,8 @@ packages:
resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==}
engines: {node: '>= 0.4'}
- axios@1.6.8:
- resolution: {integrity: sha512-v/ZHtJDU39mDpyBoFVkETcd/uNdxrWRrg3bKpOKzXFA6Bvqopts6ALSMU3y6ijYxbw2B+wPrIv46egTzJXCLGQ==}
+ axios@1.7.7:
+ resolution: {integrity: sha512-S4kL7XrjgBmvdGut0sN3yJxqYzrDOnivkBiN0OFs6hLiUam3UPvswUo0kqGyhqUZGEOytHyumEdXsAkgCOUf3Q==}
babel-jest@29.7.0:
resolution: {integrity: sha512-BrvGY3xZSwEcCzKvKsCi2GgHqDqsYkOP4/by5xCgIwGXQxIEh+8ew3gmrE1y7XRR6LHZIj6yLYnUi/mm2KXKBg==}
@@ -4201,7 +4201,7 @@ snapshots:
dependencies:
possible-typed-array-names: 1.0.0
- axios@1.6.8:
+ axios@1.7.7:
dependencies:
follow-redirects: 1.15.6
form-data: 4.0.0
diff --git a/genkit-tools/telemetry-server/package.json b/genkit-tools/telemetry-server/package.json
index 3665f657a..a49050aa1 100644
--- a/genkit-tools/telemetry-server/package.json
+++ b/genkit-tools/telemetry-server/package.json
@@ -7,7 +7,7 @@
"genai",
"generative-ai"
],
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"type": "commonjs",
"scripts": {
"compile": "tsc -b ./tsconfig.cjs.json ./tsconfig.esm.json ./tsconfig.types.json",
diff --git a/js/ai/package.json b/js/ai/package.json
index 047817be9..e9e6d2cde 100644
--- a/js/ai/package.json
+++ b/js/ai/package.json
@@ -7,7 +7,7 @@
"genai",
"generative-ai"
],
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"type": "commonjs",
"scripts": {
"check": "tsc",
diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts
index 5cd278130..89b050253 100644
--- a/js/ai/src/embedder.ts
+++ b/js/ai/src/embedder.ts
@@ -15,7 +15,7 @@
*/
import { Action, defineAction, z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
export type EmbeddingBatch = { embedding: number[] }[];
@@ -68,6 +68,7 @@ function withMetadata(
export function defineEmbedder<
ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
options: {
name: string;
configSchema?: ConfigSchema;
@@ -76,6 +77,7 @@ export function defineEmbedder<
runner: EmbedderFn
) {
const embedder = defineAction(
+ registry,
{
actionType: 'embedder',
name: options.name,
@@ -111,47 +113,91 @@ export type EmbedderArgument<
* A veneer for interacting with embedder models.
*/
export async function embed(
+ registry: Registry,
params: EmbedderParams
): Promise {
- let embedder: EmbedderAction;
- if (typeof params.embedder === 'string') {
- embedder = await lookupAction(`/embedder/${params.embedder}`);
- } else if (Object.hasOwnProperty.call(params.embedder, 'info')) {
- embedder = await lookupAction(
- `/embedder/${(params.embedder as EmbedderReference).name}`
- );
- } else {
- embedder = params.embedder as EmbedderAction;
- }
- if (!embedder) {
- throw new Error('Unable to utilize the provided embedder');
+ let embedder = await resolveEmbedder(registry, params);
+ if (!embedder.embedderAction) {
+ let embedderId: string;
+ if (typeof params.embedder === 'string') {
+ embedderId = params.embedder;
+ } else if ((params.embedder as EmbedderAction)?.__action?.name) {
+ embedderId = (params.embedder as EmbedderAction).__action.name;
+ } else {
+ embedderId = (params.embedder as EmbedderReference).name;
+ }
+ throw new Error(`Unable to resolve embedder ${embedderId}`);
}
- const response = await embedder({
+ const response = await embedder.embedderAction({
input:
typeof params.content === 'string'
? [Document.fromText(params.content, params.metadata)]
: [params.content],
- options: params.options,
+ options: {
+ version: embedder.version,
+ ...embedder.config,
+ ...params.options,
+ },
});
return response.embeddings[0].embedding;
}
+interface ResolvedEmbedder {
+ embedderAction: EmbedderAction;
+ config?: z.infer;
+ version?: string;
+}
+
+async function resolveEmbedder<
+ CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
+>(
+ registry: Registry,
+ params: EmbedderParams
+): Promise> {
+ if (typeof params.embedder === 'string') {
+ return {
+ embedderAction: await registry.lookupAction(
+ `/embedder/${params.embedder}`
+ ),
+ };
+ } else if (Object.hasOwnProperty.call(params.embedder, '__action')) {
+ return {
+ embedderAction: params.embedder as EmbedderAction,
+ };
+ } else if (Object.hasOwnProperty.call(params.embedder, 'name')) {
+ const ref = params.embedder as EmbedderReference;
+ return {
+ embedderAction: await registry.lookupAction(
+ `/embedder/${(params.embedder as EmbedderReference).name}`
+ ),
+ config: {
+ ...ref.config,
+ },
+ version: ref.version,
+ };
+ }
+ throw new Error(`failed to resolve embedder ${params.embedder}`);
+}
+
/**
* A veneer for interacting with embedder models in bulk.
*/
export async function embedMany<
ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny,
->(params: {
- embedder: EmbedderArgument;
- content: string[] | DocumentData[];
- metadata?: Record;
- options?: z.infer;
-}): Promise {
+>(
+ registry: Registry,
+ params: {
+ embedder: EmbedderArgument;
+ content: string[] | DocumentData[];
+ metadata?: Record;
+ options?: z.infer;
+ }
+): Promise {
let embedder: EmbedderAction;
if (typeof params.embedder === 'string') {
- embedder = await lookupAction(`/embedder/${params.embedder}`);
+ embedder = await registry.lookupAction(`/embedder/${params.embedder}`);
} else if (Object.hasOwnProperty.call(params.embedder, 'info')) {
- embedder = await lookupAction(
+ embedder = await registry.lookupAction(
`/embedder/${(params.embedder as EmbedderReference).name}`
);
} else {
@@ -192,6 +238,8 @@ export interface EmbedderReference<
name: string;
configSchema?: CustomOptions;
info?: EmbedderInfo;
+ config?: z.infer;
+ version?: string;
}
/**
diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts
index 042e069d9..02be11e48 100644
--- a/js/ai/src/evaluator.ts
+++ b/js/ai/src/evaluator.ts
@@ -16,7 +16,7 @@
import { Action, defineAction, z } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import { randomUUID } from 'crypto';
@@ -127,6 +127,7 @@ export function defineEvaluator<
typeof BaseEvalDataPointSchema = typeof BaseEvalDataPointSchema,
EvaluatorOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
options: {
name: string;
displayName: string;
@@ -143,6 +144,7 @@ export function defineEvaluator<
metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName;
metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition;
const evaluator = defineAction(
+ registry,
{
actionType: 'evaluator',
name: options.name,
@@ -239,12 +241,17 @@ export type EvaluatorArgument<
export async function evaluate<
DataPoint extends typeof BaseDataPointSchema = typeof BaseDataPointSchema,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
->(params: EvaluatorParams): Promise {
+>(
+ registry: Registry,
+ params: EvaluatorParams
+): Promise {
let evaluator: EvaluatorAction;
if (typeof params.evaluator === 'string') {
- evaluator = await lookupAction(`/evaluator/${params.evaluator}`);
+ evaluator = await registry.lookupAction(`/evaluator/${params.evaluator}`);
} else if (Object.hasOwnProperty.call(params.evaluator, 'info')) {
- evaluator = await lookupAction(`/evaluator/${params.evaluator.name}`);
+ evaluator = await registry.lookupAction(
+ `/evaluator/${params.evaluator.name}`
+ );
} else {
evaluator = params.evaluator as EvaluatorAction;
}
diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts
index b4e95716a..cf5231520 100755
--- a/js/ai/src/generate.ts
+++ b/js/ai/src/generate.ts
@@ -21,7 +21,7 @@ import {
StreamingCallback,
z,
} from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
@@ -365,6 +365,7 @@ export class GenerateResponseChunk
}
export async function toGenerateRequest(
+ registry: Registry,
options: GenerateOptions
): Promise {
const messages: MessageData[] = [];
@@ -402,7 +403,7 @@ export async function toGenerateRequest(
}
let tools: Action[] | undefined;
if (options.tools) {
- tools = await resolveTools(options.tools);
+ tools = await resolveTools(registry, options.tools);
}
const out = {
@@ -464,21 +465,28 @@ interface ResolvedModel {
version?: string;
}
-async function resolveModel(options: GenerateOptions): Promise {
+async function resolveModel(
+ registry: Registry,
+ options: GenerateOptions
+): Promise {
let model = options.model;
if (!model) {
throw new Error('Model is required.');
}
if (typeof model === 'string') {
return {
- modelAction: (await lookupAction(`/model/${model}`)) as ModelAction,
+ modelAction: (await registry.lookupAction(
+ `/model/${model}`
+ )) as ModelAction,
};
} else if (model.hasOwnProperty('__action')) {
return { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference;
return {
- modelAction: (await lookupAction(`/model/${ref.name}`)) as ModelAction,
+ modelAction: (await registry.lookupAction(
+ `/model/${ref.name}`
+ )) as ModelAction,
config: {
...ref.config,
},
@@ -525,13 +533,14 @@ export async function generate<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
+ registry: Registry,
options:
| GenerateOptions
| PromiseLike>
): Promise>> {
const resolvedOptions: GenerateOptions =
await Promise.resolve(options);
- const resolvedModel = await resolveModel(resolvedOptions);
+ const resolvedModel = await resolveModel(registry, resolvedOptions);
const model = resolvedModel.modelAction;
if (!model) {
let modelId: string;
@@ -603,9 +612,9 @@ export async function generate<
messages,
tools,
config: {
- ...resolvedModel.config,
version: resolvedModel.version,
- ...resolvedOptions.config,
+ ...stripUndefinedOptions(resolvedModel.config),
+ ...stripUndefinedOptions(resolvedOptions.config),
},
output: resolvedOptions.output && {
format: resolvedOptions.output.format,
@@ -623,12 +632,23 @@ export async function generate<
resolvedOptions.streamingCallback,
async () =>
new GenerateResponse(
- await generateHelper(params, resolvedOptions.use),
- await toGenerateRequest(resolvedOptions)
+ await generateHelper(registry, params, resolvedOptions.use),
+ await toGenerateRequest(registry, resolvedOptions)
)
);
}
+function stripUndefinedOptions(input?: any): any {
+ if (!input) return input;
+ const copy = { ...input };
+ Object.keys(input).forEach((key) => {
+ if (copy[key] === undefined) {
+ delete copy[key];
+ }
+ });
+ return copy;
+}
+
export type GenerateStreamOptions<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
@@ -653,6 +673,7 @@ export async function generateStream<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
>(
+ registry: Registry,
options:
| GenerateOptions
| PromiseLike>
@@ -678,7 +699,7 @@ export async function generateStream<
}
try {
- generate({
+ generate(registry, {
...options,
streamingCallback: (chunk) => {
firstChunkSent = true;
diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts
index 996f77af3..7cb4f2d71 100644
--- a/js/ai/src/generateAction.ts
+++ b/js/ai/src/generateAction.ts
@@ -21,7 +21,7 @@ import {
runWithStreamingCallback,
z,
} from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import * as clc from 'colorette';
@@ -70,6 +70,7 @@ export const GenerateUtilParamSchema = z.object({
* Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware.
*/
export async function generateHelper(
+ registry: Registry,
input: z.infer,
middleware?: Middleware[]
): Promise {
@@ -86,7 +87,7 @@ export async function generateHelper(
async (metadata) => {
metadata.name = 'generate';
metadata.input = input;
- const output = await generate(input, middleware);
+ const output = await generate(registry, input, middleware);
metadata.output = JSON.stringify(output);
return output;
}
@@ -94,10 +95,11 @@ export async function generateHelper(
}
async function generate(
+ registry: Registry,
rawRequest: z.infer,
middleware?: Middleware[]
): Promise {
- const model = (await lookupAction(
+ const model = (await registry.lookupAction(
`/model/${rawRequest.model}`
)) as ModelAction;
if (!model) {
@@ -120,7 +122,7 @@ async function generate(
tools = await Promise.all(
rawRequest.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
- const tool = (await lookupAction(toolRef)) as ToolAction;
+ const tool = (await registry.lookupAction(toolRef)) as ToolAction;
if (!tool) {
throw new Error(`Tool ${toolRef} not found`);
}
@@ -203,7 +205,7 @@ async function generate(
messages: [...request.messages, message],
prompt: toolResponses,
};
- return await generateHelper(nextRequest, middleware);
+ return await generateHelper(registry, nextRequest, middleware);
}
async function actionToGenerateRequest(
diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts
index 75981dafc..b0945b40c 100644
--- a/js/ai/src/index.ts
+++ b/js/ai/src/index.ts
@@ -72,9 +72,11 @@ export {
export {
definePrompt,
renderPrompt,
+ type ExecutablePrompt,
type PromptAction,
type PromptConfig,
type PromptFn,
+ type PromptGenerateOptions,
} from './prompt.js';
export {
rerank,
diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts
index 01aa159f7..4a176e70b 100644
--- a/js/ai/src/model.ts
+++ b/js/ai/src/model.ts
@@ -22,6 +22,7 @@ import {
StreamingCallback,
z,
} from '@genkit-ai/core';
+import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { performance } from 'node:perf_hooks';
import { DocumentDataSchema } from './document.js';
@@ -330,6 +331,7 @@ export type DefineModelOptions<
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
options: DefineModelOptions,
runner: (
request: GenerateRequest,
@@ -344,6 +346,7 @@ export function defineModel<
if (!options?.supports?.context) middleware.push(augmentWithContext());
middleware.push(conformOutput());
const act = defineAction(
+ registry,
{
actionType: 'model',
name: options.name,
@@ -386,15 +389,36 @@ export interface ModelReference {
info?: ModelInfo;
version?: string;
config?: z.infer;
+
+ withConfig(cfg: z.infer): ModelReference;
+ withVersion(version: string): ModelReference;
}
/** Cretes a model reference. */
export function modelRef<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
- options: ModelReference
+ options: Omit<
+ ModelReference,
+ 'withConfig' | 'withVersion'
+ >
): ModelReference {
- return { ...options };
+ const ref: Partial> = { ...options };
+ ref.withConfig = (
+ cfg: z.infer
+ ): ModelReference => {
+ return modelRef({
+ ...options,
+ config: cfg,
+ });
+ };
+ ref.withVersion = (version: string): ModelReference => {
+ return modelRef({
+ ...options,
+ version,
+ });
+ };
+ return ref as ModelReference;
}
/** Container for counting usage stats for a single input/output {Part} */
@@ -433,16 +457,20 @@ export function getBasicUsageStats(
function getPartCounts(parts: Part[]): PartCounts {
return parts.reduce(
(counts, part) => {
+ const isImage =
+ part.media?.contentType?.startsWith('image') ||
+ part.media?.url?.startsWith('data:image');
+ const isVideo =
+ part.media?.contentType?.startsWith('video') ||
+ part.media?.url?.startsWith('data:video');
+ const isAudio =
+ part.media?.contentType?.startsWith('audio') ||
+ part.media?.url?.startsWith('data:audio');
return {
characters: counts.characters + (part.text?.length || 0),
- images:
- counts.images +
- (part.media?.contentType?.startsWith('image') ? 1 : 0),
- videos:
- counts.videos +
- (part.media?.contentType?.startsWith('video') ? 1 : 0),
- audio:
- counts.audio + (part.media?.contentType?.startsWith('audio') ? 1 : 0),
+ images: counts.images + (isImage ? 1 : 0),
+ videos: counts.videos + (isVideo ? 1 : 0),
+ audio: counts.audio + (isAudio ? 1 : 0),
};
},
{ characters: 0, images: 0, videos: 0, audio: 0 }
diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts
index b520b5257..f497dca23 100644
--- a/js/ai/src/prompt.ts
+++ b/js/ai/src/prompt.ts
@@ -15,14 +15,19 @@
*/
import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { DocumentData } from './document.js';
-import { GenerateOptions } from './generate.js';
+import {
+ GenerateOptions,
+ GenerateResponse,
+ GenerateStreamResponse,
+} from './generate.js';
import {
GenerateRequest,
GenerateRequestSchema,
ModelArgument,
} from './model.js';
+import { ToolAction } from './tool.js';
export type PromptFn<
I extends z.ZodTypeAny = z.ZodTypeAny,
@@ -58,6 +63,84 @@ export function isPrompt(arg: any): boolean {
);
}
+export type PromptGenerateOptions<
+ I = undefined,
+ CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
+> = Omit<
+ GenerateOptions,
+ 'prompt' | 'input' | 'model'
+> & {
+ model?: ModelArgument;
+ input?: I;
+};
+
+/**
+ * A prompt that can be executed as a function.
+ */
+export interface ExecutablePrompt<
+ I = undefined,
+ O extends z.ZodTypeAny = z.ZodTypeAny,
+ CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
+> {
+ /**
+ * Generates a response by rendering the prompt template with given user input and then calling the model.
+ *
+ * @param input Prompt inputs.
+ * @param opt Options for the prompt template, including user input variables and custom model configuration options.
+ * @returns the model response as a promise of `GenerateStreamResponse`.
+ */
+ (
+ input?: I,
+ opts?: PromptGenerateOptions
+ ): Promise>>;
+
+ /**
+ * Generates a response by rendering the prompt template with given user input and then calling the model.
+ * @param input Prompt inputs.
+ * @param opt Options for the prompt template, including user input variables and custom model configuration options.
+ * @returns the model response as a promise of `GenerateStreamResponse`.
+ */
+ stream(
+ input?: I,
+ opts?: PromptGenerateOptions
+ ): Promise>>;
+
+ /**
+ * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model.
+ *
+ * @param opt Options for the prompt template, including user input variables and custom model configuration options.
+ * @returns the model response as a promise of `GenerateResponse`.
+ */
+ generate(
+ opt: PromptGenerateOptions
+ ): Promise>>;
+
+ /**
+ * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model.
+ *
+ * @param opt Options for the prompt template, including user input variables and custom model configuration options.
+ * @returns the model response as a promise of `GenerateStreamResponse`.
+ */
+ generateStream(
+ opt: PromptGenerateOptions
+ ): Promise>>;
+
+ /**
+ * Renders the prompt template based on user input.
+ *
+ * @param opt Options for the prompt template, including user input variables and custom model configuration options.
+ * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai.
+ */
+ render(
+ opt: PromptGenerateOptions
+ ): Promise>;
+
+ /**
+ * Returns the prompt usable as a tool.
+ */
+ asTool(): ToolAction;
+}
+
/**
* Defines and registers a prompt action. The action can be called to obtain
* a `GenerateRequest` which can be passed to a model action. The given
@@ -67,10 +150,12 @@ export function isPrompt(arg: any): boolean {
* @returns The new `PromptAction`.
*/
export function definePrompt(
+ registry: Registry,
config: PromptConfig,
fn: PromptFn
): PromptAction {
const a = defineAction(
+ registry,
{
...config,
actionType: 'prompt',
@@ -94,16 +179,19 @@ export async function renderPrompt<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
->(params: {
- prompt: PromptArgument;
- input: z.infer;
- docs?: DocumentData[];
- model: ModelArgument;
- config?: z.infer;
-}): Promise> {
+>(
+ registry: Registry,
+ params: {
+ prompt: PromptArgument;
+ input: z.infer;
+ docs?: DocumentData[];
+ model: ModelArgument;
+ config?: z.infer;
+ }
+): Promise> {
let prompt: PromptAction;
if (typeof params.prompt === 'string') {
- prompt = await lookupAction(`/prompt/${params.prompt}`);
+ prompt = await registry.lookupAction(`/prompt/${params.prompt}`);
} else {
prompt = params.prompt as PromptAction;
}
diff --git a/js/ai/src/reranker.ts b/js/ai/src/reranker.ts
index 54428d0cb..35d3b2505 100644
--- a/js/ai/src/reranker.ts
+++ b/js/ai/src/reranker.ts
@@ -15,7 +15,7 @@
*/
import { Action, defineAction, z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { Part, PartSchema } from './document.js';
import { Document, DocumentData, DocumentDataSchema } from './retriever.js';
@@ -101,6 +101,7 @@ function rerankerWithMetadata<
* Creates a reranker action for the provided {@link RerankerFn} implementation.
*/
export function defineReranker(
+ registry: Registry,
options: {
name: string;
configSchema?: OptionsType;
@@ -109,6 +110,7 @@ export function defineReranker(
runner: RerankerFn
) {
const reranker = defineAction(
+ registry,
{
actionType: 'reranker',
name: options.name,
@@ -157,13 +159,14 @@ export type RerankerArgument<
* Reranks documents from a {@link RerankerArgument} based on the provided query.
*/
export async function rerank(
+ registry: Registry,
params: RerankerParams
): Promise> {
let reranker: RerankerAction;
if (typeof params.reranker === 'string') {
- reranker = await lookupAction(`/reranker/${params.reranker}`);
+ reranker = await registry.lookupAction(`/reranker/${params.reranker}`);
} else if (Object.hasOwnProperty.call(params.reranker, 'info')) {
- reranker = await lookupAction(`/reranker/${params.reranker.name}`);
+ reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`);
} else {
reranker = params.reranker as RerankerAction;
}
diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts
index 0d3f23689..0623e297f 100644
--- a/js/ai/src/retriever.ts
+++ b/js/ai/src/retriever.ts
@@ -15,7 +15,7 @@
*/
import { Action, GenkitError, defineAction, z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
import { EmbedderInfo } from './embedder.js';
@@ -111,6 +111,7 @@ function indexerWithMetadata<
export function defineRetriever<
OptionsType extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
options: {
name: string;
configSchema?: OptionsType;
@@ -119,6 +120,7 @@ export function defineRetriever<
runner: RetrieverFn
) {
const retriever = defineAction(
+ registry,
{
actionType: 'retriever',
name: options.name,
@@ -149,6 +151,7 @@ export function defineRetriever<
* Creates an indexer action for the provided {@link IndexerFn} implementation.
*/
export function defineIndexer(
+ registry: Registry,
options: {
name: string;
embedderInfo?: EmbedderInfo;
@@ -157,6 +160,7 @@ export function defineIndexer(
runner: IndexerFn
) {
const indexer = defineAction(
+ registry,
{
actionType: 'indexer',
name: options.name,
@@ -200,13 +204,16 @@ export type RetrieverArgument<
* Retrieves documents from a {@link RetrieverArgument} based on the provided query.
*/
export async function retrieve(
+ registry: Registry,
params: RetrieverParams
): Promise> {
let retriever: RetrieverAction;
if (typeof params.retriever === 'string') {
- retriever = await lookupAction(`/retriever/${params.retriever}`);
+ retriever = await registry.lookupAction(`/retriever/${params.retriever}`);
} else if (Object.hasOwnProperty.call(params.retriever, 'info')) {
- retriever = await lookupAction(`/retriever/${params.retriever.name}`);
+ retriever = await registry.lookupAction(
+ `/retriever/${params.retriever.name}`
+ );
} else {
retriever = params.retriever as RetrieverAction;
}
@@ -239,13 +246,14 @@ export interface IndexerParams<
* Indexes documents using a {@link IndexerArgument}.
*/
export async function index(
+ registry: Registry,
params: IndexerParams
): Promise {
let indexer: IndexerAction;
if (typeof params.indexer === 'string') {
- indexer = await lookupAction(`/indexer/${params.indexer}`);
+ indexer = await registry.lookupAction(`/indexer/${params.indexer}`);
} else if (Object.hasOwnProperty.call(params.indexer, 'info')) {
- indexer = await lookupAction(`/indexer/${params.indexer.name}`);
+ indexer = await registry.lookupAction(`/indexer/${params.indexer.name}`);
} else {
indexer = params.indexer as IndexerAction;
}
@@ -381,10 +389,12 @@ export function defineSimpleRetriever<
C extends z.ZodTypeAny = z.ZodTypeAny,
R = any,
>(
+ registry: Registry,
options: SimpleRetrieverOptions,
handler: (query: Document, config: z.infer) => Promise
) {
return defineRetriever(
+ registry,
{
name: options.name,
configSchema: options.configSchema,
diff --git a/js/ai/src/testing/model-tester.ts b/js/ai/src/testing/model-tester.ts
index 3cf041ae9..7caa4b0cc 100644
--- a/js/ai/src/testing/model-tester.ts
+++ b/js/ai/src/testing/model-tester.ts
@@ -15,7 +15,7 @@
*/
import { z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { runInNewSpan } from '@genkit-ai/core/tracing';
import assert from 'node:assert';
import { generate } from '../generate';
@@ -23,8 +23,8 @@ import { ModelAction } from '../model';
import { defineTool } from '../tool';
const tests: Record = {
- 'basic hi': async (model: string) => {
- const response = await generate({
+ 'basic hi': async (registry: Registry, model: string) => {
+ const response = await generate(registry, {
model,
prompt: 'just say "Hi", literally',
});
@@ -32,14 +32,14 @@ const tests: Record = {
const got = response.text.trim();
assert.match(got, /Hi/i);
},
- multimodal: async (model: string) => {
- const resolvedModel = (await lookupAction(
+ multimodal: async (registry: Registry, model: string) => {
+ const resolvedModel = (await registry.lookupAction(
`/model/${model}`
)) as ModelAction;
if (!resolvedModel.__action.metadata?.model.supports?.media) {
skip();
}
- const response = await generate({
+ const response = await generate(registry, {
model,
prompt: [
{
@@ -57,18 +57,18 @@ const tests: Record = {
const got = response.text.trim();
assert.match(got, /plus/i);
},
- history: async (model: string) => {
- const resolvedModel = (await lookupAction(
+ history: async (registry: Registry, model: string) => {
+ const resolvedModel = (await registry.lookupAction(
`/model/${model}`
)) as ModelAction;
if (!resolvedModel.__action.metadata?.model.supports?.multiturn) {
skip();
}
- const response1 = await generate({
+ const response1 = await generate(registry, {
model,
prompt: 'My name is Glorb',
});
- const response = await generate({
+ const response = await generate(registry, {
model,
prompt: "What's my name?",
messages: response1.messages,
@@ -77,8 +77,8 @@ const tests: Record = {
const got = response.text.trim();
assert.match(got, /Glorb/);
},
- 'system prompt': async (model: string) => {
- const { text } = await generate({
+ 'system prompt': async (registry: Registry, model: string) => {
+ const { text } = await generate(registry, {
model,
prompt: 'Hi',
messages: [
@@ -97,8 +97,8 @@ const tests: Record = {
const got = text.trim();
assert.equal(got, want);
},
- 'structured output': async (model: string) => {
- const response = await generate({
+ 'structured output': async (registry: Registry, model: string) => {
+ const response = await generate(registry, {
model,
prompt: 'extract data as json from: Jack was a Lumberjack',
output: {
@@ -117,15 +117,15 @@ const tests: Record = {
const got = response.output;
assert.deepEqual(want, got);
},
- 'tool calling': async (model: string) => {
- const resolvedModel = (await lookupAction(
+ 'tool calling': async (registry: Registry, model: string) => {
+ const resolvedModel = (await registry.lookupAction(
`/model/${model}`
)) as ModelAction;
if (!resolvedModel.__action.metadata?.model.supports?.tools) {
skip();
}
- const { text } = await generate({
+ const { text } = await generate(registry, {
model,
prompt: 'what is a gablorken of 2? use provided tool',
tools: ['gablorkenTool'],
@@ -149,10 +149,14 @@ type TestReport = {
}[];
}[];
-type TestCase = (model: string) => Promise;
+type TestCase = (ai: Registry, model: string) => Promise;
-export async function testModels(models: string[]): Promise {
+export async function testModels(
+ registry: Registry,
+ models: string[]
+): Promise {
const gablorkenTool = defineTool(
+ registry,
{
name: 'gablorkenTool',
description: 'use when need to calculate a gablorken',
@@ -182,7 +186,7 @@ export async function testModels(models: string[]): Promise {
});
const modelReport = caseReport.models[caseReport.models.length - 1];
try {
- await tests[test](model);
+ await tests[test](registry, model);
} catch (e) {
modelReport.passed = false;
if (e instanceof SkipTestError) {
diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts
index 9dcb61c4f..a0d85340c 100644
--- a/js/ai/src/tool.ts
+++ b/js/ai/src/tool.ts
@@ -15,7 +15,7 @@
*/
import { Action, defineAction, JSONSchema7, z } from '@genkit-ai/core';
-import { lookupAction } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import { ToolDefinition } from './model.js';
@@ -89,11 +89,11 @@ export function asTool(
export async function resolveTools<
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
->(tools: ToolArgument[] = []): Promise {
+>(registry: Registry, tools: ToolArgument[] = []): Promise {
return await Promise.all(
tools.map(async (ref): Promise => {
if (typeof ref === 'string') {
- const tool = await lookupAction(`/tool/${ref}`);
+ const tool = await registry.lookupAction(`/tool/${ref}`);
if (!tool) {
throw new Error(`Tool ${ref} not found`);
}
@@ -101,7 +101,7 @@ export async function resolveTools<
} else if ((ref as Action).__action) {
return asTool(ref as Action);
} else if (ref.name) {
- const tool = await lookupAction(`/tool/${ref.name}`);
+ const tool = await registry.lookupAction(`/tool/${ref.name}`);
if (!tool) {
throw new Error(`Tool ${ref} not found`);
}
@@ -137,10 +137,12 @@ export function toToolDefinition(
* A tool is an action that can be passed to a model to be called automatically if it so chooses.
*/
export function defineTool(
+ registry: Registry,
config: ToolConfig,
fn: (input: z.infer) => Promise>
): ToolAction {
const a = defineAction(
+ registry,
{
...config,
actionType: 'tool',
diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts
index cae8e419c..9a02b0f6e 100644
--- a/js/ai/tests/generate/generate_test.ts
+++ b/js/ai/tests/generate/generate_test.ts
@@ -15,7 +15,7 @@
*/
import { z } from '@genkit-ai/core';
-import { Registry, runWithRegistry } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
@@ -262,19 +262,18 @@ describe('GenerateResponse', () => {
describe('toGenerateRequest', () => {
const registry = new Registry();
// register tools
- const tellAFunnyJoke = runWithRegistry(registry, () =>
- defineTool(
- {
- name: 'tellAFunnyJoke',
- description:
- 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.',
- inputSchema: z.object({ topic: z.string() }),
- outputSchema: z.string(),
- },
- async (input) => {
- return `Why did the ${input.topic} cross the road?`;
- }
- )
+ const tellAFunnyJoke = defineTool(
+ registry,
+ {
+ name: 'tellAFunnyJoke',
+ description:
+ 'Tells jokes about an input topic. Use this tool whenever user asks you to tell a joke.',
+ inputSchema: z.object({ topic: z.string() }),
+ outputSchema: z.string(),
+ },
+ async (input) => {
+ return `Why did the ${input.topic} cross the road?`;
+ }
);
const testCases = [
@@ -442,9 +441,7 @@ describe('toGenerateRequest', () => {
for (const test of testCases) {
it(test.should, async () => {
assert.deepStrictEqual(
- await runWithRegistry(registry, () =>
- toGenerateRequest(test.prompt as GenerateOptions)
- ),
+ await toGenerateRequest(registry, test.prompt as GenerateOptions),
test.expectedOutput
);
});
@@ -530,29 +527,28 @@ describe('generate', () => {
beforeEach(() => {
registry = new Registry();
- echoModel = runWithRegistry(registry, () =>
- defineModel(
- {
- name: 'echoModel',
- },
- async (request) => {
- return {
- message: {
- role: 'model',
- content: [
- {
- text:
- 'Echo: ' +
- request.messages
- .map((m) => m.content.map((c) => c.text).join())
- .join(),
- },
- ],
- },
- finishReason: 'stop',
- };
- }
- )
+ echoModel = defineModel(
+ registry,
+ {
+ name: 'echoModel',
+ },
+ async (request) => {
+ return {
+ message: {
+ role: 'model',
+ content: [
+ {
+ text:
+ 'Echo: ' +
+ request.messages
+ .map((m) => m.content.map((c) => c.text).join())
+ .join(),
+ },
+ ],
+ },
+ finishReason: 'stop',
+ };
+ }
);
});
@@ -592,14 +588,11 @@ describe('generate', () => {
};
};
- const response = await runWithRegistry(registry, () =>
- generate({
- prompt: 'banana',
- model: echoModel,
- use: [wrapRequest, wrapResponse],
- })
- );
-
+ const response = await generate(registry, {
+ prompt: 'banana',
+ model: echoModel,
+ use: [wrapRequest, wrapResponse],
+ });
const want = '[Echo: (banana)]';
assert.deepStrictEqual(response.text, want);
});
@@ -609,24 +602,21 @@ describe('generate', () => {
let registry: Registry;
beforeEach(() => {
registry = new Registry();
- runWithRegistry(registry, () =>
- defineModel(
- { name: 'echo', supports: { tools: true } },
- async (input) => ({
- message: input.messages[0],
- finishReason: 'stop',
- })
- )
+
+ defineModel(
+ registry,
+ { name: 'echo', supports: { tools: true } },
+ async (input) => ({
+ message: input.messages[0],
+ finishReason: 'stop',
+ })
);
});
it('should preserve the request in the returned response, enabling .messages', async () => {
- const response = await runWithRegistry(registry, () =>
- generate({
- model: 'echo',
- prompt: 'Testing messages',
- })
- );
-
+ const response = await generate(registry, {
+ model: 'echo',
+ prompt: 'Testing messages',
+ });
assert.deepEqual(
response.messages.map((m) => m.content[0].text),
['Testing messages', 'Testing messages']
diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts
index 9b3eb7054..3c9aaaffa 100644
--- a/js/ai/tests/model/middleware_test.ts
+++ b/js/ai/tests/model/middleware_test.ts
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-import { Registry, runWithRegistry } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { DocumentData } from '../../src/document.js';
@@ -147,24 +147,21 @@ describe('validateSupport', () => {
});
const registry = new Registry();
-const echoModel = runWithRegistry(registry, () =>
- defineModel({ name: 'echo' }, async (req) => {
- return {
- finishReason: 'stop',
- message: {
- role: 'model',
- content: [{ data: req }],
- },
- };
- })
-);
-
+const echoModel = defineModel(registry, { name: 'echo' }, async (req) => {
+ return {
+ finishReason: 'stop',
+ message: {
+ role: 'model',
+ content: [{ data: req }],
+ },
+ };
+});
describe('conformOutput (default middleware)', () => {
const schema = { type: 'object', properties: { test: { type: 'boolean' } } };
// return the output tagged part from the request
async function testRequest(req: GenerateRequest): Promise {
- const response = await runWithRegistry(registry, () => echoModel(req));
+ const response = await echoModel(req);
const treq = response.message!.content[0].data as GenerateRequest;
const lastUserMessage = treq.messages
diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts
index c35c951c6..702f85444 100644
--- a/js/ai/tests/prompt/prompt_test.ts
+++ b/js/ai/tests/prompt/prompt_test.ts
@@ -15,7 +15,7 @@
*/
import { z } from '@genkit-ai/core';
-import { Registry, runWithRegistry } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { definePrompt, renderPrompt } from '../../src/prompt.ts';
@@ -23,38 +23,37 @@ import { definePrompt, renderPrompt } from '../../src/prompt.ts';
describe('prompt', () => {
let registry = new Registry();
describe('render()', () => {
- runWithRegistry(registry, () => {
- it('respects output schema in the definition', async () => {
- const schema1 = z.object({
- puppyName: z.string({ description: 'A cute name for a puppy' }),
- });
- const prompt1 = definePrompt(
- {
- name: 'prompt1',
- inputSchema: z.string({ description: 'Dog breed' }),
- },
- async (breed) => {
- return {
- messages: [
- {
- role: 'user',
- content: [{ text: `Pick a name for a ${breed} puppy` }],
- },
- ],
- output: {
- format: 'json',
- schema: schema1,
+ it('respects output schema in the definition', async () => {
+ const schema1 = z.object({
+ puppyName: z.string({ description: 'A cute name for a puppy' }),
+ });
+ const prompt1 = definePrompt(
+ registry,
+ {
+ name: 'prompt1',
+ inputSchema: z.string({ description: 'Dog breed' }),
+ },
+ async (breed) => {
+ return {
+ messages: [
+ {
+ role: 'user',
+ content: [{ text: `Pick a name for a ${breed} puppy` }],
},
- };
- }
- );
- const generateRequest = await renderPrompt({
- prompt: prompt1,
- input: 'poodle',
- model: 'geminiPro',
- });
- assert.equal(generateRequest.output?.schema, schema1);
+ ],
+ output: {
+ format: 'json',
+ schema: schema1,
+ },
+ };
+ }
+ );
+ const generateRequest = await renderPrompt(registry, {
+ prompt: prompt1,
+ input: 'poodle',
+ model: 'geminiPro',
});
+ assert.equal(generateRequest.output?.schema, schema1);
});
});
});
diff --git a/js/ai/tests/reranker/reranker_test.ts b/js/ai/tests/reranker/reranker_test.ts
index 4942e02b6..63a8b25e4 100644
--- a/js/ai/tests/reranker/reranker_test.ts
+++ b/js/ai/tests/reranker/reranker_test.ts
@@ -15,7 +15,7 @@
*/
import { GenkitError, z } from '@genkit-ai/core';
-import { Registry, runWithRegistry } from '@genkit-ai/core/registry';
+import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { defineReranker, rerank } from '../../src/reranker';
@@ -28,34 +28,32 @@ describe('reranker', () => {
registry = new Registry();
});
it('reranks documents based on custom logic', async () => {
- const customReranker = runWithRegistry(registry, () =>
- defineReranker(
- {
- name: 'reranker',
- configSchema: z.object({
- k: z.number().optional(),
- }),
- },
- async (query, documents, options) => {
- // Custom reranking logic: score based on string length similarity to query
- const queryLength = query.text.length;
- const rerankedDocs = documents.map((doc) => {
- const score = Math.abs(queryLength - doc.text.length);
- return {
- ...doc,
- metadata: { ...doc.metadata, score },
- };
- });
-
+ const customReranker = defineReranker(
+ registry,
+ {
+ name: 'reranker',
+ configSchema: z.object({
+ k: z.number().optional(),
+ }),
+ },
+ async (query, documents, options) => {
+ // Custom reranking logic: score based on string length similarity to query
+ const queryLength = query.text.length;
+ const rerankedDocs = documents.map((doc) => {
+ const score = Math.abs(queryLength - doc.text.length);
return {
- documents: rerankedDocs
- .sort((a, b) => a.metadata.score - b.metadata.score)
- .slice(0, options.k || 3),
+ ...doc,
+ metadata: { ...doc.metadata, score },
};
- }
- )
+ });
+
+ return {
+ documents: rerankedDocs
+ .sort((a, b) => a.metadata.score - b.metadata.score)
+ .slice(0, options.k || 3),
+ };
+ }
);
-
// Sample documents for testing
const documents = [
Document.fromText('short'),
@@ -64,15 +62,12 @@ describe('reranker', () => {
];
const query = Document.fromText('medium length');
- const rerankedDocuments = await runWithRegistry(registry, () =>
- rerank({
- reranker: customReranker,
- query,
- documents,
- options: { k: 2 },
- })
- );
-
+ const rerankedDocuments = await rerank(registry, {
+ reranker: customReranker,
+ query,
+ documents,
+ options: { k: 2 },
+ });
// Validate the reranked results
assert.equal(rerankedDocuments.length, 2);
assert(rerankedDocuments[0].text.includes('a bit longer'));
@@ -80,85 +75,76 @@ describe('reranker', () => {
});
it('handles missing options gracefully', async () => {
- const customReranker = runWithRegistry(registry, () =>
- defineReranker(
- {
- name: 'reranker',
- configSchema: z.object({
- k: z.number().optional(),
- }),
- },
- async (query, documents, options) => {
- const rerankedDocs = documents.map((doc) => {
- const score = Math.random(); // Simplified scoring for testing
- return {
- ...doc,
- metadata: { ...doc.metadata, score },
- };
- });
-
+ const customReranker = defineReranker(
+ registry,
+ {
+ name: 'reranker',
+ configSchema: z.object({
+ k: z.number().optional(),
+ }),
+ },
+ async (query, documents, options) => {
+ const rerankedDocs = documents.map((doc) => {
+ const score = Math.random(); // Simplified scoring for testing
return {
- documents: rerankedDocs.sort(
- (a, b) => b.metadata.score - a.metadata.score
- ),
+ ...doc,
+ metadata: { ...doc.metadata, score },
};
- }
- )
+ });
+
+ return {
+ documents: rerankedDocs.sort(
+ (a, b) => b.metadata.score - a.metadata.score
+ ),
+ };
+ }
);
-
const documents = [Document.fromText('doc1'), Document.fromText('doc2')];
const query = Document.fromText('test query');
- const rerankedDocuments = await runWithRegistry(registry, () =>
- rerank({
- reranker: customReranker,
- query,
- documents,
- options: { k: 2 },
- })
- );
-
+ const rerankedDocuments = await rerank(registry, {
+ reranker: customReranker,
+ query,
+ documents,
+ options: { k: 2 },
+ });
assert.equal(rerankedDocuments.length, 2);
assert(typeof rerankedDocuments[0].metadata.score === 'number');
});
it('validates config schema and throws error on invalid input', async () => {
- const customReranker = runWithRegistry(registry, () =>
- defineReranker(
- {
- name: 'reranker',
- configSchema: z.object({
- k: z.number().min(1),
- }),
- },
- async (query, documents, options) => {
- // Simplified scoring for testing
- const rerankedDocs = documents.map((doc) => ({
- ...doc,
- metadata: { score: Math.random() },
- }));
- return {
- documents: rerankedDocs.sort(
- (a, b) => b.metadata.score - a.metadata.score
- ),
- };
- }
- )
+ const customReranker = defineReranker(
+ registry,
+ {
+ name: 'reranker',
+ configSchema: z.object({
+ k: z.number().min(1),
+ }),
+ },
+ async (query, documents, options) => {
+ // Simplified scoring for testing
+ const rerankedDocs = documents.map((doc) => ({
+ ...doc,
+ metadata: { score: Math.random() },
+ }));
+ return {
+ documents: rerankedDocs.sort(
+ (a, b) => b.metadata.score - a.metadata.score
+ ),
+ };
+ }
);
-
const documents = [Document.fromText('doc1')];
const query = Document.fromText('test query');
try {
- await runWithRegistry(registry, () =>
- rerank({
- reranker: customReranker,
- query,
- documents,
- options: { k: 0 }, // Invalid input: k must be at least 1
- })
- );
+ await rerank(registry, {
+ reranker: customReranker,
+ query,
+ documents,
+ options: { k: 0 }, // Invalid input: k must be at least 1
+ });
assert.fail('Expected validation error');
} catch (err) {
assert(err instanceof GenkitError);
@@ -167,71 +153,62 @@ describe('reranker', () => {
});
it('preserves document metadata after reranking', async () => {
- const customReranker = runWithRegistry(registry, () =>
- defineReranker(
- {
- name: 'reranker',
- },
- async (query, documents) => {
- const rerankedDocs = documents.map((doc, i) => ({
- ...doc,
- metadata: { ...doc.metadata, score: 2 - i },
- }));
-
- return {
- documents: rerankedDocs.sort(
- (a, b) => b.metadata.score - a.metadata.score
- ),
- };
- }
- )
+ const customReranker = defineReranker(
+ registry,
+ {
+ name: 'reranker',
+ },
+ async (query, documents) => {
+ const rerankedDocs = documents.map((doc, i) => ({
+ ...doc,
+ metadata: { ...doc.metadata, score: 2 - i },
+ }));
+
+ return {
+ documents: rerankedDocs.sort(
+ (a, b) => b.metadata.score - a.metadata.score
+ ),
+ };
+ }
);
-
const documents = [
new Document({ content: [], metadata: { originalField: 'test1' } }),
new Document({ content: [], metadata: { originalField: 'test2' } }),
];
const query = Document.fromText('test query');
- const rerankedDocuments = await runWithRegistry(registry, () =>
- rerank({
- reranker: customReranker,
- query,
- documents,
- })
- );
-
+ const rerankedDocuments = await rerank(registry, {
+ reranker: customReranker,
+ query,
+ documents,
+ });
assert.equal(rerankedDocuments[0].metadata.originalField, 'test1');
assert.equal(rerankedDocuments[1].metadata.originalField, 'test2');
});
it('handles errors thrown by the reranker', async () => {
- const customReranker = runWithRegistry(registry, () =>
- defineReranker(
- {
- name: 'reranker',
- },
- async (query, documents) => {
- // Simulate an error in the reranker logic
- throw new GenkitError({
- message: 'Something went wrong during reranking',
- status: 'INTERNAL',
- });
- }
- )
+ const customReranker = defineReranker(
+ registry,
+ {
+ name: 'reranker',
+ },
+ async (query, documents) => {
+ // Simulate an error in the reranker logic
+ throw new GenkitError({
+ message: 'Something went wrong during reranking',
+ status: 'INTERNAL',
+ });
+ }
);
-
const documents = [Document.fromText('doc1'), Document.fromText('doc2')];
const query = Document.fromText('test query');
try {
- await runWithRegistry(registry, () =>
- rerank({
- reranker: customReranker,
- query,
- documents,
- })
- );
+ await rerank(registry, {
+ reranker: customReranker,
+ query,
+ documents,
+ });
assert.fail('Expected an error to be thrown');
} catch (err) {
assert(err instanceof GenkitError);
diff --git a/js/core/package.json b/js/core/package.json
index 3a26e2536..03fdcd451 100644
--- a/js/core/package.json
+++ b/js/core/package.json
@@ -7,7 +7,7 @@
"genai",
"generative-ai"
],
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"type": "commonjs",
"scripts": {
"check": "tsc",
diff --git a/js/core/src/action.ts b/js/core/src/action.ts
index 1bf8f7a1b..382b80d14 100644
--- a/js/core/src/action.ts
+++ b/js/core/src/action.ts
@@ -17,12 +17,7 @@
import { JSONSchema7 } from 'json-schema';
import { AsyncLocalStorage } from 'node:async_hooks';
import * as z from 'zod';
-import {
- ActionType,
- initializeAllPlugins,
- lookupPlugin,
- registerAction,
-} from './registry.js';
+import { ActionType, Registry } from './registry.js';
import { parseSchema } from './schema.js';
import {
SPAN_TYPE_ATTR,
@@ -122,8 +117,8 @@ export function action<
): Action {
const actionName =
typeof config.name === 'string'
- ? validateActionName(config.name)
- : `${config.name.pluginId}/${validateActionId(config.name.actionId)}`;
+ ? config.name
+ : `${config.name.pluginId}/${config.name.actionId}`;
const actionFn = async (input: I) => {
input = parseSchema(input, {
schema: config.inputSchema,
@@ -168,16 +163,16 @@ export function action<
return actionFn;
}
-function validateActionName(name: string) {
+function validateActionName(registry: Registry, name: string) {
if (name.includes('/')) {
- validatePluginName(name.split('/', 1)[0]);
+ validatePluginName(registry, name.split('/', 1)[0]);
validateActionId(name.substring(name.indexOf('/') + 1));
}
return name;
}
-function validatePluginName(pluginId: string) {
- if (!lookupPlugin(pluginId)) {
+function validatePluginName(registry: Registry, pluginId: string) {
+ if (!registry.lookupPlugin(pluginId)) {
throw new Error(
`Unable to find plugin name used in the action name: ${pluginId}`
);
@@ -200,6 +195,7 @@ export function defineAction<
O extends z.ZodTypeAny,
M extends Record = Record,
>(
+ registry: Registry,
config: ActionParams & {
actionType: ActionType;
},
@@ -211,13 +207,18 @@ export function defineAction<
'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md'
);
}
+ if (typeof config.name === 'string') {
+ validateActionName(registry, config.name);
+ } else {
+ validateActionId(config.name.actionId);
+ }
const act = action(config, async (i: I): Promise> => {
setCustomMetadataAttributes({ subtype: config.actionType });
- await initializeAllPlugins();
+ await registry.initializeAllPlugins();
return await runInActionRuntimeContext(() => fn(i));
});
act.__action.actionType = config.actionType;
- registerAction(config.actionType, act);
+ registry.registerAction(config.actionType, act);
return act;
}
diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts
index 0061e0cde..107459585 100644
--- a/js/core/src/flow.ts
+++ b/js/core/src/flow.ts
@@ -31,12 +31,7 @@ import { runWithAuthContext } from './auth.js';
import { getErrorMessage, getErrorStack } from './error.js';
import { FlowActionInputSchema } from './flowTypes.js';
import { logger } from './logging.js';
-import {
- getRegistryInstance,
- initializeAllPlugins,
- Registry,
- runWithRegistry,
-} from './registry.js';
+import { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import {
newTrace,
@@ -181,6 +176,7 @@ export class Flow<
readonly flowFn: FlowFn;
constructor(
+ private registry: Registry,
config: FlowConfig | StreamingFlowConfig,
flowFn: FlowFn
) {
@@ -207,7 +203,7 @@ export class Flow<
auth?: unknown;
}
): Promise>> {
- await initializeAllPlugins();
+ await this.registry.initializeAllPlugins();
return await runWithAuthContext(opts.auth, () =>
newTrace(
{
@@ -336,84 +332,79 @@ export class Flow<
}
async expressHandler(
- registry: Registry,
request: __RequestWithAuth,
response: express.Response
): Promise {
- await runWithRegistry(registry, async () => {
- const { stream } = request.query;
- const auth = request.auth;
-
- let input = request.body.data;
+ const { stream } = request.query;
+ const auth = request.auth;
+
+ let input = request.body.data;
+
+ try {
+ await this.authPolicy?.(auth, input);
+ } catch (e: any) {
+ const respBody = {
+ error: {
+ status: 'PERMISSION_DENIED',
+ message: e.message || 'Permission denied to resource',
+ },
+ };
+ response.status(403).send(respBody).end();
+ return;
+ }
+ if (stream === 'true') {
+ response.writeHead(200, {
+ 'Content-Type': 'text/plain',
+ 'Transfer-Encoding': 'chunked',
+ });
try {
- await this.authPolicy?.(auth, input);
- } catch (e: any) {
- const respBody = {
+ const result = await this.invoke(input, {
+ streamingCallback: ((chunk: z.infer) => {
+ response.write(JSON.stringify(chunk) + streamDelimiter);
+ }) as S extends z.ZodVoid ? undefined : StreamingCallback>,
+ auth,
+ });
+ response.write({
+ result: result.result, // Need more results!!!!
+ });
+ response.end();
+ } catch (e) {
+ response.write({
error: {
- status: 'PERMISSION_DENIED',
- message: e.message || 'Permission denied to resource',
+ status: 'INTERNAL',
+ message: getErrorMessage(e),
+ details: getErrorStack(e),
},
- };
- response.status(403).send(respBody).end();
- return;
- }
-
- if (stream === 'true') {
- response.writeHead(200, {
- 'Content-Type': 'text/plain',
- 'Transfer-Encoding': 'chunked',
});
- try {
- const result = await this.invoke(input, {
- streamingCallback: ((chunk: z.infer) => {
- response.write(JSON.stringify(chunk) + streamDelimiter);
- }) as S extends z.ZodVoid
- ? undefined
- : StreamingCallback>,
- auth,
- });
- response.write({
- result: result.result, // Need more results!!!!
- });
- response.end();
- } catch (e) {
- response.write({
+ response.end();
+ }
+ } else {
+ try {
+ const result = await this.invoke(input, { auth });
+ response.setHeader('x-genkit-trace-id', result.traceId);
+ response.setHeader('x-genkit-span-id', result.spanId);
+ // Responses for non-streaming flows are passed back with the flow result stored in a field called "result."
+ response
+ .status(200)
+ .send({
+ result: result.result,
+ })
+ .end();
+ } catch (e) {
+ // Errors for non-streaming flows are passed back as standard API errors.
+ response
+ .status(500)
+ .send({
error: {
status: 'INTERNAL',
message: getErrorMessage(e),
details: getErrorStack(e),
},
- });
- response.end();
- }
- } else {
- try {
- const result = await this.invoke(input, { auth });
- response.setHeader('x-genkit-trace-id', result.traceId);
- response.setHeader('x-genkit-span-id', result.spanId);
- // Responses for non-streaming flows are passed back with the flow result stored in a field called "result."
- response
- .status(200)
- .send({
- result: result.result,
- })
- .end();
- } catch (e) {
- // Errors for non-streaming flows are passed back as standard API errors.
- response
- .status(500)
- .send({
- error: {
- status: 'INTERNAL',
- message: getErrorMessage(e),
- details: getErrorStack(e),
- },
- })
- .end();
- }
+ })
+ .end();
}
- });
+ }
}
}
@@ -496,9 +487,7 @@ export class FlowServer {
flow.middleware?.forEach((middleware) =>
server.post(flowPath, middleware)
);
- server.post(flowPath, (req, res) =>
- flow.expressHandler(this.registry, req, res)
- );
+ server.post(flowPath, (req, res) => flow.expressHandler(req, res));
});
} else {
logger.warn('No flows registered in flow server.');
@@ -557,17 +546,17 @@ export function defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
config: FlowConfig | string,
fn: FlowFn
): CallableFlow {
const resolvedConfig: FlowConfig =
typeof config === 'string' ? { name: config } : config;
- const flow = new Flow(resolvedConfig, fn);
- registerFlowAction(flow);
- const registry = getRegistryInstance();
+ const flow = new Flow(registry, resolvedConfig, fn);
+ registerFlowAction(registry, flow);
const callableFlow: CallableFlow = async (input, opts) => {
- return runWithRegistry(registry, () => flow.run(input, opts));
+ return flow.run(input, opts);
};
callableFlow.flow = flow;
return callableFlow;
@@ -581,14 +570,14 @@ export function defineStreamingFlow<
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
+ registry: Registry,
config: StreamingFlowConfig,
fn: FlowFn
): StreamableFlow {
- const flow = new Flow(config, fn);
- registerFlowAction(flow);
- const registry = getRegistryInstance();
+ const flow = new Flow(registry, config, fn);
+ registerFlowAction(registry, flow);
const streamableFlow: StreamableFlow = (input, opts) => {
- return runWithRegistry(registry, () => flow.stream(input, opts));
+ return flow.stream(input, opts);
};
streamableFlow.flow = flow;
return streamableFlow;
@@ -601,8 +590,12 @@ function registerFlowAction<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
->(flow: Flow): Action {
+>(
+ registry: Registry,
+ flow: Flow
+): Action {
return defineAction(
+ registry,
{
actionType: 'flow',
name: flow.name,
diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts
index e74a7fa69..e66c73630 100644
--- a/js/core/src/reflection.ts
+++ b/js/core/src/reflection.ts
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-import express, { NextFunction, Request, Response } from 'express';
+import express from 'express';
import fs from 'fs/promises';
import getPort, { makeRange } from 'get-port';
import { Server } from 'http';
@@ -23,7 +23,7 @@ import z from 'zod';
import { Status, StatusCodes, runWithStreamingCallback } from './action.js';
import { GENKIT_VERSION } from './index.js';
import { logger } from './logging.js';
-import { Registry, runWithRegistry } from './registry.js';
+import { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import {
flushTracing,
@@ -113,16 +113,6 @@ export class ReflectionServer {
next();
});
- server.use((req: Request, res: Response, next: NextFunction) => {
- runWithRegistry(this.registry, async () => {
- try {
- next();
- } catch (err) {
- next(err);
- }
- });
- });
-
server.get('/api/__health', async (_, response) => {
await this.registry.listActions();
response.status(200).send('OK');
diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts
index afa679ac6..f7cd0f532 100644
--- a/js/core/src/registry.ts
+++ b/js/core/src/registry.ts
@@ -14,7 +14,6 @@
* limitations under the License.
*/
-import { AsyncLocalStorage } from 'async_hooks';
import * as z from 'zod';
import { Action } from './action.js';
import { logger } from './logging.js';
@@ -47,17 +46,6 @@ export interface Schema {
jsonSchema?: JSONSchema;
}
-/**
- * Looks up a registry key (action type and key) in the registry.
- */
-export function lookupAction<
- I extends z.ZodTypeAny,
- O extends z.ZodTypeAny,
- R extends Action,
->(key: string): Promise {
- return getRegistryInstance().lookupAction(key);
-}
-
function parsePluginName(registryKey: string) {
const tokens = registryKey.split('/');
if (tokens.length === 4) {
@@ -66,99 +54,8 @@ function parsePluginName(registryKey: string) {
return undefined;
}
-/**
- * Registers an action in the registry.
- */
-export function registerAction(
- type: ActionType,
- action: Action
-) {
- return getRegistryInstance().registerAction(type, action);
-}
-
type ActionsRecord = Record>;
-/**
- * Initialize all plugins in the registry.
- */
-export async function initializeAllPlugins() {
- await getRegistryInstance().initializeAllPlugins();
-}
-
-/**
- * Returns all actions in the registry.
- */
-export function listActions(): Promise {
- return getRegistryInstance().listActions();
-}
-
-/**
- * Registers a plugin provider.
- * @param name The name of the plugin to register.
- * @param provider The plugin provider.
- */
-export function registerPluginProvider(name: string, provider: PluginProvider) {
- return getRegistryInstance().registerPluginProvider(name, provider);
-}
-
-/**
- * Looks up a plugin.
- * @param name The name of the plugin to lookup.
- * @returns The plugin.
- */
-export function lookupPlugin(name: string) {
- return getRegistryInstance().lookupPlugin(name);
-}
-
-/**
- * Initializes a plugin that has already been registered.
- * @param name The name of the plugin to initialize.
- * @returns The plugin.
- */
-export async function initializePlugin(name: string) {
- return getRegistryInstance().initializePlugin(name);
-}
-
-/**
- * Registers a schema.
- * @param name The name of the schema to register.
- * @param data The schema to register (either a Zod schema or a JSON schema).
- */
-export function registerSchema(name: string, data: Schema) {
- return getRegistryInstance().registerSchema(name, data);
-}
-
-/**
- * Looks up a schema.
- * @param name The name of the schema to lookup.
- * @returns The schema.
- */
-export function lookupSchema(name: string) {
- return getRegistryInstance().lookupSchema(name);
-}
-
-const registryAls = new AsyncLocalStorage();
-
-/**
- * @returns The active registry instance.
- */
-export function getRegistryInstance(): Registry {
- const registry = registryAls.getStore();
- if (!registry) {
- throw new Error('getRegistryInstance() called before runWithRegistry()');
- }
- return registry;
-}
-
-/**
- * Runs a function with a specific registry instance.
- * @param registry The registry instance to use.
- * @param fn The function to run.
- */
-export function runWithRegistry(registry: Registry, fn: () => R) {
- return registryAls.run(registry, fn);
-}
-
/**
* The registry is used to store and lookup actions, trace stores, flow state stores, plugins, and schemas.
*/
@@ -170,14 +67,6 @@ export class Registry {
constructor(public parent?: Registry) {}
- /**
- * Creates a new registry overlaid onto the currently active registry.
- * @returns The new overlaid registry.
- */
- static withCurrent() {
- return new Registry(getRegistryInstance());
- }
-
/**
* Creates a new registry overlaid onto the provided registry.
* @param parent The parent registry.
diff --git a/js/core/src/schema.ts b/js/core/src/schema.ts
index 16a45160d..a53da8acb 100644
--- a/js/core/src/schema.ts
+++ b/js/core/src/schema.ts
@@ -19,7 +19,7 @@ import addFormats from 'ajv-formats';
import { z } from 'zod';
import zodToJsonSchema from 'zod-to-json-schema';
import { GenkitError } from './error.js';
-import { registerSchema } from './registry.js';
+import { Registry } from './registry.js';
const ajv = new Ajv();
addFormats(ajv);
@@ -112,14 +112,19 @@ export function parseSchema(
}
export function defineSchema(
+ registry: Registry,
name: string,
schema: T
): T {
- registerSchema(name, { schema });
+ registry.registerSchema(name, { schema });
return schema;
}
-export function defineJsonSchema(name: string, jsonSchema: JSONSchema) {
- registerSchema(name, { jsonSchema });
+export function defineJsonSchema(
+ registry: Registry,
+ name: string,
+ jsonSchema: JSONSchema
+) {
+ registry.registerSchema(name, { jsonSchema });
return jsonSchema;
}
diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts
index 7d5b74646..cce14e2ee 100644
--- a/js/core/tests/flow_test.ts
+++ b/js/core/tests/flow_test.ts
@@ -18,10 +18,11 @@ import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { defineFlow, defineStreamingFlow } from '../src/flow.js';
import { z } from '../src/index.js';
-import { Registry, runWithRegistry } from '../src/registry.js';
+import { Registry } from '../src/registry.js';
-function createTestFlow() {
+function createTestFlow(registry: Registry) {
return defineFlow(
+ registry,
{
name: 'testFlow',
inputSchema: z.string(),
@@ -33,8 +34,9 @@ function createTestFlow() {
);
}
-function createTestStreamingFlow() {
+function createTestStreamingFlow(registry: Registry) {
return defineStreamingFlow(
+ registry,
{
name: 'testFlow',
inputSchema: z.number(),
@@ -63,7 +65,7 @@ describe('flow', () => {
describe('runFlow', () => {
it('should run the flow', async () => {
- const testFlow = runWithRegistry(registry, createTestFlow);
+ const testFlow = createTestFlow(registry);
const result = await testFlow('foo');
@@ -71,10 +73,8 @@ describe('flow', () => {
});
it('should run simple sync flow', async () => {
- const testFlow = runWithRegistry(registry, () => {
- return defineFlow('testFlow', (input) => {
- return `bar ${input}`;
- });
+ const testFlow = defineFlow(registry, 'testFlow', (input) => {
+ return `bar ${input}`;
});
const result = await testFlow('foo');
@@ -83,17 +83,16 @@ describe('flow', () => {
});
it('should rethrow the error', async () => {
- const testFlow = runWithRegistry(registry, () =>
- defineFlow(
- {
- name: 'throwing',
- inputSchema: z.string(),
- outputSchema: z.string(),
- },
- async (input) => {
- throw new Error(`bad happened: ${input}`);
- }
- )
+ const testFlow = defineFlow(
+ registry,
+ {
+ name: 'throwing',
+ inputSchema: z.string(),
+ outputSchema: z.string(),
+ },
+ async (input) => {
+ throw new Error(`bad happened: ${input}`);
+ }
);
await assert.rejects(() => testFlow('foo'), {
@@ -103,17 +102,16 @@ describe('flow', () => {
});
it('should validate input', async () => {
- const testFlow = runWithRegistry(registry, () =>
- defineFlow(
- {
- name: 'validating',
- inputSchema: z.object({ foo: z.string(), bar: z.number() }),
- outputSchema: z.string(),
- },
- async (input) => {
- return `ok ${input}`;
- }
- )
+ const testFlow = defineFlow(
+ registry,
+ {
+ name: 'validating',
+ inputSchema: z.object({ foo: z.string(), bar: z.number() }),
+ outputSchema: z.string(),
+ },
+ async (input) => {
+ return `ok ${input}`;
+ }
);
await assert.rejects(
@@ -132,7 +130,7 @@ describe('flow', () => {
describe('streamFlow', () => {
it('should run the flow', async () => {
- const testFlow = runWithRegistry(registry, createTestStreamingFlow);
+ const testFlow = createTestStreamingFlow(registry);
const response = testFlow(3);
@@ -146,16 +144,15 @@ describe('flow', () => {
});
it('should rethrow the error', async () => {
- const testFlow = runWithRegistry(registry, () =>
- defineStreamingFlow(
- {
- name: 'throwing',
- inputSchema: z.string(),
- },
- async (input) => {
- throw new Error(`stream bad happened: ${input}`);
- }
- )
+ const testFlow = defineStreamingFlow(
+ registry,
+ {
+ name: 'throwing',
+ inputSchema: z.string(),
+ },
+ async (input) => {
+ throw new Error(`stream bad happened: ${input}`);
+ }
);
const response = testFlow('foo');
diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts
index 9542cf779..d54fdd415 100644
--- a/js/core/tests/registry_test.ts
+++ b/js/core/tests/registry_test.ts
@@ -17,175 +17,7 @@
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { action } from '../src/action.js';
-import {
- Registry,
- listActions,
- lookupAction,
- registerAction,
- registerPluginProvider,
- runWithRegistry,
-} from '../src/registry.js';
-
-describe('global registry', () => {
- let registry: Registry;
-
- beforeEach(() => {
- registry = new Registry();
- });
-
- describe('listActions', () => {
- it('returns all registered actions', async () => {
- await runWithRegistry(registry, async () => {
- const fooSomethingAction = action(
- { name: 'foo_something' },
- async () => null
- );
- registerAction('model', fooSomethingAction);
- const barSomethingAction = action(
- { name: 'bar_something' },
- async () => null
- );
- registerAction('model', barSomethingAction);
-
- assert.deepEqual(await listActions(), {
- '/model/foo_something': fooSomethingAction,
- '/model/bar_something': barSomethingAction,
- });
- });
- });
-
- it('returns all registered actions by plugins', async () => {
- await runWithRegistry(registry, async () => {
- registerPluginProvider('foo', {
- name: 'foo',
- async initializer() {
- registerAction('model', fooSomethingAction);
- return {};
- },
- });
- const fooSomethingAction = action(
- {
- name: {
- pluginId: 'foo',
- actionId: 'something',
- },
- },
- async () => null
- );
- registerPluginProvider('bar', {
- name: 'bar',
- async initializer() {
- registerAction('model', barSomethingAction);
- return {};
- },
- });
- const barSomethingAction = action(
- {
- name: {
- pluginId: 'bar',
- actionId: 'something',
- },
- },
- async () => null
- );
-
- assert.deepEqual(await listActions(), {
- '/model/foo/something': fooSomethingAction,
- '/model/bar/something': barSomethingAction,
- });
- });
- });
- });
-
- describe('lookupAction', () => {
- it('initializes plugin for action first', async () => {
- await runWithRegistry(registry, async () => {
- let fooInitialized = false;
- registerPluginProvider('foo', {
- name: 'foo',
- async initializer() {
- fooInitialized = true;
- return {};
- },
- });
- let barInitialized = false;
- registerPluginProvider('bar', {
- name: 'bar',
- async initializer() {
- barInitialized = true;
- return {};
- },
- });
-
- await lookupAction('/model/foo/something');
-
- assert.strictEqual(fooInitialized, true);
- assert.strictEqual(barInitialized, false);
-
- await lookupAction('/model/bar/something');
-
- assert.strictEqual(fooInitialized, true);
- assert.strictEqual(barInitialized, true);
- });
- });
- });
-
- it('returns registered action', async () => {
- await runWithRegistry(registry, async () => {
- const fooSomethingAction = action(
- { name: 'foo_something' },
- async () => null
- );
- registerAction('model', fooSomethingAction);
- const barSomethingAction = action(
- { name: 'bar_something' },
- async () => null
- );
- registerAction('model', barSomethingAction);
-
- assert.strictEqual(
- await lookupAction('/model/foo_something'),
- fooSomethingAction
- );
- assert.strictEqual(
- await lookupAction('/model/bar_something'),
- barSomethingAction
- );
- });
- });
-
- it('returns action registered by plugin', async () => {
- await runWithRegistry(registry, async () => {
- registerPluginProvider('foo', {
- name: 'foo',
- async initializer() {
- registerAction('model', somethingAction);
- return {};
- },
- });
- const somethingAction = action(
- {
- name: {
- pluginId: 'foo',
- actionId: 'something',
- },
- },
- async () => null
- );
-
- assert.strictEqual(
- await lookupAction('/model/foo/something'),
- somethingAction
- );
- });
- });
-
- it('returns undefined for unknown action', async () => {
- await runWithRegistry(registry, async () => {
- assert.strictEqual(await lookupAction('/model/foo/something'), undefined);
- });
- });
-});
+import { Registry } from '../src/registry.js';
describe('registry class', () => {
var registry: Registry;
diff --git a/js/genkit/package.json b/js/genkit/package.json
index fbed927e3..c5d18f544 100644
--- a/js/genkit/package.json
+++ b/js/genkit/package.json
@@ -7,7 +7,7 @@
"genai",
"generative-ai"
],
- "version": "0.6.0-dev.2",
+ "version": "0.9.0-dev.1",
"type": "commonjs",
"main": "./lib/cjs/index.js",
"scripts": {
@@ -16,7 +16,8 @@
"build:clean": "rimraf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch",
- "test": "node --import tsx --test tests/*_test.ts"
+ "test": "node --import tsx --test tests/*_test.ts",
+ "test:watch": "node --watch --import tsx --test tests/*_test.ts"
},
"repository": {
"type": "git",
diff --git a/js/genkit/src/chat.ts b/js/genkit/src/chat.ts
index 97405b1b9..7587c941a 100644
--- a/js/genkit/src/chat.ts
+++ b/js/genkit/src/chat.ts
@@ -15,6 +15,7 @@
*/
import {
+ ExecutablePrompt,
GenerateOptions,
GenerateResponse,
GenerateStreamOptions,
@@ -24,24 +25,25 @@ import {
Part,
} from '@genkit-ai/ai';
import { z } from '@genkit-ai/core';
-import { v4 as uuidv4 } from 'uuid';
import { Genkit } from './genkit';
-import {
- Session,
- SessionData,
- SessionStore,
- inMemorySessionStore,
-} from './session';
+import { Session, SessionStore } from './session';
export const MAIN_THREAD = 'main';
export type BaseGenerateOptions = Omit;
-export type ChatOptions =
- BaseGenerateOptions & {
- store?: SessionStore;
- sessionId?: string;
- };
+export interface PromptRenderOptions {
+ prompt: ExecutablePrompt;
+ input?: I;
+}
+
+export type ChatOptions<
+ I = undefined,
+ S extends z.ZodTypeAny = z.ZodTypeAny,
+> = (PromptRenderOptions | BaseGenerateOptions) & {
+ store?: SessionStore;
+ sessionId?: string;
+};
/**
* Chat encapsulates a statful execution environment for chat.
@@ -56,75 +58,55 @@ export type ChatOptions =
* ```
*/
export class Chat {
+ readonly requestBase?: Promise;
readonly sessionId: string;
readonly schema?: S;
- private sessionData?: SessionData;
- private store: SessionStore;
+ private _messages?: MessageData[];
private threadName: string;
constructor(
- readonly parent: Genkit | Session | Chat,
- readonly requestBase?: BaseGenerateOptions,
- options?: {
- id?: string;
- sessionData?: SessionData;
- store?: SessionStore;
- thread?: string;
+ readonly session: Session,
+ requestBase: Promise,
+ options: {
+ id: string;
+ thread: string;
+ messages?: MessageData[];
}
) {
- this.sessionId = options?.id ?? uuidv4();
- this.threadName = options?.thread ?? MAIN_THREAD;
- this.sessionData = options?.sessionData;
- if (!this.sessionData) {
- this.sessionData = { id: this.sessionId };
- }
- if (!this.sessionData.threads) {
- this.sessionData!.threads = {};
- }
- // this is handling dotprompt render case
- if (requestBase && requestBase['prompt']) {
- const basePrompt = requestBase['prompt'] as string | Part | Part[];
- let promptMessage: MessageData;
- if (typeof basePrompt === 'string') {
- promptMessage = {
- role: 'user',
- content: [{ text: basePrompt }],
- };
- } else if (Array.isArray(basePrompt)) {
- promptMessage = {
- role: 'user',
- content: basePrompt,
- };
- } else {
- promptMessage = {
- role: 'user',
- content: [basePrompt],
- };
- }
- requestBase.messages = [...(requestBase.messages ?? []), promptMessage];
- }
- if (parent instanceof Chat) {
- if (!this.sessionData.threads[this.threadName]) {
- this!.sessionData.threads[this.threadName] = [
- ...(parent.messages ?? []),
- ...(requestBase?.messages ?? []),
- ];
- }
- } else if (parent instanceof Session) {
- if (!this.sessionData.threads[this.threadName]) {
- this!.sessionData.threads[this.threadName] = [
- ...(requestBase?.messages ?? []),
- ];
+ this.sessionId = options.id;
+ this.threadName = options.thread;
+ this.requestBase = requestBase?.then((rb) => {
+ const requestBase = { ...rb };
+ // this is handling dotprompt render case
+ if (requestBase && requestBase['prompt']) {
+ const basePrompt = requestBase['prompt'] as string | Part | Part[];
+ let promptMessage: MessageData;
+ if (typeof basePrompt === 'string') {
+ promptMessage = {
+ role: 'user',
+ content: [{ text: basePrompt }],
+ };
+ } else if (Array.isArray(basePrompt)) {
+ promptMessage = {
+ role: 'user',
+ content: basePrompt,
+ };
+ } else {
+ promptMessage = {
+ role: 'user',
+ content: [basePrompt],
+ };
+ }
+ requestBase.messages = [...(requestBase.messages ?? []), promptMessage];
}
- } else {
- // Genkit
- if (!this.sessionData.threads[this.threadName]) {
- this.sessionData.threads[this.threadName] = [
- ...(requestBase?.messages ?? []),
- ];
- }
- }
- this.store = options?.store ?? (inMemorySessionStore() as SessionStore);
+ requestBase.messages = [
+ ...(options.messages ?? []),
+ ...(requestBase.messages ?? []),
+ ];
+ this._messages = requestBase.messages;
+ return requestBase;
+ });
+ this._messages = options.messages;
}
async send<
@@ -146,7 +128,7 @@ export class Chat {
} as GenerateOptions;
}
const response = await this.genkit.generate({
- ...this.requestBase,
+ ...(await this.requestBase),
messages: this.messages,
...options,
});
@@ -173,7 +155,7 @@ export class Chat {
} as GenerateOptions;
}
const { response, stream } = await this.genkit.generateStream({
- ...this.requestBase,
+ ...(await this.requestBase),
messages: this.messages,
...options,
});
@@ -187,66 +169,15 @@ export class Chat {
}
private get genkit(): Genkit {
- if (this.parent instanceof Genkit) {
- return this.parent;
- }
- if (this.parent instanceof Session) {
- return this.parent.genkit;
- }
- return this.parent.genkit;
+ return this.session.genkit;
}
- get state(): z.infer {
- // We always get state from the parent. Parent session is the source of truth.
- if (this.parent instanceof Session) {
- return this.parent.state;
- }
- return this.sessionData!.state;
- }
-
- async updateState(data: z.infer): Promise {
- // We always update the state on the parent. Parent session is the source of truth.
- if (this.parent instanceof Session) {
- return this.parent.updateState(data);
- }
- let sessionData = await this.store.get(this.sessionId);
- if (!sessionData) {
- sessionData = {} as SessionData;
- }
- sessionData.state = data;
- this.sessionData = sessionData;
-
- await this.store.save(this.sessionId, sessionData);
- }
-
- get messages(): MessageData[] | undefined {
- if (!this.sessionData?.threads) {
- return undefined;
- }
- return this.sessionData?.threads[this.threadName];
+ get messages(): MessageData[] {
+ return this._messages ?? [];
}
async updateMessages(messages: MessageData[]): Promise {
- let sessionData = await this.store.get(this.sessionId);
- if (!sessionData) {
- sessionData = { id: this.sessionId, threads: {} };
- }
- if (!sessionData.threads) {
- sessionData.threads = {};
- }
- sessionData.threads[this.threadName] = messages;
- this.sessionData = sessionData;
- await this.store.save(this.sessionId, sessionData);
- }
-
- toJSON() {
- if (this.parent instanceof Session) {
- return this.parent.toJSON();
- }
- return this.sessionData;
- }
-
- static fromJSON(data: SessionData) {
- //return new Session();
+ this._messages = messages;
+ await this.session.updateMessages(this.threadName, messages);
}
}
diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts
index 85ad82d8b..c9ff26a75 100644
--- a/js/genkit/src/genkit.ts
+++ b/js/genkit/src/genkit.ts
@@ -26,6 +26,7 @@ import {
EvalResponses,
evaluate,
EvaluatorParams,
+ ExecutablePrompt,
generate,
GenerateOptions,
GenerateRequest,
@@ -35,13 +36,13 @@ import {
GenerateStreamOptions,
GenerateStreamResponse,
GenerationCommonConfigSchema,
- index,
IndexerParams,
ModelArgument,
ModelReference,
Part,
PromptAction,
PromptFn,
+ PromptGenerateOptions,
RankedDocument,
rerank,
RerankerParams,
@@ -81,6 +82,7 @@ import {
defineRetriever,
defineSimpleRetriever,
DocumentData,
+ index,
IndexerAction,
IndexerFn,
RetrieverFn,
@@ -109,10 +111,7 @@ import {
defineDotprompt,
defineHelper,
definePartial,
- Dotprompt,
loadPromptFolder,
- prompt,
- PromptGenerateOptions,
PromptMetadata,
} from '@genkit-ai/dotprompt';
import { v4 as uuidv4 } from 'uuid';
@@ -120,7 +119,7 @@ import { Chat, ChatOptions } from './chat.js';
import { BaseEvalDataPointSchema } from './evaluator.js';
import { logger } from './logging.js';
import { GenkitPlugin, genkitPlugin } from './plugin.js';
-import { lookupAction, Registry, runWithRegistry } from './registry.js';
+import { Registry } from './registry.js';
import {
getCurrentSession,
Session,
@@ -144,65 +143,6 @@ export interface GenkitOptions {
flowServer?: FlowServerOptions | boolean;
}
-export interface ExecutablePrompt<
- I extends z.ZodTypeAny = z.ZodTypeAny,
- O extends z.ZodTypeAny = z.ZodTypeAny,
- CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
-> {
- /**
- * Generates a response by rendering the prompt template with given user input and then calling the model.
- *
- * @param input Prompt inputs.
- * @param opt Options for the prompt template, including user input variables and custom model configuration options.
- * @returns the model response as a promise of `GenerateStreamResponse`.
- */
- (
- input?: z.infer,
- opts?: z.infer
- ): Promise>>;
-
- /**
- * Generates a response by rendering the prompt template with given user input and then calling the model.
- * @param input Prompt inputs.
- * @param opt Options for the prompt template, including user input variables and custom model configuration options.
- * @returns the model response as a promise of `GenerateStreamResponse`.
- */
- stream(
- input?: z.infer,
- opts?: z.infer
- ): Promise>>;
-
- /**
- * Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model.
- *
- * @param opt Options for the prompt template, including user input variables and custom model configuration options.
- * @returns the model response as a promise of `GenerateResponse`.
- */
- generate(
- opt: PromptGenerateOptions, CustomOptions>
- ): Promise>>;
-
- /**
- * Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model.
- *
- * @param opt Options for the prompt template, including user input variables and custom model configuration options.
- * @returns the model response as a promise of `GenerateStreamResponse`.
- */
- generateStream(
- opt: PromptGenerateOptions, CustomOptions>
- ): Promise>>;
-
- /**
- * Renders the prompt template based on user input.
- *
- * @param opt Options for the prompt template, including user input variables and custom model configuration options.
- * @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai.
- */
- render(
- opt: PromptGenerateOptions, CustomOptions>
- ): Promise>;
-}
-
/**
* `Genkit` encapsulates a single Genkit instance including the {@link Registry}, {@link ReflectionServer}, {@link FlowServer}, and configuration.
*
@@ -253,7 +193,7 @@ export class Genkit {
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
>(config: FlowConfig | string, fn: FlowFn): CallableFlow {
- const flow = runWithRegistry(this.registry, () => defineFlow(config, fn));
+ const flow = defineFlow(this.registry, config, fn);
this.registeredFlows.push(flow.flow);
return flow;
}
@@ -271,9 +211,7 @@ export class Genkit {
config: StreamingFlowConfig,
fn: FlowFn
): StreamableFlow {
- const flow = runWithRegistry(this.registry, () =>
- defineStreamingFlow(config, fn)
- );
+ const flow = defineStreamingFlow(this.registry, config, fn);
this.registeredFlows.push(flow.flow);
return flow;
}
@@ -287,7 +225,7 @@ export class Genkit {
config: ToolConfig,
fn: (input: z.infer) => Promise>
): ToolAction {
- return runWithRegistry(this.registry, () => defineTool(config, fn));
+ return defineTool(this.registry, config, fn);
}
/**
@@ -296,7 +234,7 @@ export class Genkit {
* Defined schemas can be referenced by `name` in prompts in place of inline schemas.
*/
defineSchema(name: string, schema: T): T {
- return runWithRegistry(this.registry, () => defineSchema(name, schema));
+ return defineSchema(this.registry, name, schema);
}
/**
@@ -305,9 +243,7 @@ export class Genkit {
* Defined schemas can be referenced by `name` in prompts in place of inline schemas.
*/
defineJsonSchema(name: string, jsonSchema: JSONSchema) {
- return runWithRegistry(this.registry, () =>
- defineJsonSchema(name, jsonSchema)
- );
+ return defineJsonSchema(this.registry, name, jsonSchema);
}
/**
@@ -320,7 +256,7 @@ export class Genkit {
streamingCallback?: StreamingCallback
) => Promise
): ModelAction {
- return runWithRegistry(this.registry, () => defineModel(options, runner));
+ return defineModel(this.registry, options, runner);
}
/**
@@ -328,33 +264,21 @@ export class Genkit {
*
* @todo TODO: Show an example of a name and variant.
*/
- prompt<
+ async prompt<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
name: string,
options?: { variant?: string }
- ): Promise> {
- return runWithRegistry(this.registry, async () => {
- const action = (await lookupAction(`/prompt/${name}`)) as PromptAction;
- if (
- action.__action?.metadata?.prompt &&
- Object.keys(action.__action.metadata.prompt).length > 0
- ) {
- const p = await prompt(name, options);
- return this.wrapDotpromptInExecutablePrompt(p, {}) as ExecutablePrompt<
- I,
- O,
- CustomOptions
- >;
- } else {
- return this.wrapPromptActionInExecutablePrompt(
- action,
- {}
- ) as ExecutablePrompt;
- }
- });
+ ): Promise, O, CustomOptions>> {
+ const action = (await this.registry.lookupAction(
+ `/prompt/${name}`
+ )) as PromptAction