@@ -7,6 +7,7 @@ package client
7
7
8
8
import (
9
9
"context"
10
+ "errors"
10
11
"fmt"
11
12
"io"
12
13
"net/http"
@@ -19,6 +20,13 @@ import (
19
20
"golang.org/x/sync/errgroup"
20
21
)
21
22
23
+ var (
24
+ errUnauthorized = errors .New ("unauthorized" )
25
+ errMissingLocationHeader = errors .New ("missing HTTP Location header" )
26
+ errInvalidArguments = errors .New ("invalid argument(s)" )
27
+ errUnknownContentLength = errors .New ("unknown content length" )
28
+ )
29
+
22
30
// DownloadImage will retrieve an image from the Container Library, saving it
23
31
// into the specified io.Writer. The timeout value for this operation is set
24
32
// within the context. It is recommended to use a large value (ie. 1800 seconds)
@@ -102,9 +110,22 @@ type Downloader struct {
102
110
BufferSize int64
103
111
}
104
112
105
- // httpGetRangeRequest performs HTTP GET range request to URL specified by 'u' in range start-end.
106
- func (c * Client ) httpGetRangeRequest (ctx context.Context , url string , start , end int64 ) (* http.Response , error ) {
107
- req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
113
+ // httpGetRangeRequest performs HTTP GET range request to URL 'u' in range start-end.
114
+ func (c * Client ) httpGetRangeRequest (ctx context.Context , endpoint , authToken string , start , end int64 ) (* http.Response , error ) {
115
+ return c .httpRangeRequest (ctx , http .MethodGet , endpoint , authToken , start , end )
116
+ }
117
+
118
+ func (c * Client ) httpRangeRequest (ctx context.Context , method , endpoint , authToken string , start , end int64 ) (* http.Response , error ) {
119
+ if start >= end || start < 0 || end < 0 {
120
+ return nil , errInvalidArguments
121
+ }
122
+
123
+ u , err := url .Parse (endpoint )
124
+ if err != nil {
125
+ return nil , err
126
+ }
127
+
128
+ req , err := http .NewRequestWithContext (ctx , method , endpoint , nil )
108
129
if err != nil {
109
130
return nil , err
110
131
}
@@ -113,14 +134,28 @@ func (c *Client) httpGetRangeRequest(ctx context.Context, url string, start, end
113
134
req .Header .Set ("User-Agent" , v )
114
135
}
115
136
137
+ if authToken != "" && samehost (c .BaseURL , u ) {
138
+ // Include authorization header if request being made to host specified by base URL
139
+ req .Header .Add ("Authorization" , fmt .Sprintf ("Bearer %v" , authToken ))
140
+ }
141
+
116
142
req .Header .Add ("Range" , fmt .Sprintf ("bytes=%d-%d" , start , end ))
117
143
118
144
return c .HTTPClient .Do (req )
119
145
}
120
146
147
+ // samehost returns true if host1 and host2 are, in fact, the same host by
148
+ // comparing scheme (https == https) and host (which includes port).
149
+ //
150
+ // Hosts will be treated as dissimilar if one host includes domain suffix
151
+ // and the other does not, even if the host names match.
152
+ func samehost (host1 , host2 * url.URL ) bool {
153
+ return host1 .Scheme == host2 .Scheme && host1 .Host == host2 .Host
154
+ }
155
+
121
156
// downloadFilePart writes range to dst as specified in bufferSpec.
122
- func (c * Client ) downloadFilePart (ctx context.Context , dst * os.File , url string , ps * partSpec , pb ProgressBar ) error {
123
- resp , err := c .httpGetRangeRequest (ctx , url , ps .Start , ps .End )
157
+ func (c * Client ) downloadFilePart (ctx context.Context , dst * os.File , endpoint , authToken string , ps * partSpec , pb ProgressBar ) error {
158
+ resp , err := c .httpGetRangeRequest (ctx , endpoint , authToken , ps .Start , ps .End )
124
159
if err != nil {
125
160
return err
126
161
}
@@ -152,20 +187,37 @@ func (c *Client) downloadFilePart(ctx context.Context, dst *os.File, url string,
152
187
}
153
188
154
189
// downloadWorker is a worker func for processing jobs in stripes channel.
155
- func (c * Client ) downloadWorker (ctx context.Context , dst * os.File , url string , parts <- chan partSpec , pb ProgressBar ) func () error {
190
+ func (c * Client ) downloadWorker (ctx context.Context , dst * os.File , endpoint , authToken string , parts <- chan partSpec , pb ProgressBar ) func () error {
156
191
return func () error {
157
192
for ps := range parts {
158
- if err := c .downloadFilePart (ctx , dst , url , & ps , pb ); err != nil {
193
+ if err := c .downloadFilePart (ctx , dst , endpoint , authToken , & ps , pb ); err != nil {
159
194
return err
160
195
}
161
196
}
162
197
return nil
163
198
}
164
199
}
165
200
166
- func (c * Client ) getContentLength (ctx context.Context , url string ) (int64 , error ) {
201
+ // parseContentRangeHeader returns size returned in Content-Range response HTTP header
202
+ func parseContentRangeHeader (value string ) (int64 , error ) {
203
+ if value == "" {
204
+ return - 1 , nil
205
+ }
206
+
207
+ vals := strings .Split (value , "/" )
208
+ if len (vals ) < 2 {
209
+ return 0 , errUnknownContentLength
210
+ }
211
+ if vals [1 ] == "*" {
212
+ // Server reports size is unknown
213
+ return 0 , fmt .Errorf ("indeterminant size" )
214
+ }
215
+ return strconv .ParseInt (vals [1 ], 0 , 64 )
216
+ }
217
+
218
+ func (c * Client ) getContentLength (ctx context.Context , endpoint , authToken string ) (int64 , error ) {
167
219
// Perform short request to determine content length.
168
- resp , err := c .httpGetRangeRequest (ctx , url , 0 , 1024 )
220
+ resp , err := c .httpRangeRequest (ctx , http . MethodGet , endpoint , authToken , 0 , 1024 )
169
221
if err != nil {
170
222
return 0 , err
171
223
}
@@ -178,8 +230,8 @@ func (c *Client) getContentLength(ctx context.Context, url string) (int64, error
178
230
return 0 , fmt .Errorf ("unexpected HTTP status: %d" , resp .StatusCode )
179
231
}
180
232
181
- vals := strings . Split ( resp . Header . Get ( " Content-Range" ), "/" )
182
- return strconv . ParseInt ( vals [ 1 ], 0 , 64 )
233
+ // Extract size from Content-Range header
234
+ return parseContentRangeHeader ( resp . Header . Get ( "Content-Range" ) )
183
235
}
184
236
185
237
// NoopProgressBar implements ProgressBarInterface to allow disabling the progress bar
@@ -292,12 +344,29 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
292
344
}
293
345
294
346
if res .StatusCode != http .StatusSeeOther {
347
+ if res .StatusCode == http .StatusUnauthorized {
348
+ return errUnauthorized
349
+ }
350
+
295
351
return fmt .Errorf ("unexpected HTTP status %d: %v" , res .StatusCode , err )
296
352
}
297
353
298
- url := res .Header .Get ("Location" )
354
+ location := res .Header .Get ("Location" )
355
+ if location == "" {
356
+ return errMissingLocationHeader
357
+ }
358
+
359
+ u , err := url .Parse (location )
360
+ if err != nil {
361
+ return fmt .Errorf ("parsing redirect URL %v: %v" , location , err )
362
+ }
299
363
300
- contentLength , err := c .getContentLength (ctx , url )
364
+ authToken := ""
365
+ if samehost (c .BaseURL , u ) {
366
+ authToken = c .AuthToken
367
+ }
368
+
369
+ contentLength , err := c .getContentLength (ctx , u .String (), authToken )
301
370
if err != nil {
302
371
return err
303
372
}
@@ -308,12 +377,13 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
308
377
contentLength , numParts , spec .Concurrency , spec .PartSize , spec .BufferSize ,
309
378
)
310
379
311
- jobs := make (chan partSpec , numParts )
380
+ jobs := make (chan partSpec )
312
381
313
382
g , ctx := errgroup .WithContext (ctx )
314
383
315
384
// initialize progress bar
316
385
pb .Init (contentLength )
386
+ defer pb .Wait ()
317
387
318
388
// if spec.Requests is greater than number of parts for requested file,
319
389
// set concurrency to number of parts
@@ -324,7 +394,7 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
324
394
325
395
// start workers to manage concurrent HTTP requests
326
396
for workerID := uint (0 ); workerID <= concurrency ; workerID ++ {
327
- g .Go (c .downloadWorker (ctx , dst , url , jobs , pb ))
397
+ g .Go (c .downloadWorker (ctx , dst , u . String (), authToken , jobs , pb ))
328
398
}
329
399
330
400
// iterate over parts, adding to job queue
@@ -346,16 +416,12 @@ func (c *Client) ConcurrentDownloadImage(ctx context.Context, dst *os.File, arch
346
416
close (jobs )
347
417
348
418
// wait on errgroup
349
- err = g .Wait ()
350
- if err != nil {
419
+ if err := g .Wait (); err != nil {
351
420
// cancel/remove progress bar on error
352
421
pb .Abort (true )
422
+ return err
353
423
}
354
-
355
- // wait on progress bar
356
- pb .Wait ()
357
-
358
- return err
424
+ return nil
359
425
}
360
426
361
427
func (c * Client ) singleStreamDownload (ctx context.Context , fp * os.File , res * http.Response , pb ProgressBar ) error {
0 commit comments