From 7b554f57aded603ed0e0d018b3338b958a7b0671 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 22 Oct 2024 13:28:26 -0700 Subject: [PATCH] revert googleai --- js/plugins/googleai/package.json | 2 +- js/plugins/googleai/src/embedder.ts | 42 +++++--- js/plugins/googleai/src/gemini.ts | 143 +++++++++++++++------------- js/plugins/googleai/src/index.ts | 50 ++++++---- 4 files changed, 140 insertions(+), 97 deletions(-) diff --git a/js/plugins/googleai/package.json b/js/plugins/googleai/package.json index 46e56bc46..3ee9417f1 100644 --- a/js/plugins/googleai/package.json +++ b/js/plugins/googleai/package.json @@ -13,7 +13,7 @@ "genai", "generative-ai" ], - "version": "0.6.0-dev.2", + "version": "0.9.0-dev.1", "type": "commonjs", "scripts": { "check": "tsc", diff --git a/js/plugins/googleai/src/embedder.ts b/js/plugins/googleai/src/embedder.ts index 78d729fe1..9e0818488 100644 --- a/js/plugins/googleai/src/embedder.ts +++ b/js/plugins/googleai/src/embedder.ts @@ -15,7 +15,7 @@ */ import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai'; -import { Genkit, z } from 'genkit'; +import { EmbedderReference, Genkit, z } from 'genkit'; import { embedderRef } from 'genkit/embedder'; import { PluginOptions } from './index.js'; @@ -28,22 +28,21 @@ export const TaskTypeSchema = z.enum([ ]); export type TaskType = z.infer; -export const TextEmbeddingGeckoConfigSchema = z.object({ +export const GeminiEmbeddingConfigSchema = z.object({ /** * The `task_type` parameter is defined as the intended downstream application to help the model * produce better quality embeddings. **/ taskType: TaskTypeSchema.optional(), title: z.string().optional(), + version: z.string().optional(), }); -export type TextEmbeddingGeckoConfig = z.infer< - typeof TextEmbeddingGeckoConfigSchema ->; +export type GeminiEmbeddingConfig = z.infer; export const textEmbeddingGecko001 = embedderRef({ name: 'googleai/embedding-001', - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: { dimensions: 768, label: 'Google Gen AI - Text Embedding Gecko (Legacy)', @@ -57,7 +56,7 @@ export const SUPPORTED_MODELS = { 'embedding-001': textEmbeddingGecko001, }; -export function textEmbeddingGeckoEmbedder( +export function defineGoogleAIEmbedder( ai: Genkit, name: string, options: PluginOptions @@ -71,17 +70,36 @@ export function textEmbeddingGeckoEmbedder( 'Please pass in the API key or set either GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable.\n' + 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); - const client = new GoogleGenerativeAI(apiKey).getGenerativeModel({ - model: name, - }); - const embedder = SUPPORTED_MODELS[name]; + const embedder: EmbedderReference = + SUPPORTED_MODELS[name] ?? + embedderRef({ + name: name, + configSchema: GeminiEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Google AI - ${name}`, + supports: { + input: ['text'], + }, + }, + }); + const apiModelName = embedder.name.startsWith('googleai/') + ? embedder.name.substring('googleai/'.length) + : embedder.name; return ai.defineEmbedder( { name: embedder.name, - configSchema: TextEmbeddingGeckoConfigSchema, + configSchema: GeminiEmbeddingConfigSchema, info: embedder.info!, }, async (input, options) => { + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel({ + model: + options?.version || + embedder.config?.version || + embedder.version || + apiModelName, + }); const embeddings = await Promise.all( input.map(async (doc) => { const response = await client.embedContent({ diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index e6bde1a20..28ac77db2 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -38,6 +38,7 @@ import { MediaPart, MessageData, ModelAction, + ModelInfo, ModelMiddleware, ModelReference, Part, @@ -69,16 +70,16 @@ const SafetySettingsSchema = z.object({ ]), }); -const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ +export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ safetySettings: z.array(SafetySettingsSchema).optional(), codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(), }); -export const geminiPro = modelRef({ - name: 'googleai/gemini-pro', +export const gemini10Pro = modelRef({ + name: 'googleai/gemini-1.0-pro', info: { label: 'Google AI - Gemini Pro', - versions: ['gemini-1.0-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], + versions: ['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], supports: { multiturn: true, media: false, @@ -89,28 +90,8 @@ export const geminiPro = modelRef({ configSchema: GeminiConfigSchema, }); -/** - * @deprecated Use `gemini15Pro` or `gemini15Flash` instead. - */ -export const geminiProVision = modelRef({ - name: 'googleai/gemini-pro-vision', - info: { - label: 'Google AI - Gemini Pro Vision', - // none declared on https://ai.google.dev/models/gemini#model-variations - versions: [], - supports: { - multiturn: true, - media: true, - tools: false, - systemRole: false, - }, - stage: 'deprecated', - }, - configSchema: GeminiConfigSchema, -}); - export const gemini15Pro = modelRef({ - name: 'googleai/gemini-1.5-pro-latest', + name: 'googleai/gemini-1.5-pro', info: { label: 'Google AI - Gemini 1.5 Pro', supports: { @@ -120,13 +101,17 @@ export const gemini15Pro = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-pro-001'], + versions: [ + 'gemini-1.5-pro-latest', + 'gemini-1.5-pro-001', + 'gemini-1.5-pro-002', + ], }, configSchema: GeminiConfigSchema, }); export const gemini15Flash = modelRef({ - name: 'googleai/gemini-1.5-flash-latest', + name: 'googleai/gemini-1.5-flash', info: { label: 'Google AI - Gemini 1.5 Flash', supports: { @@ -136,44 +121,45 @@ export const gemini15Flash = modelRef({ systemRole: true, output: ['text', 'json'], }, - versions: ['gemini-1.5-flash-001'], + versions: [ + 'gemini-1.5-flash-latest', + 'gemini-1.5-flash-001', + 'gemini-1.5-flash-002', + ], }, configSchema: GeminiConfigSchema, }); -export const geminiUltra = modelRef({ - name: 'googleai/gemini-ultra', +export const gemini15Flash8b = modelRef({ + name: 'googleai/gemini-1.5-flash-8b', info: { - label: 'Google AI - Gemini Ultra', - versions: [], + label: 'Google AI - Gemini 1.5 Flash', supports: { multiturn: true, - media: false, + media: true, tools: true, systemRole: true, + output: ['text', 'json'], }, + versions: ['gemini-1.5-flash-8b-latest', 'gemini-1.5-flash-8b-001'], }, configSchema: GeminiConfigSchema, }); -export const SUPPORTED_V1_MODELS: Record< - string, - ModelReference -> = { - 'gemini-pro': geminiPro, - 'gemini-pro-vision': geminiProVision, - // 'gemini-ultra': geminiUltra, +export const SUPPORTED_V1_MODELS = { + 'gemini-1.0-pro': gemini10Pro, }; -export const SUPPORTED_V15_MODELS: Record< - string, - ModelReference -> = { - 'gemini-1.5-pro-latest': gemini15Pro, - 'gemini-1.5-flash-latest': gemini15Flash, +export const SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro': gemini15Pro, + 'gemini-1.5-flash': gemini15Flash, + 'gemini-1.5-flash-8b': gemini15Flash8b, }; -const SUPPORTED_MODELS = { +export const SUPPORTED_GEMINI_MODELS: Record< + string, + ModelReference +> = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, }; @@ -453,17 +439,17 @@ export function fromGeminiCandidate( } /** - * + * Defines a new GoogleAI model. */ -export function googleAIModel( +export function defineGoogleAIModel( ai: Genkit, name: string, apiKey?: string, apiVersion?: string, - baseUrl?: string + baseUrl?: string, + info?: ModelInfo, + defaultConfig?: z.infer ): ModelAction { - const modelName = `googleai/${name}`; - if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; } @@ -473,15 +459,33 @@ export function googleAIModel( 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai' ); } - - const model: ModelReference = SUPPORTED_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); + const apiModelName = name.startsWith('googleai/') + ? name.substring('googleai/'.length) + : name; + + const model: ModelReference = + SUPPORTED_GEMINI_MODELS[name] ?? + modelRef({ + name, + info: { + label: `Google AI - ${apiModelName}`, + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text', 'json'], + }, + ...info, + }, + configSchema: GeminiConfigSchema, + }); const middleware: ModelMiddleware[] = []; if (SUPPORTED_V1_MODELS[name]) { middleware.push(simulateSystemPrompt()); } - if (model?.info?.supports?.media) { + if (model.info?.supports?.media) { // the gemini api doesn't support downloading media from http(s) middleware.push( downloadRequestMedia({ @@ -497,7 +501,7 @@ export function googleAIModel( return ai.defineModel( { - name: modelName, + name: model.name, ...model.info, configSchema: model.configSchema, use: middleware, @@ -510,9 +514,18 @@ export function googleAIModel( if (apiVersion) { options.baseUrl = baseUrl; } + const requestConfig = { + ...defaultConfig, + ...request.config, + }; + const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel( { - model: request.config?.version || model.version || name, + model: + requestConfig.version || + model.config?.version || + model.version || + apiModelName, }, options ); @@ -542,7 +555,7 @@ export function googleAIModel( }); } - if (request.config?.codeExecution) { + if (requestConfig.codeExecution) { tools.push({ codeExecution: request.config.codeExecution === true @@ -558,11 +571,11 @@ export function googleAIModel( const generationConfig: GenerationConfig = { candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, - stopSequences: request.config?.stopSequences, + temperature: requestConfig.temperature, + maxOutputTokens: requestConfig.maxOutputTokens, + topK: requestConfig.topK, + topP: requestConfig.topP, + stopSequences: requestConfig.stopSequences, responseMimeType: jsonMode ? 'application/json' : undefined, }; @@ -573,7 +586,7 @@ export function googleAIModel( history: messages .slice(0, -1) .map((message) => toGeminiMessage(message, model)), - safetySettings: request.config?.safetySettings, + safetySettings: requestConfig.safetySettings, } as StartChatParams; const msg = toGeminiMessage(messages[messages.length - 1], model); const fromJSONModeScopedGeminiCandidate = ( diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index 2a62bca9d..abb731426 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -18,25 +18,18 @@ import { Genkit } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { SUPPORTED_MODELS as EMBEDDER_MODELS, + defineGoogleAIEmbedder, textEmbeddingGecko001, - textEmbeddingGeckoEmbedder, } from './embedder.js'; import { SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, + defineGoogleAIModel, + gemini10Pro, gemini15Flash, gemini15Pro, - geminiPro, - geminiProVision, - googleAIModel, } from './gemini.js'; -export { - gemini15Flash, - gemini15Pro, - geminiPro, - geminiProVision, - textEmbeddingGecko001, -}; +export { gemini10Pro, gemini15Flash, gemini15Pro, textEmbeddingGecko001 }; export interface PluginOptions { apiKey?: string; @@ -56,21 +49,40 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { } } if (apiVersions.includes('v1beta')) { - Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, 'v1beta', options?.baseUrl) + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + 'v1beta', + options?.baseUrl + ) ); } if (apiVersions.includes('v1')) { - Object.keys(SUPPORTED_V1_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, undefined, options?.baseUrl) + Object.keys(SUPPORTED_V1_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) ); - Object.keys(SUPPORTED_V15_MODELS).map((name) => - googleAIModel(ai, name, options?.apiKey, undefined, options?.baseUrl) + Object.keys(SUPPORTED_V15_MODELS).forEach((name) => + defineGoogleAIModel( + ai, + name, + options?.apiKey, + undefined, + options?.baseUrl + ) ); - Object.keys(EMBEDDER_MODELS).map((name) => - textEmbeddingGeckoEmbedder(ai, name, { apiKey: options?.apiKey }) + Object.keys(EMBEDDER_MODELS).forEach((name) => + defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) ); } }); } + export default googleAI;