-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sweep: Connection to openai via OPENAI_API_BASE doesn't seem to work #108
Comments
🚀 Here's the PR! #109See Sweep's progress at the progress dashboard! ⚡ Sweep Basic Tier: I'm using GPT-4. You have 5 GPT-4 tickets left for the month and 3 for the day. (tracking ID:
5e0552f48c )For more GPT-4 tickets, visit our payment portal. For a one week free trial, try Sweep Pro (unlimited GPT-4 tickets). Tip I can email you next time I complete a pull request if you set up your email here! Actions (click)
GitHub Actions✓Here are the GitHub Actions logs prior to making any changes: Sandbox logs for
|
pub struct OpenAIClientManager {} | |
#[async_trait] | |
impl Manager for OpenAIClientManager { | |
type Type = AnyOpenAIClient; | |
type Error = MotorheadError; | |
async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> { | |
let openai_client = match ( | |
env::var("AZURE_API_KEY"), | |
env::var("AZURE_DEPLOYMENT_ID"), | |
env::var("AZURE_DEPLOYMENT_ID_ADA"), | |
env::var("AZURE_API_BASE"), | |
) { | |
( | |
Ok(azure_api_key), | |
Ok(azure_deployment_id), | |
Ok(azure_deployment_id_ada), | |
Ok(azure_api_base), | |
) => { | |
let config = AzureConfig::new() | |
.with_api_base(&azure_api_base) | |
.with_api_key(&azure_api_key) | |
.with_deployment_id(azure_deployment_id) | |
.with_api_version("2023-05-15"); | |
let config_ada = AzureConfig::new() | |
.with_api_base(&azure_api_base) | |
.with_api_key(&azure_api_key) | |
.with_deployment_id(azure_deployment_id_ada) | |
.with_api_version("2023-05-15"); | |
AnyOpenAIClient::Azure { | |
embedding_client: Client::with_config(config_ada), | |
completion_client: Client::with_config(config), | |
} | |
} | |
_ => { | |
let openai_api_base = env::var("OPENAI_API_BASE"); | |
if let Ok(openai_api_base) = openai_api_base { | |
let embedding_config = OpenAIConfig::default().with_api_base(&openai_api_base); | |
let completion_config = OpenAIConfig::default().with_api_base(&openai_api_base); | |
AnyOpenAIClient::OpenAI { | |
embedding_client: Client::with_config(embedding_config), | |
completion_client: Client::with_config(completion_config), | |
} | |
} else { | |
AnyOpenAIClient::OpenAI { | |
embedding_client: Client::new(), | |
completion_client: Client::new(), | |
} | |
} | |
} | |
}; |
Step 2: ⌨️ Coding
Modify src/models.rs with contents:
• Change the logic in the `create` method to first check for the `OPENAI_API_BASE` environment variable and configure the OpenAI clients accordingly.
• If `OPENAI_API_BASE` is set, use it to configure both the `embedding_client` and `completion_client` for the `AnyOpenAIClient::OpenAI` variant.
• If `OPENAI_API_BASE` is not set, then proceed to check for Azure-related environment variables and configure the clients for the `AnyOpenAIClient::Azure` variant.
• If neither `OPENAI_API_BASE` nor Azure-related environment variables are set, default to creating OpenAI clients with the default configuration.
• The new logic should look like this:async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> { let openai_api_base = env::var("OPENAI_API_BASE").ok(); let azure_api_key = env::var("AZURE_API_KEY").ok(); let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok(); let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok(); let azure_api_base = env::var("AZURE_API_BASE").ok(); let openai_client = if let Some(api_base) = openai_api_base { let embedding_config = OpenAIConfig::default().with_api_base(&api_base); let completion_config = OpenAIConfig::default().with_api_base(&api_base); AnyOpenAIClient::OpenAI { embedding_client: Client::with_config(embedding_config), completion_client: Client::with_config(completion_config), } } else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() { let config = AzureConfig::new() .with_api_base(azure_api_base.as_ref().unwrap()) .with_api_key(azure_api_key.as_ref().unwrap()) .with_deployment_id(azure_deployment_id.unwrap()) .with_api_version("2023-05-15"); let config_ada = AzureConfig::new() .with_api_base(azure_api_base.as_ref().unwrap()) .with_api_key(azure_api_key.as_ref().unwrap()) .with_deployment_id(azure_deployment_id_ada.unwrap()) .with_api_version("2023-05-15"); AnyOpenAIClient::Azure { embedding_client: Client::with_config(config_ada), completion_client: Client::with_config(config), } } else { AnyOpenAIClient::OpenAI { embedding_client: Client::new(), completion_client: Client::new(), } }; Ok(openai_client) } ```<br/>• This change ensures that the `OPENAI_API_BASE` is prioritized for configuration, allowing users to connect to a custom OpenAI API base if provided. <pre>--- +++ @@ -26,32 +26,46 @@ type Error = MotorheadError; async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> { - let openai_client = match ( - env::var("AZURE_API_KEY"), - env::var("AZURE_DEPLOYMENT_ID"), - env::var("AZURE_DEPLOYMENT_ID_ADA"), - env::var("AZURE_API_BASE"), - ) { - ( - Ok(azure_api_key), - Ok(azure_deployment_id), - Ok(azure_deployment_id_ada), - Ok(azure_api_base), - ) => { - let config = AzureConfig::new() - .with_api_base(&azure_api_base) - .with_api_key(&azure_api_key) - .with_deployment_id(azure_deployment_id) - .with_api_version("2023-05-15"); - - let config_ada = AzureConfig::new() - .with_api_base(&azure_api_base) - .with_api_key(&azure_api_key) - .with_deployment_id(azure_deployment_id_ada) - .with_api_version("2023-05-15"); - - AnyOpenAIClient::Azure { - embedding_client: Client::with_config(config_ada), + let openai_api_base = env::var("OPENAI_API_BASE").ok(); + let azure_api_key = env::var("AZURE_API_KEY").ok(); + let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok(); + let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok(); + let azure_api_base = env::var("AZURE_API_BASE").ok(); + + let openai_client = if let Some(api_base) = openai_api_base { + let embedding_config = OpenAIConfig::default().with_api_base(&api_base); + let completion_config = OpenAIConfig::default().with_api_base(&api_base); + + AnyOpenAIClient::OpenAI { + embedding_client: Client::with_config(embedding_config), + completion_client: Client::with_config(completion_config), + } + } else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() { + let config = AzureConfig::new() + .with_api_base(azure_api_base.as_ref().unwrap()) + .with_api_key(azure_api_key.as_ref().unwrap()) + .with_deployment_id(azure_deployment_id.unwrap()) + .with_api_version("2023-05-15"); + + let config_ada = AzureConfig::new() + .with_api_base(azure_api_base.as_ref().unwrap()) + .with_api_key(azure_api_key.as_ref().unwrap()) + .with_deployment_id(azure_deployment_id_ada.unwrap()) + .with_api_version("2023-05-15"); + + AnyOpenAIClient::Azure { + embedding_client: Client::with_config(config_ada), + completion_client: Client::with_config(config), + } + } else { + AnyOpenAIClient::OpenAI { + embedding_client: Client::new(), + completion_client: Client::new(), + } + }; + + Ok(openai_client) + } completion_client: Client::with_config(config), } } </pre> </blockquote> - [X] Running GitHub Actions for `src/models.rs` ✓ [Edit](https://github.com/getmetal/motorhead/edit/sweep/connection_to_openai_via_openai_api_base/src/models.rs#L28-L76) <blockquote>Check src/models.rs with contents: Ran GitHub Actions for <a href="https://github.com/getmetal/motorhead/commit/abf0e546bda733eaa3948c35fd99dd7dae847174">abf0e546bda733eaa3948c35fd99dd7dae847174</a>: </blockquote> --- ## Step 3: 🔁 Code Review I have finished reviewing the code for completeness. I did not find errors for [`sweep/connection_to_openai_via_openai_api_base`](https://github.com/getmetal/motorhead/commits/sweep/connection_to_openai_via_openai_api_base). --- <details> <summary><b>🎉 Latest improvements to Sweep:</b></summary> <ul> <li>New <a href="https://progress.sweep.dev">dashboard</a> launched for real-time tracking of Sweep issues, covering all stages from search to coding.</li> <li>Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.</li> <li>Use the <a href="https://marketplace.visualstudio.com/items?itemName=GitHub.vscode-pull-request-github">GitHub issues extension</a> for creating Sweep issues directly from your editor.</li> </ul> </details> 💡 To recreate the pull request edit the issue title or description. To tweak the pull request, leave a comment on the pull request.<sup>Something wrong? [Let us know](https://discord.gg/sweep).</sup> *This is an automated message generated by [Sweep AI](https://sweep.dev).*
No, I think I got the issue wrong - the logic here is correct, but I'm still seeing motorhead try and connect to the main openai servers instead of to my base url. This PR is not correct, but I'm still debugging why it doesn't honour OPENAI_API_BASE properly... |
Details
I believe there's a logic error in models.rs, where it creates a connection - if OPENAI_API_BASE is provided, it does not properly execute the connection to the specified server, instead it still creates a default openai connection. This means you can't use a local or self-hosted model. I think the logic is wrong in that it only reaches that block if the AZURE env variables are set, but I'm not confident in rust to say for sure.
Checklist
src/models.rs
✓ abf0e54 Editsrc/models.rs
✓ EditThe text was updated successfully, but these errors were encountered: