Skip to content

Commit d4e04bc

Browse files
authored
Merge pull request #157 from EmmEff/release-1.3-fixes
Fixes
2 parents 9205dca + 521a909 commit d4e04bc

File tree

2 files changed

+602
-23
lines changed

2 files changed

+602
-23
lines changed

client/pull.go

+88-22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package client
77

88
import (
99
"context"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"net/http"
@@ -19,6 +20,13 @@ import (
1920
"golang.org/x/sync/errgroup"
2021
)
2122

23+
var (
24+
errUnauthorized = errors.New("unauthorized")
25+
errMissingLocationHeader = errors.New("missing HTTP Location header")
26+
errInvalidArguments = errors.New("invalid argument(s)")
27+
errUnknownContentLength = errors.New("unknown content length")
28+
)
29+
2230
// DownloadImage will retrieve an image from the Container Library, saving it
2331
// into the specified io.Writer. The timeout value for this operation is set
2432
// within the context. It is recommended to use a large value (ie. 1800 seconds)
@@ -102,9 +110,22 @@ type Downloader struct {
102110
BufferSize int64
103111
}
104112

105-
// httpGetRangeRequest performs HTTP GET range request to URL specified by 'u' in range start-end.
106-
func (c *Client) httpGetRangeRequest(ctx context.Context, url string, start, end int64) (*http.Response, error) {
107-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
113+
// httpGetRangeRequest performs HTTP GET range request to URL 'u' in range start-end.
114+
func (c *Client) httpGetRangeRequest(ctx context.Context, endpoint, authToken string, start, end int64) (*http.Response, error) {
115+
return c.httpRangeRequest(ctx, http.MethodGet, endpoint, authToken, start, end)
116+
}
117+
118+
func (c *Client) httpRangeRequest(ctx context.Context, method, endpoint, authToken string, start, end int64) (*http.Response, error) {
119+
if start >= end || start < 0 || end < 0 {
120+
return nil, errInvalidArguments
121+
}
122+
123+
u, err := url.Parse(endpoint)
124+
if err != nil {
125+
return nil, err
126+
}
127+
128+
req, err := http.NewRequestWithContext(ctx, method, endpoint, nil)
108129
if err != nil {
109130
return nil, err
110131
}
@@ -113,14 +134,28 @@ func (c *Client) httpGetRangeRequest(ctx context.Context, url string, start, end
113134
req.Header.Set("User-Agent", v)
114135
}
115136

137+
if authToken != "" && samehost(c.BaseURL, u) {
138+
// Include authorization header if request being made to host specified by base URL
139+
req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", authToken))
140+
}
141+
116142
req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end))
117143

118144
return c.HTTPClient.Do(req)
119145
}
120146

147+
// samehost returns true if host1 and host2 are, in fact, the same host by
148+
// comparing scheme (https == https) and host (which includes port).
149+
//
150+
// Hosts will be treated as dissimilar if one host includes domain suffix
151+
// and the other does not, even if the host names match.
152+
func samehost(host1, host2 *url.URL) bool {
153+
return host1.Scheme == host2.Scheme && host1.Host == host2.Host
154+
}
155+
121156
// downloadFilePart writes range to dst as specified in bufferSpec.
122-
func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, url string, ps *partSpec, pb ProgressBar) error {
123-
resp, err := c.httpGetRangeRequest(ctx, url, ps.Start, ps.End)
157+
func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, endpoint, authToken string, ps *partSpec, pb ProgressBar) error {
158+
resp, err := c.httpGetRangeRequest(ctx, endpoint, authToken, ps.Start, ps.End)
124159
if err != nil {
125160
return err
126161
}
@@ -152,20 +187,37 @@ func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, url string,
152187
}
153188

154189
// downloadWorker is a worker func for processing jobs in stripes channel.
155-
func (c *Client) downloadWorker(ctx context.Context, dst *os.File, url string, parts <-chan partSpec, pb ProgressBar) func() error {
190+
func (c *Client) downloadWorker(ctx context.Context, dst *os.File, endpoint, authToken string, parts <-chan partSpec, pb ProgressBar) func() error {
156191
return func() error {
157192
for ps := range parts {
158-
if err := c.downloadFilePart(ctx, dst, url, &ps, pb); err != nil {
193+
if err := c.downloadFilePart(ctx, dst, endpoint, authToken, &ps, pb); err != nil {
159194
return err
160195
}
161196
}
162197
return nil
163198
}
164199
}
165200

166-
func (c *Client) getContentLength(ctx context.Context, url string) (int64, error) {
201+
// parseContentRangeHeader returns size returned in Content-Range response HTTP header
202+
func parseContentRangeHeader(value string) (int64, error) {
203+
if value == "" {
204+
return -1, nil
205+
}
206+
207+
vals := strings.Split(value, "/")
208+
if len(vals) < 2 {
209+
return 0, errUnknownContentLength
210+
}
211+
if vals[1] == "*" {
212+
// Server reports size is unknown
213+
return 0, fmt.Errorf("indeterminant size")
214+
}
215+
return strconv.ParseInt(vals[1], 0, 64)
216+
}
217+
218+
func (c *Client) getContentLength(ctx context.Context, endpoint, authToken string) (int64, error) {
167219
// Perform short request to determine content length.
168-
resp, err := c.httpGetRangeRequest(ctx, url, 0, 1024)
220+
resp, err := c.httpRangeRequest(ctx, http.MethodGet, endpoint, authToken, 0, 1024)
169221
if err != nil {
170222
return 0, err
171223
}
@@ -178,8 +230,8 @@ func (c *Client) getContentLength(ctx context.Context, url string) (int64, error
178230
return 0, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
179231
}
180232

181-
vals := strings.Split(resp.Header.Get("Content-Range"), "/")
182-
return strconv.ParseInt(vals[1], 0, 64)
233+
// Extract size from Content-Range header
234+
return parseContentRangeHeader(resp.Header.Get("Content-Range"))
183235
}
184236

185237
// NoopProgressBar implements ProgressBarInterface to allow disabling the progress bar
@@ -292,12 +344,29 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
292344
}
293345

294346
if res.StatusCode != http.StatusSeeOther {
347+
if res.StatusCode == http.StatusUnauthorized {
348+
return errUnauthorized
349+
}
350+
295351
return fmt.Errorf("unexpected HTTP status %d: %v", res.StatusCode, err)
296352
}
297353

298-
url := res.Header.Get("Location")
354+
location := res.Header.Get("Location")
355+
if location == "" {
356+
return errMissingLocationHeader
357+
}
358+
359+
u, err := url.Parse(location)
360+
if err != nil {
361+
return fmt.Errorf("parsing redirect URL %v: %v", location, err)
362+
}
299363

300-
contentLength, err := c.getContentLength(ctx, url)
364+
authToken := ""
365+
if samehost(c.BaseURL, u) {
366+
authToken = c.AuthToken
367+
}
368+
369+
contentLength, err := c.getContentLength(ctx, u.String(), authToken)
301370
if err != nil {
302371
return err
303372
}
@@ -308,12 +377,13 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
308377
contentLength, numParts, spec.Concurrency, spec.PartSize, spec.BufferSize,
309378
)
310379

311-
jobs := make(chan partSpec, numParts)
380+
jobs := make(chan partSpec)
312381

313382
g, ctx := errgroup.WithContext(ctx)
314383

315384
// initialize progress bar
316385
pb.Init(contentLength)
386+
defer pb.Wait()
317387

318388
// if spec.Requests is greater than number of parts for requested file,
319389
// set concurrency to number of parts
@@ -324,7 +394,7 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
324394

325395
// start workers to manage concurrent HTTP requests
326396
for workerID := uint(0); workerID <= concurrency; workerID++ {
327-
g.Go(c.downloadWorker(ctx, dst, url, jobs, pb))
397+
g.Go(c.downloadWorker(ctx, dst, u.String(), authToken, jobs, pb))
328398
}
329399

330400
// iterate over parts, adding to job queue
@@ -346,16 +416,12 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
346416
close(jobs)
347417

348418
// wait on errgroup
349-
err = g.Wait()
350-
if err != nil {
419+
if err := g.Wait(); err != nil {
351420
// cancel/remove progress bar on error
352421
pb.Abort(true)
422+
return err
353423
}
354-
355-
// wait on progress bar
356-
pb.Wait()
357-
358-
return err
424+
return nil
359425
}
360426

361427
func (c *Client) singleStreamDownload(ctx context.Context, fp *os.File, res *http.Response, pb ProgressBar) error {

0 commit comments

Comments
 (0)