Skip to content

Commit

Permalink
Gracefully handle missing SRV records
Browse files Browse the repository at this point in the history
This code did not communicate its intentions well.

ConsistentHashingMode has a CacheHosts member which corresponds to the SRV
records for the consistent hash pods, in the correct order.  This slice can
contain gaps of empty string where there are gaps in the corresponding SRV
records. For example, if a statefulset `foo` has 4 pods but foo-2 is not
currently ready, CacheHosts will contain this slice:

- 0: "foo-0"
- 1: "foo-1"
- 2: ""
- 3: "foo-3"

The code had a guard to ensure that it only overrides the request Host and
Scheme if the CacheHost exists.  This meant that if a cache host was not ready,
the request would be made directly against upstream.

This interacted poorly with CacheUsePathProxy, which also munged the HTTP path.

This commit changes the `cacheHost[i] == ""` case to treat it as a regular
failure, and to perform fallbacks as necessary.  This is more correct in a few
ways:

- if a single cache pod is not ready, we should retry against another cache pod
  before going upstream, but the current code goes straight to upstream
- it handles CacheUsePathProxy better, by ensuring we only ever munge the path
  if we're also sending the request to a real cache host

I've written a test for this case. It's not ideal in that in the failure case it
makes a request upstream to https://weights.replicate.delivery/hello.txt.
Ideally we'd spin up an origin server in the test, but all the hashes depend on
the origin URL and the current way we spin up test servers chooses a new port
each time.  I'd welcome improvements but it's good enough for now.
  • Loading branch information
philandstuff authored and tempusfrangit committed Feb 13, 2024
1 parent 4517037 commit 7593122
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 23 deletions.
61 changes: 38 additions & 23 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,37 +250,20 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
}

func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) {
logger := logging.GetLogger()
chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true)
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))

logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")

resp, err := m.Client.Do(req)
resp, cachePodIndex, err := m.doRequestToCacheHost(req, urlString, start, end)
if err != nil {
if errors.Is(err, client.ErrStrategyFallback) {
origErr := err
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
_, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex)
if err != nil {
// return origErr so that we can use our regular fallback strategy
return nil, origErr
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")

resp, err = m.Client.Do(req)
resp, _, err = m.doRequestToCacheHost(req, urlString, start, end, cachePodIndex)
if err != nil {
// return origErr so that we can use our regular fallback strategy
return nil, origErr
Expand All @@ -296,6 +279,20 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
return resp, nil
}

func (m *ConsistentHashingMode) doRequestToCacheHost(req *http.Request, urlString string, start int64, end int64, previousPodIndexes ...int) (*http.Response, int, error) {
logger := logging.GetLogger()
cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end, previousPodIndexes...)
if err != nil {
return nil, cachePodIndex, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))

logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")

resp, err := m.Client.Do(req)
return resp, cachePodIndex, err
}

func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, start int64, end int64, previousPodIndexes ...int) (int, error) {
logger := logging.GetLogger()
if start/m.SliceSize != end/m.SliceSize {
Expand All @@ -319,10 +316,28 @@ func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, sta
req.URL.Path = fmt.Sprintf("/%s", newPath)
}
cacheHost := m.CacheHosts[cachePodIndex]
logger.Debug().Str("cache_key", fmt.Sprintf("%+v", key)).Int64("start", start).Int64("end", end).Int64("slice_size", m.SliceSize).Int("bucket", cachePodIndex).Msg("consistent hashing")
if cacheHost != "" {
req.URL.Scheme = "http"
req.URL.Host = cacheHost
if cacheHost == "" {
// this can happen if an SRV record is missing due to a not-ready pod
logger.Debug().
Str("cache_key", fmt.Sprintf("%+v", key)).
Int64("start", start).
Int64("end", end).
Int64("slice_size", m.SliceSize).
Int("bucket", cachePodIndex).
Ints("previous_pod_indexes", previousPodIndexes).
Msg("cache host for bucket not ready, falling back")
return cachePodIndex, client.ErrStrategyFallback
}
logger.Debug().
Str("cache_key", fmt.Sprintf("%+v", key)).
Int64("start", start).
Int64("end", end).
Int64("slice_size", m.SliceSize).
Int("bucket", cachePodIndex).
Ints("previous_pod_indexes", previousPodIndexes).
Msg("consistent hashing")
req.URL.Scheme = "http"
req.URL.Host = cacheHost

return cachePodIndex, nil
}
43 changes: 43 additions & 0 deletions pkg/download/consistent_hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ func TestConsistentHashRetries(t *testing.T) {
assert.Equal(t, "3344761726165516", string(bytes))
}

func TestConsistentHashRetriesMissingHostname(t *testing.T) {
hostnames := make([]string, len(testFSes))
for i, fs := range testFSes {
ts := httptest.NewServer(http.FileServer(http.FS(fs)))
defer ts.Close()
url, err := url.Parse(ts.URL)
require.NoError(t, err)
hostnames[i] = url.Host
}
// we want to test that we never fall back to origin. So we set origin to be
// this canary file and ensure we don't get any data from it
origin := "https://weights.replicate.delivery/hello.txt"

// we deliberately "break" this cache host to make it as if its SRV record was missing
hostnames[0] = ""

opts := download.Options{
Client: client.Options{},
MaxConcurrency: 8,
MinChunkSize: 1,
CacheHosts: hostnames,
CacheableURIPrefixes: makeCacheableURIPrefixes("https://weights.replicate.delivery"),
SliceSize: 1,
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

strategy, err := download.GetConsistentHashingMode(opts)
require.NoError(t, err)

reader, _, err := strategy.Fetch(ctx, origin)
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)

// with a functional hostnames[0], we'd see `5"37132251231713`, where the
// `"` character comes from upstream, but instead we should fall back to
// this. Note that each 0 value has been changed to a different index; we
// don't want every request that previously hit 0 to hit the same new host.
assert.Equal(t, "5337132251231713", string(bytes))
}

// with only two hosts, we should *always* fall back to the other host
func TestConsistentHashRetriesTwoHosts(t *testing.T) {
hostnames := make([]string, 2)
Expand Down

0 comments on commit 7593122

Please sign in to comment.