Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): adding Azure AI Foundry OpenAPI support #3683

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,43 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
.api_endpoint
.as_deref()
.expect("api_endpoint is required");
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();

builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "openai/chat" {
// Do nothing
} else if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
} else {
panic!("Unsupported model kind: {}", model.kind);
}

let config = builder.build().expect("Failed to build config");

let engine = Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
);

let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
"azure/chat" => {
let config = async_openai_alt::config::AzureConfig::new()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default())
.with_api_version("2024-08-01-preview")
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
.with_deployment_id(model.model_name.as_deref().expect("Model name is required"));
Box::new(
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
"openai/chat" | "mistral/chat" => {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();
builder
.base(config)
.supported_models(model.supported_models.clone())
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
}

Box::new(
async_openai_alt::Client::with_config(
builder.build().expect("Failed to build config"),
)
.with_http_client(create_reqwest_client(api_endpoint)),
)
}
_ => panic!("Unsupported model kind: {}", model.kind),
};

Arc::new(rate_limit::new_chat(
engine,
Expand Down
125 changes: 125 additions & 0 deletions crates/http-api-bindings/src/embedding/azure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;

/// `AzureEmbeddingEngine` is responsible for interacting with Azure's Embedding API.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
#[derive(Clone)]
pub struct AzureEmbeddingEngine {
client: Arc<Client>,
api_endpoint: String,
api_key: String,
api_version: String,
}

/// Structure representing the request body for embedding.
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
input: String,
}

/// Structure representing the response from the embedding API.
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<Data>,
}

/// Structure representing individual embedding data.
#[derive(Debug, Deserialize)]
struct Data {
embedding: Vec<f32>,
}

impl AzureEmbeddingEngine {
/// Creates a new instance of `AzureEmbeddingEngine`.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
///
/// # Parameters
///
/// - `api_endpoint`: The base URL of the Azure Embedding API.
/// - `model_name`: The name of the deployed model, used to construct the deployment ID.
/// - `api_key`: Optional API key for authentication.
/// - `api_version`: Optional API version, defaults to "2023-05-15".
///
/// # Returns
///
/// A boxed instance that implements the `Embedding` trait.
pub fn create(
api_endpoint: &str,
model_name: &str,
api_key: Option<&str>,
api_version: Option<&str>,
) -> Box<dyn Embedding> {
let client = Client::new();
let deployment_id = model_name;
// Construct the full endpoint URL for the Azure Embedding API
let azure_endpoint = format!(
"{}/openai/deployments/{}/embeddings",
api_endpoint.trim_end_matches('/'),
deployment_id
);

Box::new(Self {
client: Arc::new(client),
api_endpoint: azure_endpoint,
api_key: api_key.unwrap_or_default().to_owned(),
// Use a specific API version; currently, only this version is supported
api_version: api_version.unwrap_or("2023-05-15").to_owned(),
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
})
}
}

#[async_trait]
impl Embedding for AzureEmbeddingEngine {
/// Generates an embedding vector for the given prompt.
///
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
///
/// # Parameters
///
/// - `prompt`: The input text to generate embeddings for.
///
/// # Returns
///
/// A `Result` containing the embedding vector or an error.
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
// Clone all necessary fields to ensure thread safety across await points
let client = self.client.clone();
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
let api_endpoint = self.api_endpoint.clone();
let api_key = self.api_key.clone();
let api_version = self.api_version.clone();
let request = EmbeddingRequest {
input: prompt.to_owned(),
};

// Send a POST request to the Azure Embedding API
let response = client
.post(&api_endpoint)
.query(&[("api-version", &api_version)])
.header("api-key", &api_key)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;

// Check if the response status indicates success
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("Azure API error: {}", error_text);
}

// Deserialize the response body into `EmbeddingResponse`
let embedding_response: EmbeddingResponse = response.json().await?;
embedding_response
.data
.first()
.map(|data| data.embedding.clone())
.ok_or_else(|| anyhow::anyhow!("No embedding data received"))
}
}
11 changes: 11 additions & 0 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod azure;
mod llama;
mod openai;

use core::panic;
use std::sync::Arc;

use azure::AzureEmbeddingEngine;
use llama::LlamaCppEngine;
use openai::OpenAIEmbeddingEngine;
use tabby_common::config::HttpModelConfig;
Expand Down Expand Up @@ -40,6 +42,15 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
.expect("model_name must be set for voyage/embedding"),
config.api_key.as_deref(),
),
"azure/embedding" => AzureEmbeddingEngine::create(
config
.api_endpoint
.as_deref()
.expect("api_endpoint is required for azure/embedding"),
config.model_name.as_deref().unwrap_or_default(), // Provide a default if model_name is optional
config.api_key.as_deref(),
Some("2023-05-15"),
Sma1lboy marked this conversation as resolved.
Show resolved Hide resolved
),
unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
unsupported_kind
Expand Down
17 changes: 17 additions & 0 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
self.chat().create_stream(request).await
}
}

#[async_trait]
impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config::AzureConfig> {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
self.chat().create(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
self.chat().create_stream(request).await
}
}
Loading