Skip to content

Commit

Permalink
Implement download modes framework (#25)
Browse files Browse the repository at this point in the history
Imeplement a download modes framework that allows specifying a mechanism
to download. This commit implements "buffer" and "tar-extract" which mirror
the original behaviors of pget. The '-x'/'--extract' option automatically
sets the `tar-extract` mode.

tar-extract mode simply wraps the buffer mode for now.

Behavior and CLI options are not modified in this commit
  • Loading branch information
tempusfrangit authored Nov 15, 2023
1 parent c40b7ab commit 9581b69
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 37 deletions.
22 changes: 2 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/replicate/pget/cmd"
"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/download"
"github.com/replicate/pget/pkg/extract"
"github.com/replicate/pget/pkg/optname"
)

Expand Down Expand Up @@ -57,23 +56,6 @@ func execFunc(cmd *cobra.Command, args []string) error {
_ = os.WriteFile(tmpFile, []byte(""), 0644)
defer os.Remove(tmpFile)

buffer, fileSize, err := download.FileToBuffer(url)
if err != nil {
return fmt.Errorf("error downloading file: %w", err)
}

// extract the tar file if the -x flag was provided
if viper.GetBool(optname.Extract) {
err = extract.ExtractTarFile(buffer, dest, fileSize)
if err != nil {
return fmt.Errorf("error extracting file: %v", err)
}
} else {
// if -x flag is not set, save the buffer to a file
err = os.WriteFile(dest, buffer.Bytes(), 0644)
if err != nil {
return fmt.Errorf("error writing file: %v", err)
}
}
return nil
mode := download.GetMode(config.Mode)
return mode.DownloadFile(url, dest)
}
6 changes: 5 additions & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ var (
Extract bool
Force bool
MinimumChunkSize string
Mode string
ResolveHosts []string
Retries int
Verbose bool
Expand Down Expand Up @@ -46,7 +47,10 @@ func AddFlags(cmd *cobra.Command) {
}

func PersistentStartupProcessFlags() error {

Mode = "buffer"
if viper.GetBool(optname.Extract) {
Mode = "tar-extract"
}
if err := convertResolveHostsToMap(); err != nil {
return err
}
Expand Down
44 changes: 28 additions & 16 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"math"
"net/http"
"os"
"time"

"github.com/dustin/go-humanize"
Expand All @@ -16,34 +17,34 @@ import (
"github.com/replicate/pget/pkg/optname"
)

var (
fileSize int64
)
type BufferMode struct {
Client *http.Client
}

func getRemoteFileSize(url string) (string, int64, error) {
// TODO: this needs a retry
resp, err := http.DefaultClient.Head(url)
func (m *BufferMode) getRemoteFileSize(url string) (string, int64, error) {
resp, err := m.Client.Head(url)
if err != nil {
return "", int64(-1), err
}
defer resp.Body.Close()
trueUrl := resp.Request.URL.String()
if trueUrl != url {
fmt.Printf("Redirected to %s\n", trueUrl)
if viper.GetBool(optname.Verbose) {
fmt.Printf("Redirected to %s\n", trueUrl)
}
}

fSize := resp.ContentLength
if fSize <= 0 {
fileSize := resp.ContentLength
if fileSize <= 0 {
return "", int64(-1), fmt.Errorf("unable to determine file size")
}
fileSize = fSize
return trueUrl, fileSize, nil
}

func FileToBuffer(url string) (*bytes.Buffer, int64, error) {
func (m *BufferMode) fileToBuffer(url string) (*bytes.Buffer, int64, error) {
maxConcurrency := viper.GetInt(optname.Concurrency)

trueURL, fileSize, err := getRemoteFileSize(url)
trueURL, fileSize, err := m.getRemoteFileSize(url)
if err != nil {
return nil, -1, err
}
Expand Down Expand Up @@ -81,7 +82,7 @@ func FileToBuffer(url string) (*bytes.Buffer, int64, error) {
}

errGroup.Go(func() error {
return downloadChunk(ctx, start, end, data[start:end+1], trueURL)
return m.downloadChunk(ctx, start, end, data[start:end+1], trueURL)
})
}

Expand All @@ -97,14 +98,13 @@ func FileToBuffer(url string) (*bytes.Buffer, int64, error) {
return buffer, fileSize, nil
}

func downloadChunk(ctx context.Context, start, end int64, dataSlice []byte, trueURL string) error {
client := newClient()
func (m *BufferMode) downloadChunk(ctx context.Context, start, end int64, dataSlice []byte, trueURL string) error {
req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil)
if err != nil {
return fmt.Errorf("failed to download %s", req.URL.String())
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := client.Do(req)
resp, err := m.Client.Do(req)
if err != nil {
return fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)
}
Expand All @@ -119,3 +119,15 @@ func downloadChunk(ctx context.Context, start, end int64, dataSlice []byte, true
}
return nil
}

func (m *BufferMode) DownloadFile(url string, dest string) error {
buffer, _, err := m.fileToBuffer(url)
if err != nil {
return err
}
err = os.WriteFile(dest, buffer.Bytes(), 0644)
if err != nil {
return fmt.Errorf("error writing file: %w", err)
}
return nil
}
16 changes: 16 additions & 0 deletions pkg/download/modes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package download

type modeFactoryFunc func() Mode

var modes = map[string]modeFactoryFunc{
"buffer": func() Mode { return &BufferMode{Client: newClient()} },
"tar-extract": func() Mode { return &ExtractTarMode{} },
}

type Mode interface {
DownloadFile(url string, dest string) error
}

func GetMode(name string) Mode {
return modes[name]()
}
23 changes: 23 additions & 0 deletions pkg/download/tar_extraction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package download

import (
"fmt"

"github.com/replicate/pget/pkg/extract"
)

type ExtractTarMode struct {
}

func (m *ExtractTarMode) DownloadFile(url string, dest string) error {
downloader := &BufferMode{Client: newClient()}
buffer, fileSize, err := downloader.fileToBuffer(url)
if err != nil {
return fmt.Errorf("error downloading file: %w", err)
}
err = extract.ExtractTarFile(buffer, dest, fileSize)
if err != nil {
return fmt.Errorf("error extracting file: %w", err)
}
return nil
}

0 comments on commit 9581b69

Please sign in to comment.