Skip to content

Commit

Permalink
Resume interrupted connections
Browse files Browse the repository at this point in the history
Resume interrupted connections. This makes the client.Client an
interface as a test seam (so we can easily mock it). Additionally this
now allows pget to resume downloads if a connection is interruped before
the chunk is delivered. This should reduce the potential for
unexpectedEOF errors being bubbled up to the end user.
  • Loading branch information
tempusfrangit committed Jun 6, 2024
1 parent 807a85a commit 02450b0
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 14 deletions.
14 changes: 9 additions & 5 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ const (

var ErrStrategyFallback = errors.New("fallback to next strategy")

// HTTPClient is a wrapper around http.Client that allows for limiting the number of concurrent connections per host
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

// PGetHTTPClient is a wrapper around http.Client that allows for limiting the number of concurrent connections per host
// utilizing a client pool. If the OptMaxConnPerHost option is not set, the client pool will not be used.
type HTTPClient struct {
type PGetHTTPClient struct {
*http.Client
}

func (c *HTTPClient) Do(req *http.Request) (*http.Response, error) {
func (c *PGetHTTPClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion()))
return c.Client.Do(req)
}
Expand All @@ -51,7 +55,7 @@ type TransportOptions struct {

// NewHTTPClient factory function returns a new http.Client with the appropriate settings and can limit number of clients
// per host if the OptMaxConnPerHost option is set.
func NewHTTPClient(opts Options) *HTTPClient {
func NewHTTPClient(opts Options) HTTPClient {

transport := opts.Transport

Expand Down Expand Up @@ -94,7 +98,7 @@ func NewHTTPClient(opts Options) *HTTPClient {
}

client := retryClient.StandardClient()
return &HTTPClient{Client: client}
return &PGetHTTPClient{Client: client}
}

// RetryPolicy wraps retryablehttp.DefaultRetryPolicy and included additional logic:
Expand Down
2 changes: 1 addition & 1 deletion pkg/consumer/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type NullWriter struct{}

var _ Consumer = &NullWriter{}

func (f *NullWriter) Consume(reader io.Reader, destPath string) error {
func (NullWriter) Consume(reader io.Reader, destPath string) error {
// io.Discard is explicitly designed to always succeed, ignore errors.
_, _ = io.Copy(io.Discard, reader)
return nil
Expand Down
8 changes: 7 additions & 1 deletion pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type BufferMode struct {
Client *client.HTTPClient
Client client.HTTPClient
Options

queue *priorityWorkQueue
Expand Down Expand Up @@ -81,6 +81,9 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e

contentLength := firstChunkResp.ContentLength
n, err := io.ReadFull(firstChunkResp.Body, buf[0:contentLength])
if err == io.ErrUnexpectedEOF {
_, err = resumeDownload(firstChunkResp.Request, buf[n:], m.Client, int64(n))
}
firstChunk.Deliver(buf[0:n], err)
})

Expand Down Expand Up @@ -144,6 +147,9 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e

contentLength := resp.ContentLength
n, err := io.ReadFull(resp.Body, buf[0:contentLength])
if err == io.ErrUnexpectedEOF {
_, err = resumeDownload(resp.Request, buf[n:], m.Client, int64(n))
}
chunk.Deliver(buf[0:n], err)
})
}
Expand Down
79 changes: 78 additions & 1 deletion pkg/download/common.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,88 @@
package download

import (
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"

"github.com/dustin/go-humanize"

"github.com/replicate/pget/pkg/client"
)

const defaultChunkSize = 125 * humanize.MiByte

var contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`)
var (
contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`)

errMalformedRangeHeader = errors.New("malformed range header")
errMissingRangeHeader = errors.New("missing range header")
errInvalidContentRange = errors.New("invalid content range")
)

func resumeDownload(req *http.Request, buffer []byte, client client.HTTPClient, bytesReceived int64) (*http.Response, error) {
var startByte int
for {
var n int
if err := updateRangeRequestHeader(req, bytesReceived); err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent {
return nil, fmt.Errorf("expected status code %d, got %d", http.StatusPartialContent, resp.StatusCode)
}
n, err = io.ReadFull(resp.Body, buffer[startByte:])
if err == io.ErrUnexpectedEOF {
startByte = n
continue
}
return nil, err

}
}

func updateRangeRequestHeader(req *http.Request, receivedBytes int64) error {
rangeHeader := req.Header.Get("Range")
if rangeHeader == "" {
return errMissingRangeHeader
}

// Expected format: "bytes=start-end"
if !strings.HasPrefix(rangeHeader, "bytes=") {
return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader)
}

rangeValues := strings.TrimPrefix(rangeHeader, "bytes=")
parts := strings.Split(rangeValues, "-")
if len(parts) != 2 {
return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader)
}

start, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader)
}

end, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader)
}

start = start + receivedBytes
if start > end {
return fmt.Errorf("%w: %s", errInvalidContentRange, rangeHeader)
}

newRangeHeader := fmt.Sprintf("bytes=%d-%d", start, end)
req.Header.Set("Range", newRangeHeader)

return nil
}
189 changes: 189 additions & 0 deletions pkg/download/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package download

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

type mockHTTPClient struct {
doFunc func(req *http.Request) (*http.Response, error)
callCount atomic.Int32
}

func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
m.callCount.Add(1)
return m.doFunc(req)
}

func TestResumeDownload(t *testing.T) {
tests := []struct {
name string
serverContent string
bytesReceived int64
initialRange string
expectedError error
expectedOutput []byte
expectedCalls int32
}{
{
name: "successful download",
serverContent: "Hello, world!",
bytesReceived: 0,
initialRange: "bytes=0-12",
expectedError: nil,
expectedOutput: []byte("Hello, world!"),
expectedCalls: 1,
},
{
name: "partial download",
serverContent: "Hello, world!",
bytesReceived: 3,
initialRange: "bytes=7-12",
expectedError: nil,
expectedOutput: []byte("world!"),
expectedCalls: 1,
},
{
name: "network error",
serverContent: "Hello, world!",
bytesReceived: 0,
initialRange: "bytes=0-12",
expectedError: errors.New("network error"),
expectedOutput: nil,
expectedCalls: 1,
},
{
name: "multi-pass download",
serverContent: "12345678901234567890",
bytesReceived: 3,
initialRange: "bytes=10-19",
expectedError: nil,
expectedOutput: []byte("0123456789"),
expectedCalls: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.ServeContent(w, r, "", time.Time{}, bytes.NewReader([]byte(tt.serverContent)))
}))
defer server.Close()

req, err := http.NewRequest("GET", server.URL, nil)
assert.NoError(t, err)

// Set the initial Range header from the test case
req.Header.Set("Range", tt.initialRange)

buffer := make([]byte, len(tt.expectedOutput))
copy(buffer, tt.expectedOutput[:tt.bytesReceived])
mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
if tt.name == "network error" {
return nil, errors.New("network error")
}
if tt.name == "multi-pass download" {
switch req.Header.Get("Range") {
case "bytes=16-19":
return &http.Response{
StatusCode: http.StatusPartialContent,
Body: io.NopCloser(bytes.NewReader([]byte("56789"))),
Header: http.Header{"Content-Range": []string{"bytes 15-20/21"}},
}, nil
case "bytes=13-19":
return &http.Response{
StatusCode: http.StatusPartialContent,
Body: io.NopCloser(bytes.NewReader([]byte("34"))),
Header: http.Header{"Content-Range": []string{"bytes 13-20/21"}},
}, nil
}
}
return http.DefaultClient.Do(req)
},
}

_, err = resumeDownload(req, buffer[tt.bytesReceived:], mockClient, tt.bytesReceived)
if tt.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, tt.expectedError.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedOutput, buffer[:len(tt.expectedOutput)])
}
assert.Equal(t, tt.expectedCalls, mockClient.callCount.Load(), "Unexpected number of HTTP client calls")
})
}
}

func TestUpdateRangeRequestHeader(t *testing.T) {
tests := []struct {
name string
initialRange string
receivedBytes int64
expectedRange string
expectedError error
}{
{
name: "valid range header",
initialRange: "bytes=0-10",
receivedBytes: 5,
expectedRange: "bytes=5-10",
expectedError: nil,
},
{
name: "non-zero initial range",
initialRange: "bytes=7-12",
receivedBytes: 3,
expectedRange: "bytes=10-12",
expectedError: nil,
},
{
name: "missing range header",
initialRange: "",
receivedBytes: 5,
expectedRange: "",
expectedError: errMissingRangeHeader,
},
{
name: "malformed range header",
initialRange: "bytes=malformed",
receivedBytes: 5,
expectedRange: "",
expectedError: errMalformedRangeHeader,
},
{
name: "receivedBytes exceeds range",
initialRange: "bytes=0-10",
receivedBytes: 15,
expectedRange: "",
expectedError: errInvalidContentRange,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest("GET", "http://example.com", nil)
assert.NoError(t, err)
req.Header.Set("Range", tt.initialRange)

err = updateRangeRequestHeader(req, tt.receivedBytes)
if tt.expectedError != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.expectedError)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedRange, req.Header.Get("Range"))
}
})
}
}
13 changes: 7 additions & 6 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

type ConsistentHashingMode struct {
Client *client.HTTPClient
Client client.HTTPClient
Options
// TODO: allow this to be configured and not just "BufferMode"
FallbackStrategy Strategy
Expand Down Expand Up @@ -116,6 +116,9 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io

contentLength := firstChunkResp.ContentLength
n, err := io.ReadFull(firstChunkResp.Body, buf[0:contentLength])
if err == io.ErrUnexpectedEOF {
_, err = resumeDownload(firstChunkResp.Request, buf[n:], m.Client, int64(n))
}
firstChunk.Deliver(buf[0:n], err)
})
firstReqResult, ok := <-firstReqResultCh
Expand Down Expand Up @@ -205,11 +208,6 @@ func (m *ConsistentHashingMode) downloadRemainingChunks(ctx context.Context, url
// for the specified chunk instead of the whole file.
if errors.Is(err, client.ErrStrategyFallback) {
// TODO(morgan): we should indicate the fallback strategy we're using in the logs
logger.Info().
Str("url", urlString).
Str("type", "chunk").
Err(err).
Msg("consistent hash fallback")
resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString)
}
if err != nil {
Expand All @@ -220,6 +218,9 @@ func (m *ConsistentHashingMode) downloadRemainingChunks(ctx context.Context, url
defer resp.Body.Close()
contentLength := resp.ContentLength
n, err := io.ReadFull(resp.Body, buf[0:contentLength])
if err == io.ErrUnexpectedEOF {
_, err = resumeDownload(resp.Request, buf[n:], m.Client, int64(n))
}
chunk.Deliver(buf[0:n], err)
})
}
Expand Down

0 comments on commit 02450b0

Please sign in to comment.