@@ -12,31 +12,48 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
12
12
. api_endpoint
13
13
. as_deref ( )
14
14
. expect ( "api_endpoint is required" ) ;
15
- let config = OpenAIConfig :: default ( )
16
- . with_api_base ( api_endpoint)
17
- . with_api_key ( model. api_key . clone ( ) . unwrap_or_default ( ) ) ;
18
-
19
- let mut builder = ExtendedOpenAIConfig :: builder ( ) ;
20
-
21
- builder
22
- . base ( config)
23
- . supported_models ( model. supported_models . clone ( ) )
24
- . model_name ( model. model_name . as_deref ( ) . expect ( "Model name is required" ) ) ;
25
-
26
- if model. kind == "openai/chat" {
27
- // Do nothing
28
- } else if model. kind == "mistral/chat" {
29
- builder. fields_to_remove ( ExtendedOpenAIConfig :: mistral_fields_to_remove ( ) ) ;
30
- } else {
31
- panic ! ( "Unsupported model kind: {}" , model. kind) ;
32
- }
33
-
34
- let config = builder. build ( ) . expect ( "Failed to build config" ) ;
35
-
36
- let engine = Box :: new (
37
- async_openai_alt:: Client :: with_config ( config)
38
- . with_http_client ( create_reqwest_client ( api_endpoint) ) ,
39
- ) ;
15
+
16
+ let engine: Box < dyn ChatCompletionStream > = match model. kind . as_str ( ) {
17
+ "azure/chat" => {
18
+ let config = async_openai_alt:: config:: AzureConfig :: new ( )
19
+ . with_api_base ( api_endpoint)
20
+ . with_api_key ( model. api_key . clone ( ) . unwrap_or_default ( ) )
21
+ . with_api_version (
22
+ model
23
+ . api_version
24
+ . clone ( )
25
+ . unwrap_or ( "2024-05-01-preview" . to_string ( ) ) ,
26
+ )
27
+ . with_deployment_id ( model. model_name . as_deref ( ) . expect ( "Model name is required" ) ) ;
28
+ Box :: new (
29
+ async_openai_alt:: Client :: with_config ( config)
30
+ . with_http_client ( create_reqwest_client ( api_endpoint) ) ,
31
+ )
32
+ }
33
+ "openai/chat" | "mistral/chat" => {
34
+ let config = OpenAIConfig :: default ( )
35
+ . with_api_base ( api_endpoint)
36
+ . with_api_key ( model. api_key . clone ( ) . unwrap_or_default ( ) ) ;
37
+
38
+ let mut builder = ExtendedOpenAIConfig :: builder ( ) ;
39
+ builder
40
+ . base ( config)
41
+ . supported_models ( model. supported_models . clone ( ) )
42
+ . model_name ( model. model_name . as_deref ( ) . expect ( "Model name is required" ) ) ;
43
+
44
+ if model. kind == "mistral/chat" {
45
+ builder. fields_to_remove ( ExtendedOpenAIConfig :: mistral_fields_to_remove ( ) ) ;
46
+ }
47
+
48
+ Box :: new (
49
+ async_openai_alt:: Client :: with_config (
50
+ builder. build ( ) . expect ( "Failed to build config" ) ,
51
+ )
52
+ . with_http_client ( create_reqwest_client ( api_endpoint) ) ,
53
+ )
54
+ }
55
+ _ => panic ! ( "Unsupported model kind: {}" , model. kind) ,
56
+ } ;
40
57
41
58
Arc :: new ( rate_limit:: new_chat (
42
59
engine,
0 commit comments