Skip to content

Commit

Permalink
Merge pull request #157 from EmmEff/release-1.3-fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
EmmEff authored Jan 10, 2023
2 parents 9205dca + 521a909 commit d4e04bc
Show file tree
Hide file tree
Showing 2 changed files with 602 additions and 23 deletions.
110 changes: 88 additions & 22 deletions client/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package client

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -19,6 +20,13 @@ import (
"golang.org/x/sync/errgroup"
)

var (
errUnauthorized = errors.New("unauthorized")
errMissingLocationHeader = errors.New("missing HTTP Location header")
errInvalidArguments = errors.New("invalid argument(s)")
errUnknownContentLength = errors.New("unknown content length")
)

// DownloadImage will retrieve an image from the Container Library, saving it
// into the specified io.Writer. The timeout value for this operation is set
// within the context. It is recommended to use a large value (ie. 1800 seconds)
Expand Down Expand Up @@ -102,9 +110,22 @@ type Downloader struct {
BufferSize int64
}

// httpGetRangeRequest performs HTTP GET range request to URL specified by 'u' in range start-end.
func (c *Client) httpGetRangeRequest(ctx context.Context, url string, start, end int64) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
// httpGetRangeRequest performs HTTP GET range request to URL 'u' in range start-end.
func (c *Client) httpGetRangeRequest(ctx context.Context, endpoint, authToken string, start, end int64) (*http.Response, error) {
return c.httpRangeRequest(ctx, http.MethodGet, endpoint, authToken, start, end)
}

func (c *Client) httpRangeRequest(ctx context.Context, method, endpoint, authToken string, start, end int64) (*http.Response, error) {
if start >= end || start < 0 || end < 0 {
return nil, errInvalidArguments
}

u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, method, endpoint, nil)
if err != nil {
return nil, err
}
Expand All @@ -113,14 +134,28 @@ func (c *Client) httpGetRangeRequest(ctx context.Context, url string, start, end
req.Header.Set("User-Agent", v)
}

if authToken != "" && samehost(c.BaseURL, u) {
// Include authorization header if request being made to host specified by base URL
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", authToken))
}

req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))

return c.HTTPClient.Do(req)
}

// samehost returns true if host1 and host2 are, in fact, the same host by
// comparing scheme (https == https) and host (which includes port).
//
// Hosts will be treated as dissimilar if one host includes domain suffix
// and the other does not, even if the host names match.
func samehost(host1, host2 *url.URL) bool {
return host1.Scheme == host2.Scheme && host1.Host == host2.Host
}

// downloadFilePart writes range to dst as specified in bufferSpec.
func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, url string, ps *partSpec, pb ProgressBar) error {
resp, err := c.httpGetRangeRequest(ctx, url, ps.Start, ps.End)
func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, endpoint, authToken string, ps *partSpec, pb ProgressBar) error {
resp, err := c.httpGetRangeRequest(ctx, endpoint, authToken, ps.Start, ps.End)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,20 +187,37 @@ func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, url string,
}

// downloadWorker is a worker func for processing jobs in stripes channel.
func (c *Client) downloadWorker(ctx context.Context, dst *os.File, url string, parts <-chan partSpec, pb ProgressBar) func() error {
func (c *Client) downloadWorker(ctx context.Context, dst *os.File, endpoint, authToken string, parts <-chan partSpec, pb ProgressBar) func() error {
return func() error {
for ps := range parts {
if err := c.downloadFilePart(ctx, dst, url, &ps, pb); err != nil {
if err := c.downloadFilePart(ctx, dst, endpoint, authToken, &ps, pb); err != nil {
return err
}
}
return nil
}
}

func (c *Client) getContentLength(ctx context.Context, url string) (int64, error) {
// parseContentRangeHeader returns size returned in Content-Range response HTTP header
func parseContentRangeHeader(value string) (int64, error) {
if value == "" {
return -1, nil
}

vals := strings.Split(value, "/")
if len(vals) < 2 {
return 0, errUnknownContentLength
}
if vals[1] == "*" {
// Server reports size is unknown
return 0, fmt.Errorf("indeterminant size")
}
return strconv.ParseInt(vals[1], 0, 64)
}

func (c *Client) getContentLength(ctx context.Context, endpoint, authToken string) (int64, error) {
// Perform short request to determine content length.
resp, err := c.httpGetRangeRequest(ctx, url, 0, 1024)
resp, err := c.httpRangeRequest(ctx, http.MethodGet, endpoint, authToken, 0, 1024)
if err != nil {
return 0, err
}
Expand All @@ -178,8 +230,8 @@ func (c *Client) getContentLength(ctx context.Context, url string) (int64, error
return 0, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}

vals := strings.Split(resp.Header.Get("Content-Range"), "/")
return strconv.ParseInt(vals[1], 0, 64)
// Extract size from Content-Range header
return parseContentRangeHeader(resp.Header.Get("Content-Range"))
}

// NoopProgressBar implements ProgressBarInterface to allow disabling the progress bar
Expand Down Expand Up @@ -292,12 +344,29 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
}

if res.StatusCode != http.StatusSeeOther {
if res.StatusCode == http.StatusUnauthorized {
return errUnauthorized
}

return fmt.Errorf("unexpected HTTP status %d: %v", res.StatusCode, err)
}

url := res.Header.Get("Location")
location := res.Header.Get("Location")
if location == "" {
return errMissingLocationHeader
}

u, err := url.Parse(location)
if err != nil {
return fmt.Errorf("parsing redirect URL %v: %v", location, err)
}

contentLength, err := c.getContentLength(ctx, url)
authToken := ""
if samehost(c.BaseURL, u) {
authToken = c.AuthToken
}

contentLength, err := c.getContentLength(ctx, u.String(), authToken)
if err != nil {
return err
}
Expand All @@ -308,12 +377,13 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
contentLength, numParts, spec.Concurrency, spec.PartSize, spec.BufferSize,
)

jobs := make(chan partSpec, numParts)
jobs := make(chan partSpec)

g, ctx := errgroup.WithContext(ctx)

// initialize progress bar
pb.Init(contentLength)
defer pb.Wait()

// if spec.Requests is greater than number of parts for requested file,
// set concurrency to number of parts
Expand All @@ -324,7 +394,7 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch

// start workers to manage concurrent HTTP requests
for workerID := uint(0); workerID <= concurrency; workerID++ {
g.Go(c.downloadWorker(ctx, dst, url, jobs, pb))
g.Go(c.downloadWorker(ctx, dst, u.String(), authToken, jobs, pb))
}

// iterate over parts, adding to job queue
Expand All @@ -346,16 +416,12 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
close(jobs)

// wait on errgroup
err = g.Wait()
if err != nil {
if err := g.Wait(); err != nil {
// cancel/remove progress bar on error
pb.Abort(true)
return err
}

// wait on progress bar
pb.Wait()

return err
return nil
}

func (c *Client) singleStreamDownload(ctx context.Context, fp *os.File, res *http.Response, pb ProgressBar) error {
Expand Down
Loading

0 comments on commit d4e04bc

Please sign in to comment.