@@ -3,6 +3,7 @@ package llms
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "net/http"
6
7
"time"
7
8
8
9
"github.com/getzep/zep/pkg/models"
@@ -131,13 +132,32 @@ func Float64ToFloat32Matrix(in [][]float64) [][]float32 {
131
132
}
132
133
133
134
func NewRetryableHTTPClient (retryMax int , timeout time.Duration ) * retryablehttp.Client {
134
- retryableHttpClient := retryablehttp .NewClient ()
135
- retryableHttpClient .RetryMax = retryMax
136
- retryableHttpClient .HTTPClient .Timeout = timeout
137
- retryableHttpClient .Logger = log
138
- retryableHttpClient .Backoff = retryablehttp .DefaultBackoff
135
+ retryableHTTPClient := retryablehttp .NewClient ()
136
+ retryableHTTPClient .RetryMax = retryMax
137
+ retryableHTTPClient .HTTPClient .Timeout = timeout
138
+ retryableHTTPClient .Logger = log
139
+ retryableHTTPClient .Backoff = retryablehttp .DefaultBackoff
140
+ retryableHTTPClient .CheckRetry = retryPolicy
141
+
142
+ return retryableHTTPClient
143
+ }
144
+
145
+ // retryPolicy is a retryablehttp.CheckRetry function. It is used to determine
146
+ // whether a request should be retried or not.
147
+ func retryPolicy (ctx context.Context , resp * http.Response , err error ) (bool , error ) {
148
+ // do not retry on context.Canceled or context.DeadlineExceeded
149
+ if ctx .Err () != nil {
150
+ return false , ctx .Err ()
151
+ }
152
+
153
+ // Do not retry 400 errors as they're used by OpenAI to indicate maximum
154
+ // context length exceeded
155
+ if resp != nil && resp .StatusCode == 400 {
156
+ return false , err
157
+ }
139
158
140
- return retryableHttpClient
159
+ shouldRetry , _ := retryablehttp .DefaultRetryPolicy (ctx , resp , err )
160
+ return shouldRetry , nil
141
161
}
142
162
143
163
// useOpenAIEmbeddings is true if OpenAI embeddings are enabled
0 commit comments