Skip to content

Commit 6191482

Browse files
committed
feat(chat): add support for Azure API with versioning and refactor model handling
1 parent c3f552e commit 6191482

File tree

3 files changed

+64
-25
lines changed

3 files changed

+64
-25
lines changed

crates/http-api-bindings/src/chat/mod.rs

+42-25
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,48 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
1212
.api_endpoint
1313
.as_deref()
1414
.expect("api_endpoint is required");
15-
let config = OpenAIConfig::default()
16-
.with_api_base(api_endpoint)
17-
.with_api_key(model.api_key.clone().unwrap_or_default());
18-
19-
let mut builder = ExtendedOpenAIConfig::builder();
20-
21-
builder
22-
.base(config)
23-
.supported_models(model.supported_models.clone())
24-
.model_name(model.model_name.as_deref().expect("Model name is required"));
25-
26-
if model.kind == "openai/chat" {
27-
// Do nothing
28-
} else if model.kind == "mistral/chat" {
29-
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
30-
} else {
31-
panic!("Unsupported model kind: {}", model.kind);
32-
}
33-
34-
let config = builder.build().expect("Failed to build config");
35-
36-
let engine = Box::new(
37-
async_openai_alt::Client::with_config(config)
38-
.with_http_client(create_reqwest_client(api_endpoint)),
39-
);
15+
16+
let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
17+
"azure/chat" => {
18+
let config = async_openai_alt::config::AzureConfig::new()
19+
.with_api_base(api_endpoint)
20+
.with_api_key(model.api_key.clone().unwrap_or_default())
21+
.with_api_version(
22+
model
23+
.api_version
24+
.clone()
25+
.unwrap_or("2024-05-01-preview".to_string()),
26+
)
27+
.with_deployment_id(model.model_name.as_deref().expect("Model name is required"));
28+
Box::new(
29+
async_openai_alt::Client::with_config(config)
30+
.with_http_client(create_reqwest_client(api_endpoint)),
31+
)
32+
}
33+
"openai/chat" | "mistral/chat" => {
34+
let config = OpenAIConfig::default()
35+
.with_api_base(api_endpoint)
36+
.with_api_key(model.api_key.clone().unwrap_or_default());
37+
38+
let mut builder = ExtendedOpenAIConfig::builder();
39+
builder
40+
.base(config)
41+
.supported_models(model.supported_models.clone())
42+
.model_name(model.model_name.as_deref().expect("Model name is required"));
43+
44+
if model.kind == "mistral/chat" {
45+
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
46+
}
47+
48+
Box::new(
49+
async_openai_alt::Client::with_config(
50+
builder.build().expect("Failed to build config"),
51+
)
52+
.with_http_client(create_reqwest_client(api_endpoint)),
53+
)
54+
}
55+
_ => panic!("Unsupported model kind: {}", model.kind),
56+
};
4057

4158
Arc::new(rate_limit::new_chat(
4259
engine,

crates/tabby-common/src/config.rs

+5
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,11 @@ pub struct HttpModelConfig {
310310

311311
#[builder(default)]
312312
pub additional_stop_words: Option<Vec<String>>,
313+
314+
/// Used For Azure API to specify the api version
315+
#[builder(default)]
316+
#[serde(default)]
317+
pub api_version: Option<String>,
313318
}
314319

315320
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]

crates/tabby-inference/src/chat.rs

+17
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
125125
self.chat().create_stream(request).await
126126
}
127127
}
128+
129+
#[async_trait]
130+
impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config::AzureConfig> {
131+
async fn chat(
132+
&self,
133+
request: CreateChatCompletionRequest,
134+
) -> Result<CreateChatCompletionResponse, OpenAIError> {
135+
self.chat().create(request).await
136+
}
137+
138+
async fn chat_stream(
139+
&self,
140+
request: CreateChatCompletionRequest,
141+
) -> Result<ChatCompletionResponseStream, OpenAIError> {
142+
self.chat().create_stream(request).await
143+
}
144+
}

0 commit comments

Comments
 (0)