Skip to content

Commit d702209

Browse files
gyliu513AlexsJonesmatthisholleville
authored
fix: enabled auth add support watsonx backend (#1190)
Signed-off-by: Guangya Liu <[email protected]> Signed-off-by: Alex Jones <[email protected]> Co-authored-by: Alex Jones <[email protected]> Co-authored-by: Matthis <[email protected]>
1 parent 7019d0b commit d702209

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

cmd/auth/add.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ var addCmd = &cobra.Command{
4848
if strings.ToLower(backend) == "amazonbedrock" {
4949
_ = cmd.MarkFlagRequired("providerRegion")
5050
}
51+
if strings.ToLower(backend) == "watsonxai" {
52+
_ = cmd.MarkFlagRequired("providerId")
53+
}
5154
},
5255
Run: func(cmd *cobra.Command, args []string) {
5356

@@ -173,8 +176,8 @@ func init() {
173176
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
174177
//add flag for amazonbedrock region name
175178
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
176-
//add flag for vertexAI Project ID
177-
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
179+
//add flag for vertexAI/WatsonxAI Project ID
180+
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai/watsonxai backend)")
178181
//add flag for OCI Compartment ID
179182
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
180183
// add flag for openai organization

pkg/ai/iai.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
181181
return p.CustomHeaders
182182
}
183183

184-
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
184+
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}
185185

186186
func NeedPassword(backend string) bool {
187187
for _, b := range passwordlessProviders {

pkg/ai/watsonxai.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"os"
8-
97
wx "github.com/IBM/watsonx-go/pkg/models"
108
)
119

@@ -42,20 +40,19 @@ func (c *WatsonxAIClient) Configure(config IAIConfig) error {
4240
c.topP = config.GetTopP()
4341
c.topK = config.GetTopK()
4442

45-
// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
46-
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
47-
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)
48-
43+
apiKey := config.GetPassword()
4944
if apiKey == "" {
5045
return errors.New("No watsonx API key provided")
5146
}
52-
if projectID == "" {
47+
48+
projectId := config.GetProviderId()
49+
if projectId == "" {
5350
return errors.New("No watsonx project ID provided")
5451
}
5552

5653
client, err := wx.NewClient(
5754
wx.WithWatsonxAPIKey(apiKey),
58-
wx.WithWatsonxProjectID(projectID),
55+
wx.WithWatsonxProjectID(projectId),
5956
)
6057
if err != nil {
6158
return fmt.Errorf("Failed to create client for testing. Error: %v", err)

0 commit comments

Comments
 (0)