Skip to content

Commit

Permalink
revert googleai
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh committed Oct 22, 2024
1 parent de8012e commit 7b554f5
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 97 deletions.
2 changes: 1 addition & 1 deletion js/plugins/googleai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"genai",
"generative-ai"
],
"version": "0.6.0-dev.2",
"version": "0.9.0-dev.1",
"type": "commonjs",
"scripts": {
"check": "tsc",
Expand Down
42 changes: 30 additions & 12 deletions js/plugins/googleai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai';
import { Genkit, z } from 'genkit';
import { EmbedderReference, Genkit, z } from 'genkit';
import { embedderRef } from 'genkit/embedder';
import { PluginOptions } from './index.js';

Expand All @@ -28,22 +28,21 @@ export const TaskTypeSchema = z.enum([
]);
export type TaskType = z.infer<typeof TaskTypeSchema>;

export const TextEmbeddingGeckoConfigSchema = z.object({
export const GeminiEmbeddingConfigSchema = z.object({
/**
* The `task_type` parameter is defined as the intended downstream application to help the model
* produce better quality embeddings.
**/
taskType: TaskTypeSchema.optional(),
title: z.string().optional(),
version: z.string().optional(),
});

export type TextEmbeddingGeckoConfig = z.infer<
typeof TextEmbeddingGeckoConfigSchema
>;
export type GeminiEmbeddingConfig = z.infer<typeof GeminiEmbeddingConfigSchema>;

export const textEmbeddingGecko001 = embedderRef({
name: 'googleai/embedding-001',
configSchema: TextEmbeddingGeckoConfigSchema,
configSchema: GeminiEmbeddingConfigSchema,
info: {
dimensions: 768,
label: 'Google Gen AI - Text Embedding Gecko (Legacy)',
Expand All @@ -57,7 +56,7 @@ export const SUPPORTED_MODELS = {
'embedding-001': textEmbeddingGecko001,
};

export function textEmbeddingGeckoEmbedder(
export function defineGoogleAIEmbedder(
ai: Genkit,
name: string,
options: PluginOptions
Expand All @@ -71,17 +70,36 @@ export function textEmbeddingGeckoEmbedder(
'Please pass in the API key or set either GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable.\n' +
'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai'
);
const client = new GoogleGenerativeAI(apiKey).getGenerativeModel({
model: name,
});
const embedder = SUPPORTED_MODELS[name];
const embedder: EmbedderReference =
SUPPORTED_MODELS[name] ??
embedderRef({
name: name,
configSchema: GeminiEmbeddingConfigSchema,
info: {
dimensions: 768,
label: `Google AI - ${name}`,
supports: {
input: ['text'],
},
},
});
const apiModelName = embedder.name.startsWith('googleai/')
? embedder.name.substring('googleai/'.length)
: embedder.name;
return ai.defineEmbedder(
{
name: embedder.name,
configSchema: TextEmbeddingGeckoConfigSchema,
configSchema: GeminiEmbeddingConfigSchema,
info: embedder.info!,
},
async (input, options) => {
const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel({
model:
options?.version ||
embedder.config?.version ||
embedder.version ||
apiModelName,
});
const embeddings = await Promise.all(
input.map(async (doc) => {
const response = await client.embedContent({
Expand Down
143 changes: 78 additions & 65 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import {
MediaPart,
MessageData,
ModelAction,
ModelInfo,
ModelMiddleware,
ModelReference,
Part,
Expand Down Expand Up @@ -69,16 +70,16 @@ const SafetySettingsSchema = z.object({
]),
});

const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
safetySettings: z.array(SafetySettingsSchema).optional(),
codeExecution: z.union([z.boolean(), z.object({}).strict()]).optional(),
});

export const geminiPro = modelRef({
name: 'googleai/gemini-pro',
export const gemini10Pro = modelRef({
name: 'googleai/gemini-1.0-pro',
info: {
label: 'Google AI - Gemini Pro',
versions: ['gemini-1.0-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'],
versions: ['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'],
supports: {
multiturn: true,
media: false,
Expand All @@ -89,28 +90,8 @@ export const geminiPro = modelRef({
configSchema: GeminiConfigSchema,
});

/**
* @deprecated Use `gemini15Pro` or `gemini15Flash` instead.
*/
export const geminiProVision = modelRef({
name: 'googleai/gemini-pro-vision',
info: {
label: 'Google AI - Gemini Pro Vision',
// none declared on https://ai.google.dev/models/gemini#model-variations
versions: [],
supports: {
multiturn: true,
media: true,
tools: false,
systemRole: false,
},
stage: 'deprecated',
},
configSchema: GeminiConfigSchema,
});

export const gemini15Pro = modelRef({
name: 'googleai/gemini-1.5-pro-latest',
name: 'googleai/gemini-1.5-pro',
info: {
label: 'Google AI - Gemini 1.5 Pro',
supports: {
Expand All @@ -120,13 +101,17 @@ export const gemini15Pro = modelRef({
systemRole: true,
output: ['text', 'json'],
},
versions: ['gemini-1.5-pro-001'],
versions: [
'gemini-1.5-pro-latest',
'gemini-1.5-pro-001',
'gemini-1.5-pro-002',
],
},
configSchema: GeminiConfigSchema,
});

export const gemini15Flash = modelRef({
name: 'googleai/gemini-1.5-flash-latest',
name: 'googleai/gemini-1.5-flash',
info: {
label: 'Google AI - Gemini 1.5 Flash',
supports: {
Expand All @@ -136,44 +121,45 @@ export const gemini15Flash = modelRef({
systemRole: true,
output: ['text', 'json'],
},
versions: ['gemini-1.5-flash-001'],
versions: [
'gemini-1.5-flash-latest',
'gemini-1.5-flash-001',
'gemini-1.5-flash-002',
],
},
configSchema: GeminiConfigSchema,
});

export const geminiUltra = modelRef({
name: 'googleai/gemini-ultra',
export const gemini15Flash8b = modelRef({
name: 'googleai/gemini-1.5-flash-8b',
info: {
label: 'Google AI - Gemini Ultra',
versions: [],
label: 'Google AI - Gemini 1.5 Flash',
supports: {
multiturn: true,
media: false,
media: true,
tools: true,
systemRole: true,
output: ['text', 'json'],
},
versions: ['gemini-1.5-flash-8b-latest', 'gemini-1.5-flash-8b-001'],
},
configSchema: GeminiConfigSchema,
});

export const SUPPORTED_V1_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
> = {
'gemini-pro': geminiPro,
'gemini-pro-vision': geminiProVision,
// 'gemini-ultra': geminiUltra,
export const SUPPORTED_V1_MODELS = {
'gemini-1.0-pro': gemini10Pro,
};

export const SUPPORTED_V15_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
> = {
'gemini-1.5-pro-latest': gemini15Pro,
'gemini-1.5-flash-latest': gemini15Flash,
export const SUPPORTED_V15_MODELS = {
'gemini-1.5-pro': gemini15Pro,
'gemini-1.5-flash': gemini15Flash,
'gemini-1.5-flash-8b': gemini15Flash8b,
};

const SUPPORTED_MODELS = {
export const SUPPORTED_GEMINI_MODELS: Record<
string,
ModelReference<typeof GeminiConfigSchema>
> = {
...SUPPORTED_V1_MODELS,
...SUPPORTED_V15_MODELS,
};
Expand Down Expand Up @@ -453,17 +439,17 @@ export function fromGeminiCandidate(
}

/**
*
* Defines a new GoogleAI model.
*/
export function googleAIModel(
export function defineGoogleAIModel(
ai: Genkit,
name: string,
apiKey?: string,
apiVersion?: string,
baseUrl?: string
baseUrl?: string,
info?: ModelInfo,
defaultConfig?: z.infer<typeof GeminiConfigSchema>
): ModelAction {
const modelName = `googleai/${name}`;

if (!apiKey) {
apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY;
}
Expand All @@ -473,15 +459,33 @@ export function googleAIModel(
'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai'
);
}

const model: ModelReference<z.ZodTypeAny> = SUPPORTED_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);
const apiModelName = name.startsWith('googleai/')
? name.substring('googleai/'.length)
: name;

const model: ModelReference<z.ZodTypeAny> =
SUPPORTED_GEMINI_MODELS[name] ??
modelRef({
name,
info: {
label: `Google AI - ${apiModelName}`,
supports: {
multiturn: true,
media: true,
tools: true,
systemRole: true,
output: ['text', 'json'],
},
...info,
},
configSchema: GeminiConfigSchema,
});

const middleware: ModelMiddleware[] = [];
if (SUPPORTED_V1_MODELS[name]) {
middleware.push(simulateSystemPrompt());
}
if (model?.info?.supports?.media) {
if (model.info?.supports?.media) {
// the gemini api doesn't support downloading media from http(s)
middleware.push(
downloadRequestMedia({
Expand All @@ -497,7 +501,7 @@ export function googleAIModel(

return ai.defineModel(
{
name: modelName,
name: model.name,
...model.info,
configSchema: model.configSchema,
use: middleware,
Expand All @@ -510,9 +514,18 @@ export function googleAIModel(
if (apiVersion) {
options.baseUrl = baseUrl;
}
const requestConfig = {
...defaultConfig,
...request.config,
};

const client = new GoogleGenerativeAI(apiKey!).getGenerativeModel(
{
model: request.config?.version || model.version || name,
model:
requestConfig.version ||
model.config?.version ||
model.version ||
apiModelName,
},
options
);
Expand Down Expand Up @@ -542,7 +555,7 @@ export function googleAIModel(
});
}

if (request.config?.codeExecution) {
if (requestConfig.codeExecution) {
tools.push({
codeExecution:
request.config.codeExecution === true
Expand All @@ -558,11 +571,11 @@ export function googleAIModel(

const generationConfig: GenerationConfig = {
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
stopSequences: request.config?.stopSequences,
temperature: requestConfig.temperature,
maxOutputTokens: requestConfig.maxOutputTokens,
topK: requestConfig.topK,
topP: requestConfig.topP,
stopSequences: requestConfig.stopSequences,
responseMimeType: jsonMode ? 'application/json' : undefined,
};

Expand All @@ -573,7 +586,7 @@ export function googleAIModel(
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
safetySettings: request.config?.safetySettings,
safetySettings: requestConfig.safetySettings,
} as StartChatParams;
const msg = toGeminiMessage(messages[messages.length - 1], model);
const fromJSONModeScopedGeminiCandidate = (
Expand Down
Loading

0 comments on commit 7b554f5

Please sign in to comment.