Skip to content

Commit 02e754e

Browse files
arbreezyrenovate[bot]AlexsJones
authored
feat: add custom http headers to openai related api backends (#1174)
* feat: add custom http headers to openai related api backends Signed-off-by: Aris Boutselis <[email protected]> * ci: add custom headers test Signed-off-by: Aris Boutselis <[email protected]> * add error handling Signed-off-by: Aris Boutselis <[email protected]> * chore(deps): update docker/setup-buildx-action digest to 4fd8129 (#1173) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Signed-off-by: Aris Boutselis <[email protected]> * fix(deps): update module buf.build/gen/go/k8sgpt-ai/k8sgpt/grpc-ecosystem/gateway/v2 to v2.20.0-20240406062209-1cc152efbf5c.1 (#1147) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Signed-off-by: Aris Boutselis <[email protected]> * chore(deps): update anchore/sbom-action action to v0.16.0 (#1146) Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Jones <[email protected]> Signed-off-by: Aris Boutselis <[email protected]> * Update README.md Signed-off-by: Aris Boutselis <[email protected]> --------- Signed-off-by: Aris Boutselis <[email protected]> Signed-off-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Jones <[email protected]>
1 parent fef8539 commit 02e754e

File tree

8 files changed

+211
-26
lines changed

8 files changed

+211
-26
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ _Analysis with serve mode_
299299
```
300300
grpcurl -plaintext -d '{"namespace": "k8sgpt", "explain": false}' localhost:8080 schema.v1.ServerService/Analyze
301301
```
302+
303+
_Analysis with custom headers_
304+
305+
```
306+
k8sgpt analyze --explain --custom-headers CustomHeaderKey:CustomHeaderValue
307+
```
302308
</details>
303309

304310
## LLM AI Backends

cmd/analyze/analyze.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ var (
3838
withDoc bool
3939
interactiveMode bool
4040
customAnalysis bool
41+
customHeaders []string
4142
)
4243

4344
// AnalyzeCmd represents the problems command
@@ -59,6 +60,7 @@ var AnalyzeCmd = &cobra.Command{
5960
maxConcurrency,
6061
withDoc,
6162
interactiveMode,
63+
customHeaders,
6264
)
6365

6466
if err != nil {
@@ -138,5 +140,6 @@ func init() {
138140
AnalyzeCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Enable interactive mode that allows further conversation with LLM about the problem. Works only with --explain flag")
139141
// custom analysis flag
140142
AnalyzeCmd.Flags().BoolVarP(&customAnalysis, "custom-analysis", "z", false, "Enable custom analyzers")
141-
143+
// add custom headers flag
144+
AnalyzeCmd.Flags().StringSliceVarP(&customHeaders, "custom-headers", "r", []string{}, "Custom Headers, <key>:<value> (e.g CustomHeaderKey:CustomHeaderValue AnotherHeader:AnotherValue)")
142145
}

pkg/ai/iai.go

+23-16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ package ai
1515

1616
import (
1717
"context"
18+
"net/http"
1819
)
1920

2021
var (
@@ -83,6 +84,7 @@ type IAIConfig interface {
8384
GetProviderId() string
8485
GetCompartmentId() string
8586
GetOrganizationId() string
87+
GetCustomHeaders() []http.Header
8688
}
8789

8890
func NewClient(provider string) IAI {
@@ -101,22 +103,23 @@ type AIConfiguration struct {
101103
}
102104

103105
type AIProvider struct {
104-
Name string `mapstructure:"name"`
105-
Model string `mapstructure:"model"`
106-
Password string `mapstructure:"password" yaml:"password,omitempty"`
107-
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
108-
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
109-
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
110-
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
111-
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
112-
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
113-
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
114-
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
115-
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
116-
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
117-
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
118-
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
119-
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
106+
Name string `mapstructure:"name"`
107+
Model string `mapstructure:"model"`
108+
Password string `mapstructure:"password" yaml:"password,omitempty"`
109+
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
110+
ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"`
111+
ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"`
112+
EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"`
113+
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
114+
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
115+
ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"`
116+
ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"`
117+
CompartmentId string `mapstructure:"compartmentid" yaml:"compartmentid,omitempty"`
118+
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
119+
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
120+
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
121+
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
122+
CustomHeaders []http.Header `mapstructure:"customHeaders"`
120123
}
121124

122125
func (p *AIProvider) GetBaseURL() string {
@@ -174,6 +177,10 @@ func (p *AIProvider) GetOrganizationId() string {
174177
return p.OrganizationId
175178
}
176179

180+
func (p *AIProvider) GetCustomHeaders() []http.Header {
181+
return p.CustomHeaders
182+
}
183+
177184
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
178185

179186
func NeedPassword(backend string) bool {

pkg/ai/openai.go

+32-7
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,27 @@ func (c *OpenAIClient) Configure(config IAIConfig) error {
5252
defaultConfig.BaseURL = baseURL
5353
}
5454

55+
transport := &http.Transport{}
5556
if proxyEndpoint != "" {
5657
proxyUrl, err := url.Parse(proxyEndpoint)
5758
if err != nil {
5859
return err
5960
}
60-
transport := &http.Transport{
61-
Proxy: http.ProxyURL(proxyUrl),
62-
}
63-
64-
defaultConfig.HTTPClient = &http.Client{
65-
Transport: transport,
66-
}
61+
transport.Proxy = http.ProxyURL(proxyUrl)
6762
}
6863

6964
if orgId != "" {
7065
defaultConfig.OrgID = orgId
7166
}
7267

68+
customHeaders := config.GetCustomHeaders()
69+
defaultConfig.HTTPClient = &http.Client{
70+
Transport: &OpenAIHeaderTransport{
71+
Origin: transport,
72+
Headers: customHeaders,
73+
},
74+
}
75+
7376
client := openai.NewClientWithConfig(defaultConfig)
7477
if client == nil {
7578
return errors.New("error creating OpenAI client")
@@ -106,3 +109,25 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string
106109
func (c *OpenAIClient) GetName() string {
107110
return openAIClientName
108111
}
112+
113+
// OpenAIHeaderTransport is an http.RoundTripper that adds the given headers to each request.
114+
type OpenAIHeaderTransport struct {
115+
Origin http.RoundTripper
116+
Headers []http.Header
117+
}
118+
119+
// RoundTrip implements the http.RoundTripper interface.
120+
func (t *OpenAIHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
121+
// Clone the request to avoid modifying the original request
122+
clonedReq := req.Clone(req.Context())
123+
for _, header := range t.Headers {
124+
for key, values := range header {
125+
// Possible values per header: RFC 2616
126+
for _, value := range values {
127+
clonedReq.Header.Add(key, value)
128+
}
129+
}
130+
}
131+
132+
return t.Origin.RoundTrip(clonedReq)
133+
}
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
// Mock configuration
13+
type mockConfig struct {
14+
baseURL string
15+
}
16+
17+
func (m *mockConfig) GetPassword() string {
18+
return ""
19+
}
20+
21+
func (m *mockConfig) GetOrganizationId() string {
22+
return ""
23+
}
24+
25+
func (m *mockConfig) GetProxyEndpoint() string {
26+
return ""
27+
}
28+
29+
func (m *mockConfig) GetBaseURL() string {
30+
return m.baseURL
31+
}
32+
33+
func (m *mockConfig) GetCustomHeaders() []http.Header {
34+
return []http.Header{
35+
{"X-Custom-Header-1": []string{"Value1"}},
36+
{"X-Custom-Header-2": []string{"Value2"}},
37+
{"X-Custom-Header-2": []string{"Value3"}}, // Testing multiple values for the same header
38+
}
39+
}
40+
41+
func (m *mockConfig) GetModel() string {
42+
return ""
43+
}
44+
45+
func (m *mockConfig) GetTemperature() float32 {
46+
return 0.0
47+
}
48+
49+
func (m *mockConfig) GetTopP() float32 {
50+
return 0.0
51+
}
52+
func (m *mockConfig) GetCompartmentId() string {
53+
return ""
54+
}
55+
56+
func (m *mockConfig) GetTopK() int32 {
57+
return 0.0
58+
}
59+
60+
func (m *mockConfig) GetMaxTokens() int {
61+
return 0
62+
}
63+
64+
func (m *mockConfig) GetEndpointName() string {
65+
return ""
66+
}
67+
func (m *mockConfig) GetEngine() string {
68+
return ""
69+
}
70+
71+
func (m *mockConfig) GetProviderId() string {
72+
return ""
73+
}
74+
75+
func (m *mockConfig) GetProviderRegion() string {
76+
return ""
77+
}
78+
79+
func TestOpenAIClient_CustomHeaders(t *testing.T) {
80+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
assert.Equal(t, "Value1", r.Header.Get("X-Custom-Header-1"))
82+
assert.ElementsMatch(t, []string{"Value2", "Value3"}, r.Header["X-Custom-Header-2"])
83+
w.WriteHeader(http.StatusOK)
84+
// Mock response for openai completion
85+
mockResponse := `{"choices": [{"message": {"content": "test"}}]}`
86+
n, err := w.Write([]byte(mockResponse))
87+
if err != nil {
88+
t.Fatalf("error writing response: %v", err)
89+
}
90+
if n != len(mockResponse) {
91+
t.Fatalf("expected to write %d bytes but wrote %d bytes", len(mockResponse), n)
92+
}
93+
}))
94+
defer server.Close()
95+
96+
config := &mockConfig{baseURL: server.URL}
97+
98+
client := &OpenAIClient{}
99+
err := client.Configure(config)
100+
assert.NoError(t, err)
101+
102+
// Make a completion request to trigger the headers
103+
ctx := context.Background()
104+
_, err = client.GetCompletion(ctx, "foo prompt")
105+
assert.NoError(t, err)
106+
}

pkg/analysis/analysis.go

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func NewAnalysis(
7979
maxConcurrency int,
8080
withDoc bool,
8181
interactiveMode bool,
82+
httpHeaders []string,
8283
) (*Analysis, error) {
8384
// Get kubernetes client from viper.
8485
kubecontext := viper.GetString("kubecontext")
@@ -146,6 +147,8 @@ func NewAnalysis(
146147
}
147148

148149
aiClient := ai.NewClient(aiProvider.Name)
150+
customHeaders := util.NewHeaders(httpHeaders)
151+
aiProvider.CustomHeaders = customHeaders
149152
if err := aiClient.Configure(&aiProvider); err != nil {
150153
return nil, err
151154
}

pkg/server/analyze.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ func (h *handler) Analyze(ctx context.Context, i *schemav1.AnalyzeRequest) (
2828
i.Nocache,
2929
i.Explain,
3030
int(i.MaxConcurrency),
31-
false, // Kubernetes Doc disabled in server mode
32-
false, // Interactive mode disabled in server mode
31+
false, // Kubernetes Doc disabled in server mode
32+
false, // Interactive mode disabled in server mode
33+
[]string{}, //TODO: add custom http headers in server mode
3334
)
3435
config.Context = ctx // Replace context for correct timeouts.
3536
if err != nil {

pkg/util/util.go

+34
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/hex"
2222
"errors"
2323
"fmt"
24+
"net/http"
2425
"os"
2526
"regexp"
2627
"strings"
@@ -261,3 +262,36 @@ func FetchLatestEvent(ctx context.Context, kubernetesClient *kubernetes.Client,
261262
}
262263
return latestEvent, nil
263264
}
265+
266+
// NewHeaders parses a slice of strings in the format "key:value" into []http.Header
267+
// It handles headers with the same key by appending values
268+
func NewHeaders(customHeaders []string) []http.Header {
269+
headers := make(map[string][]string)
270+
271+
for _, header := range customHeaders {
272+
vals := strings.SplitN(header, ":", 2)
273+
if len(vals) != 2 {
274+
//TODO: Handle error instead of ignoring it
275+
continue
276+
}
277+
key := strings.TrimSpace(vals[0])
278+
value := strings.TrimSpace(vals[1])
279+
280+
if _, ok := headers[key]; !ok {
281+
headers[key] = []string{}
282+
}
283+
headers[key] = append(headers[key], value)
284+
}
285+
286+
// Convert map to []http.Header format
287+
var result []http.Header
288+
for key, values := range headers {
289+
header := make(http.Header)
290+
for _, value := range values {
291+
header.Add(key, value)
292+
}
293+
result = append(result, header)
294+
}
295+
296+
return result
297+
}

0 commit comments

Comments
 (0)