Skip to content

Commit

Permalink
Merge pull request #1442 from Firbydude/koswald/voyageai
Browse files Browse the repository at this point in the history
feat: Add support for VoyageAI embeddings API
  • Loading branch information
odilitime authored Jan 14, 2025
2 parents 8f92573 + ca4f01c commit 172d645
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 80 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ SMALL_ANTHROPIC_MODEL= # Default: claude-3-haiku-20240307
MEDIUM_ANTHROPIC_MODEL= # Default: claude-3-5-sonnet-20241022
LARGE_ANTHROPIC_MODEL= # Default: claude-3-5-sonnet-20241022

# VoyageAI Configuration
VOYAGEAI_API_KEY=
USE_VOYAGEAI_EMBEDDING= # Set to TRUE for VoyageAI, leave blank for local
VOYAGEAI_EMBEDDING_MODEL= # Default: voyage-3-lite
VOYAGEAI_EMBEDDING_DIMENSIONS= # Default: 512

# Heurist Configuration
HEURIST_API_KEY= # Get from https://heurist.ai/dev-access
SMALL_HEURIST_MODEL= # Default: meta-llama/llama-3-70b-instruct
Expand Down
165 changes: 85 additions & 80 deletions packages/core/src/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { getEmbeddingModelSettings, getEndpoint } from "./models.ts";
import { IAgentRuntime, ModelProviderName } from "./types.ts";
import path from "node:path";
import settings from "./settings.ts";
import elizaLogger from "./logger.ts";
import { getVoyageAIEmbeddingConfig } from "./voyageai.ts";
import { models, getEmbeddingModelSettings, getEndpoint } from "./models.ts";
import { IAgentRuntime, ModelProviderName } from "./types.ts";
import LocalEmbeddingModelManager from "./localembeddingManager.ts";

interface EmbeddingOptions {
Expand All @@ -20,63 +22,93 @@ export const EmbeddingProvider = {
GaiaNet: "GaiaNet",
Heurist: "Heurist",
BGE: "BGE",
VoyageAI: "VoyageAI",
} as const;

export type EmbeddingProviderType =
(typeof EmbeddingProvider)[keyof typeof EmbeddingProvider];

export namespace EmbeddingProvider {
export type OpenAI = typeof EmbeddingProvider.OpenAI;
export type Ollama = typeof EmbeddingProvider.Ollama;
export type GaiaNet = typeof EmbeddingProvider.GaiaNet;
export type BGE = typeof EmbeddingProvider.BGE;
export type VoyageAI = typeof EmbeddingProvider.VoyageAI;
}

export type EmbeddingConfig = {
readonly dimensions: number;
readonly model: string;
readonly provider: EmbeddingProviderType;
readonly provider: EmbeddingProvider;
readonly endpoint?: string;
readonly apiKey?: string;
readonly maxInputTokens?: number;
};

export const getEmbeddingConfig = (): EmbeddingConfig => ({
dimensions:
settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.OPENAI).dimensions
: settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.OLLAMA).dimensions
: settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.GAIANET)
.dimensions
: settings.USE_HEURIST_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.HEURIST)
.dimensions
: 384, // BGE
model:
settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.OPENAI).name
: settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.OLLAMA).name
: settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.GAIANET).name
: settings.USE_HEURIST_EMBEDDING?.toLowerCase() === "true"
? getEmbeddingModelSettings(ModelProviderName.HEURIST).name
: "BGE-small-en-v1.5",
provider:
settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true"
? "OpenAI"
: settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true"
? "Ollama"
: settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
? "GaiaNet"
: settings.USE_HEURIST_EMBEDDING?.toLowerCase() === "true"
? "Heurist"
: "BGE",
});
// Get embedding config based on settings
export function getEmbeddingConfig(): EmbeddingConfig {
if (settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true") {
return {
dimensions: 1536,
model: "text-embedding-3-small",
provider: "OpenAI",
endpoint: "https://api.openai.com/v1",
apiKey: settings.OPENAI_API_KEY,
maxInputTokens: 1000000,
};
}
if (settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true") {
return {
dimensions: 1024,
model: settings.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large",
provider: "Ollama",
endpoint: "https://ollama.eliza.ai/",
apiKey: settings.OLLAMA_API_KEY,
maxInputTokens: 1000000,
};
}
if (settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true") {
return {
dimensions: 768,
model: settings.GAIANET_EMBEDDING_MODEL || "nomic-embed",
provider: "GaiaNet",
endpoint: settings.SMALL_GAIANET_SERVER_URL || settings.MEDIUM_GAIANET_SERVER_URL || settings.LARGE_GAIANET_SERVER_URL,
apiKey: settings.GAIANET_API_KEY,
maxInputTokens: 1000000,
};
}
if (settings.USE_VOYAGEAI_EMBEDDING?.toLowerCase() === "true") {
return getVoyageAIEmbeddingConfig();
}

// Fallback to local BGE
return {
dimensions: 384,
model: "BGE-small-en-v1.5",
provider: "BGE",
maxInputTokens: 1000000,
};
};

async function getRemoteEmbedding(
input: string,
options: EmbeddingOptions
options: EmbeddingConfig
): Promise<number[]> {
// Ensure endpoint ends with /v1 for OpenAI
const baseEndpoint = options.endpoint.endsWith("/v1")
? options.endpoint
: `${options.endpoint}${options.isOllama ? "/v1" : ""}`;
elizaLogger.debug("Getting remote embedding using provider:", options.provider);

// Construct full URL
const fullUrl = `${baseEndpoint}/embeddings`;
const fullUrl = `${options.endpoint}/embeddings`;

// jank. voyageai is the only one that doesn't use "dimensions".
const body = options.provider === "VoyageAI" ? {
input,
model: options.model,
output_dimension: options.dimensions,
} : {
input,
model: options.model,
dimensions: options.dimensions,
};

const requestOptions = {
method: "POST",
Expand All @@ -88,14 +120,7 @@ async function getRemoteEmbedding(
}
: {}),
},
body: JSON.stringify({
input,
model: options.model,
dimensions:
options.dimensions ||
options.length ||
getEmbeddingConfig().dimensions, // Prefer dimensions, fallback to length
}),
body: JSON.stringify(body),
};

try {
Expand Down Expand Up @@ -141,44 +166,18 @@ export function getEmbeddingType(runtime: IAgentRuntime): "local" | "remote" {
}

export function getEmbeddingZeroVector(): number[] {
let embeddingDimension = 384; // Default BGE dimension

if (settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true") {
embeddingDimension = getEmbeddingModelSettings(
ModelProviderName.OPENAI
).dimensions; // OpenAI dimension
} else if (settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true") {
embeddingDimension = getEmbeddingModelSettings(
ModelProviderName.OLLAMA
).dimensions; // Ollama mxbai-embed-large dimension
} else if (settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true") {
embeddingDimension = getEmbeddingModelSettings(
ModelProviderName.GAIANET
).dimensions; // GaiaNet dimension
} else if (settings.USE_HEURIST_EMBEDDING?.toLowerCase() === "true") {
embeddingDimension = getEmbeddingModelSettings(
ModelProviderName.HEURIST
).dimensions; // Heurist dimension
}

return Array(embeddingDimension).fill(0);
// Default BGE dimension is 384
return Array(getEmbeddingConfig().dimensions).fill(0);
}

/**
* Gets embeddings from a remote API endpoint. Falls back to local BGE/384
*
* @param {string} input - The text to generate embeddings for
* @param {EmbeddingOptions} options - Configuration options including:
* - model: The model name to use
* - endpoint: Base API endpoint URL
* - apiKey: Optional API key for authentication
* - isOllama: Whether this is an Ollama endpoint
* - dimensions: Desired embedding dimensions
* @param {IAgentRuntime} runtime - The agent runtime context
* @returns {Promise<number[]>} Array of embedding values
* @throws {Error} If the API request fails
* @throws {Error} If the API request fails or configuration is invalid
*/

export async function embed(runtime: IAgentRuntime, input: string) {
elizaLogger.debug("Embedding request:", {
modelProvider: runtime.character.modelProvider,
Expand Down Expand Up @@ -207,6 +206,11 @@ export async function embed(runtime: IAgentRuntime, input: string) {
const config = getEmbeddingConfig();
const isNode = typeof process !== "undefined" && process.versions?.node;

// Attempt remote embedding if it is configured.
if (config.provider !== EmbeddingProvider.BGE) {
return await getRemoteEmbedding(input, config);
}

// Determine which embedding path to use
if (config.provider === EmbeddingProvider.OpenAI) {
return await getRemoteEmbedding(input, {
Expand Down Expand Up @@ -271,6 +275,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
getEndpoint(runtime.character.modelProvider),
apiKey: runtime.token,
dimensions: config.dimensions,
provider: config.provider,
});

async function getLocalEmbedding(input: string): Promise<number[]> {
Expand Down
102 changes: 102 additions & 0 deletions packages/core/src/tests/embeddings.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

import { describe, expect, vi } from "vitest";
import { getEmbeddingConfig } from '../embedding';
import settings from '../settings';

vi.mock("../settings");
const mockedSettings = vi.mocked(settings);

describe('getEmbeddingConfig', () => {
beforeEach(() => {
// Clear the specific mock
Object.keys(mockedSettings).forEach(key => {
delete mockedSettings[key];
});

vi.clearAllMocks();
});

afterEach(() => {
vi.clearAllMocks();
});

it('should return BGE config by default', () => {

mockedSettings.USE_OPENAI_EMBEDDING = 'false';
mockedSettings.USE_OLLAMA_EMBEDDING = 'false';
mockedSettings.USE_GAIANET_EMBEDDING = 'false';
mockedSettings.USE_VOYAGEAI_EMBEDDING = 'false';

const config = getEmbeddingConfig();
expect(config).toEqual({
dimensions: 384,
model: 'BGE-small-en-v1.5',
provider: 'BGE',
maxInputTokens: 1000000,
});
});

it('should return GaiaNet config when USE_GAIANET_EMBEDDING is true', () => {
mockedSettings.USE_GAIANET_EMBEDDING = 'true';
mockedSettings.GAIANET_EMBEDDING_MODEL = 'test-model';
mockedSettings.GAIANET_API_KEY = 'test-key';
mockedSettings.SMALL_GAIANET_SERVER_URL = 'https://test.gaianet.ai';

const config = getEmbeddingConfig();
expect(config).toEqual({
dimensions: 768,
model: 'test-model',
provider: 'GaiaNet',
endpoint: 'https://test.gaianet.ai',
apiKey: 'test-key',
maxInputTokens: 1000000,
});
});


it('should return VoyageAI config when USE_VOYAGEAI_EMBEDDING is true', () => {
mockedSettings.USE_VOYAGEAI_EMBEDDING = 'true';
mockedSettings.VOYAGEAI_API_KEY = 'test-key';

const config = getEmbeddingConfig();
expect(config).toEqual({
dimensions: 512,
model: 'voyage-3-lite',
provider: 'VoyageAI',
endpoint: 'https://api.voyageai.com/v1',
apiKey: 'test-key',
maxInputTokens: 1000000,
});
});

it('should return OpenAI config when USE_OPENAI_EMBEDDING is true', () => {
mockedSettings.USE_OPENAI_EMBEDDING = 'true';
mockedSettings.OPENAI_API_KEY = 'test-key';

const config = getEmbeddingConfig();
expect(config).toEqual({
dimensions: 1536,
model: 'text-embedding-3-small',
provider: 'OpenAI',
endpoint: 'https://api.openai.com/v1',
apiKey: 'test-key',
maxInputTokens: 1000000,
});
});

it('should return Ollama config when USE_OLLAMA_EMBEDDING is true', () => {
mockedSettings.USE_OLLAMA_EMBEDDING = 'true';
mockedSettings.OLLAMA_EMBEDDING_MODEL = 'test-model';
mockedSettings.OLLAMA_API_KEY = 'test-key';

const config = getEmbeddingConfig();
expect(config).toEqual({
dimensions: 1024,
model: 'test-model',
provider: 'Ollama',
endpoint: 'https://ollama.eliza.ai/v1',
apiKey: 'test-key',
maxInputTokens: 1000000,
});
});
});
Loading

0 comments on commit 172d645

Please sign in to comment.