Skip to content

Commit b991514

Browse files
committed
feat: add Clone support and HTTP client pooling for providers
1 parent 0867c71 commit b991514

File tree

16 files changed

+1446
-243
lines changed

16 files changed

+1446
-243
lines changed

src/backends/anthropic.rs

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use serde_json::Value;
2929
/// Client for interacting with Anthropic's API.
3030
///
3131
/// Provides methods for chat and completion requests using Anthropic's models.
32-
#[derive(Debug)]
32+
#[derive(Debug, Clone)]
3333
pub struct Anthropic {
3434
pub api_key: String,
3535
pub model: String,
@@ -506,6 +506,43 @@ impl Anthropic {
506506
client: builder.build().expect("Failed to build reqwest Client"),
507507
}
508508
}
509+
510+
/// Creates a new Anthropic client with a pre-configured HTTP client.
511+
///
512+
/// This allows sharing a single `reqwest::Client` across multiple providers,
513+
/// enabling connection pooling and reducing resource usage.
514+
#[allow(clippy::too_many_arguments)]
515+
pub fn with_client(
516+
client: Client,
517+
api_key: impl Into<String>,
518+
model: Option<String>,
519+
max_tokens: Option<u32>,
520+
temperature: Option<f32>,
521+
timeout_seconds: Option<u64>,
522+
system: Option<String>,
523+
top_p: Option<f32>,
524+
top_k: Option<u32>,
525+
tools: Option<Vec<Tool>>,
526+
tool_choice: Option<ToolChoice>,
527+
reasoning: Option<bool>,
528+
thinking_budget_tokens: Option<u32>,
529+
) -> Self {
530+
Self {
531+
api_key: api_key.into(),
532+
model: model.unwrap_or_else(|| "claude-3-sonnet-20240229".to_string()),
533+
max_tokens: max_tokens.unwrap_or(300),
534+
temperature: temperature.unwrap_or(0.7),
535+
system: system.unwrap_or_else(|| "You are a helpful assistant.".to_string()),
536+
timeout_seconds: timeout_seconds.unwrap_or(30),
537+
top_p,
538+
top_k,
539+
tools,
540+
tool_choice,
541+
reasoning: reasoning.unwrap_or(false),
542+
thinking_budget_tokens,
543+
client,
544+
}
545+
}
509546
}
510547

511548
#[async_trait]
@@ -1391,4 +1428,81 @@ data: {"type": "ping"}
13911428
let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap();
13921429
assert!(result.is_none());
13931430
}
1431+
1432+
#[test]
1433+
fn test_anthropic_clone() {
1434+
let anthropic = Anthropic::new(
1435+
"test-api-key",
1436+
Some("claude-3-sonnet".to_string()),
1437+
Some(1000),
1438+
Some(0.7),
1439+
Some(30),
1440+
Some("You are helpful.".to_string()),
1441+
None,
1442+
None,
1443+
None,
1444+
None,
1445+
None,
1446+
None,
1447+
);
1448+
1449+
// Clone the provider
1450+
let cloned = anthropic.clone();
1451+
1452+
// Verify both have the same configuration
1453+
assert_eq!(anthropic.api_key, cloned.api_key);
1454+
assert_eq!(anthropic.model, cloned.model);
1455+
assert_eq!(anthropic.max_tokens, cloned.max_tokens);
1456+
assert_eq!(anthropic.temperature, cloned.temperature);
1457+
assert_eq!(anthropic.system, cloned.system);
1458+
}
1459+
1460+
#[test]
1461+
fn test_anthropic_with_client() {
1462+
let shared_client = Client::builder()
1463+
.timeout(std::time::Duration::from_secs(60))
1464+
.build()
1465+
.expect("Failed to build client");
1466+
1467+
let anthropic = Anthropic::with_client(
1468+
shared_client.clone(),
1469+
"test-api-key",
1470+
Some("claude-3-sonnet".to_string()),
1471+
Some(1000),
1472+
Some(0.7),
1473+
Some(30),
1474+
Some("You are helpful.".to_string()),
1475+
None,
1476+
None,
1477+
None,
1478+
None,
1479+
None,
1480+
None,
1481+
);
1482+
1483+
// Verify configuration
1484+
assert_eq!(anthropic.api_key, "test-api-key");
1485+
assert_eq!(anthropic.model, "claude-3-sonnet");
1486+
assert_eq!(anthropic.max_tokens, 1000);
1487+
1488+
// Create another provider with the same client
1489+
let anthropic2 = Anthropic::with_client(
1490+
shared_client,
1491+
"test-api-key-2",
1492+
Some("claude-3-haiku".to_string()),
1493+
None,
1494+
None,
1495+
None,
1496+
None,
1497+
None,
1498+
None,
1499+
None,
1500+
None,
1501+
None,
1502+
None,
1503+
);
1504+
1505+
assert_eq!(anthropic2.api_key, "test-api-key-2");
1506+
assert_eq!(anthropic2.model, "claude-3-haiku");
1507+
}
13941508
}

src/backends/azure_openai.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use serde::{Deserialize, Serialize};
2727
/// Client for interacting with Azure OpenAI's API.
2828
///
2929
/// Provides methods for chat and completion requests using Azure OpenAI's models.
30+
#[derive(Clone)]
3031
pub struct AzureOpenAI {
3132
pub api_key: String,
3233
pub api_version: String,
@@ -381,6 +382,56 @@ impl AzureOpenAI {
381382
json_schema,
382383
}
383384
}
385+
386+
/// Creates a new Azure OpenAI client with a pre-configured HTTP client.
387+
///
388+
/// This allows sharing a single `reqwest::Client` across multiple providers,
389+
/// enabling connection pooling and reducing resource usage.
390+
#[allow(clippy::too_many_arguments)]
391+
pub fn with_client(
392+
client: Client,
393+
api_key: impl Into<String>,
394+
api_version: impl Into<String>,
395+
deployment_id: impl Into<String>,
396+
endpoint: impl Into<String>,
397+
model: Option<String>,
398+
max_tokens: Option<u32>,
399+
temperature: Option<f32>,
400+
timeout_seconds: Option<u64>,
401+
system: Option<String>,
402+
top_p: Option<f32>,
403+
top_k: Option<u32>,
404+
embedding_encoding_format: Option<String>,
405+
embedding_dimensions: Option<u32>,
406+
tools: Option<Vec<Tool>>,
407+
tool_choice: Option<ToolChoice>,
408+
reasoning_effort: Option<String>,
409+
json_schema: Option<StructuredOutputFormat>,
410+
) -> Self {
411+
let endpoint = endpoint.into();
412+
let deployment_id = deployment_id.into();
413+
414+
Self {
415+
api_key: api_key.into(),
416+
api_version: api_version.into(),
417+
base_url: Url::parse(&format!("{endpoint}/openai/deployments/{deployment_id}/"))
418+
.expect("Failed to parse base Url"),
419+
model: model.unwrap_or("gpt-3.5-turbo".to_string()),
420+
max_tokens,
421+
temperature,
422+
system,
423+
timeout_seconds,
424+
top_p,
425+
top_k,
426+
tools,
427+
tool_choice,
428+
embedding_encoding_format,
429+
embedding_dimensions,
430+
client,
431+
reasoning_effort,
432+
json_schema,
433+
}
434+
}
384435
}
385436

386437
#[async_trait]

src/backends/cohere.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use async_trait::async_trait;
1717
use serde::{Deserialize, Serialize};
1818

1919
/// Cohere configuration for the generic provider
20+
#[derive(Clone)]
2021
pub struct CohereConfig;
2122

2223
impl OpenAIProviderConfig for CohereConfig {
@@ -78,6 +79,53 @@ impl Cohere {
7879
embedding_dimensions,
7980
)
8081
}
82+
83+
/// Creates a new Cohere client with a pre-configured HTTP client.
84+
#[allow(clippy::too_many_arguments)]
85+
pub fn with_config_and_client(
86+
client: reqwest::Client,
87+
api_key: impl Into<String>,
88+
base_url: Option<String>,
89+
model: Option<String>,
90+
max_tokens: Option<u32>,
91+
temperature: Option<f32>,
92+
timeout_seconds: Option<u64>,
93+
system: Option<String>,
94+
top_p: Option<f32>,
95+
top_k: Option<u32>,
96+
tools: Option<Vec<Tool>>,
97+
tool_choice: Option<ToolChoice>,
98+
extra_body: Option<serde_json::Value>,
99+
embedding_encoding_format: Option<String>,
100+
embedding_dimensions: Option<u32>,
101+
reasoning_effort: Option<String>,
102+
json_schema: Option<StructuredOutputFormat>,
103+
parallel_tool_calls: Option<bool>,
104+
normalize_response: Option<bool>,
105+
) -> Self {
106+
<OpenAICompatibleProvider<CohereConfig>>::with_client(
107+
client,
108+
api_key,
109+
base_url,
110+
model,
111+
max_tokens,
112+
temperature,
113+
timeout_seconds,
114+
system,
115+
top_p,
116+
top_k,
117+
tools,
118+
tool_choice,
119+
reasoning_effort,
120+
json_schema,
121+
None,
122+
extra_body,
123+
parallel_tool_calls,
124+
normalize_response,
125+
embedding_encoding_format,
126+
embedding_dimensions,
127+
)
128+
}
81129
}
82130

83131
// Cohere-specific implementations that don't fit in the generic OpenAI-compatible provider

src/backends/deepseek.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use serde::{Deserialize, Serialize};
2121

2222
use crate::ToolCall;
2323

24+
#[derive(Clone)]
2425
pub struct DeepSeek {
2526
pub api_key: String,
2627
pub model: String,
@@ -105,6 +106,30 @@ impl DeepSeek {
105106
client: builder.build().expect("Failed to build reqwest Client"),
106107
}
107108
}
109+
110+
/// Creates a new DeepSeek client with a pre-configured HTTP client.
111+
///
112+
/// This allows sharing a single `reqwest::Client` across multiple providers,
113+
/// enabling connection pooling and reducing resource usage.
114+
pub fn with_client(
115+
client: Client,
116+
api_key: impl Into<String>,
117+
model: Option<String>,
118+
max_tokens: Option<u32>,
119+
temperature: Option<f32>,
120+
timeout_seconds: Option<u64>,
121+
system: Option<String>,
122+
) -> Self {
123+
Self {
124+
api_key: api_key.into(),
125+
model: model.unwrap_or("deepseek-chat".to_string()),
126+
max_tokens,
127+
temperature,
128+
system,
129+
timeout_seconds,
130+
client,
131+
}
132+
}
108133
}
109134

110135
#[async_trait]

src/backends/elevenlabs.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::time::Duration;
1616
///
1717
/// This struct provides functionality for speech-to-text transcription using the ElevenLabs API.
1818
/// It implements various LLM provider traits but only supports speech-to-text functionality.
19+
#[derive(Clone)]
1920
pub struct ElevenLabs {
2021
/// API key for ElevenLabs authentication
2122
api_key: String,
@@ -91,13 +92,39 @@ impl ElevenLabs {
9192
base_url: String,
9293
timeout_seconds: Option<u64>,
9394
voice: Option<String>,
95+
) -> Self {
96+
let mut builder = Client::builder();
97+
if let Some(sec) = timeout_seconds {
98+
builder = builder.timeout(Duration::from_secs(sec));
99+
}
100+
Self {
101+
api_key,
102+
model_id,
103+
base_url,
104+
timeout_seconds,
105+
client: builder.build().expect("Failed to build reqwest Client"),
106+
voice,
107+
}
108+
}
109+
110+
/// Creates a new ElevenLabs client with a pre-configured HTTP client.
111+
///
112+
/// This allows sharing a single `reqwest::Client` across multiple providers,
113+
/// enabling connection pooling and reducing resource usage.
114+
pub fn with_client(
115+
client: Client,
116+
api_key: String,
117+
model_id: String,
118+
base_url: String,
119+
timeout_seconds: Option<u64>,
120+
voice: Option<String>,
94121
) -> Self {
95122
Self {
96123
api_key,
97124
model_id,
98125
base_url,
99126
timeout_seconds,
100-
client: Client::new(),
127+
client,
101128
voice,
102129
}
103130
}

src/backends/google.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use serde_json::Value;
6666
///
6767
/// This struct holds the configuration and state needed to make requests to the Gemini API.
6868
/// It implements the [`ChatProvider`], [`CompletionProvider`], and [`EmbeddingProvider`] traits.
69+
#[derive(Clone)]
6970
pub struct Google {
7071
/// API key for authentication with Google's API
7172
pub api_key: String,
@@ -514,6 +515,39 @@ impl Google {
514515
client: builder.build().expect("Failed to build reqwest Client"),
515516
}
516517
}
518+
519+
/// Creates a new Google client with a pre-configured HTTP client.
520+
///
521+
/// This allows sharing a single `reqwest::Client` across multiple providers,
522+
/// enabling connection pooling and reducing resource usage.
523+
#[allow(clippy::too_many_arguments)]
524+
pub fn with_client(
525+
client: Client,
526+
api_key: impl Into<String>,
527+
model: Option<String>,
528+
max_tokens: Option<u32>,
529+
temperature: Option<f32>,
530+
timeout_seconds: Option<u64>,
531+
system: Option<String>,
532+
top_p: Option<f32>,
533+
top_k: Option<u32>,
534+
json_schema: Option<StructuredOutputFormat>,
535+
tools: Option<Vec<Tool>>,
536+
) -> Self {
537+
Self {
538+
api_key: api_key.into(),
539+
model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()),
540+
max_tokens,
541+
temperature,
542+
system,
543+
timeout_seconds,
544+
top_p,
545+
top_k,
546+
json_schema,
547+
tools,
548+
client,
549+
}
550+
}
517551
}
518552

519553
#[async_trait]

0 commit comments

Comments
 (0)