Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kadai3-2 shuheiktgw #49

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kadai3-2/shuheiktgw/gget/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gget
21 changes: 21 additions & 0 deletions kadai3-2/shuheiktgw/gget/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
gget
====

gget is a wget like command to download file, but downloads a file in parallel.

## Usage
```
gget [options...] URL

OPTIONS:
--parallel value, -p value specifies the amount of parallelism (default: the number of CPU)
--help, -h prints help

```

## Install

```
go build
./gget [options...] URL
```
75 changes: 75 additions & 0 deletions kadai3-2/shuheiktgw/gget/cli.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package main

import (
"flag"
"fmt"
"io"
"runtime"
)

const (
ExitCodeOK = iota
ExitCodeError
ExitCodeBadArgsError
ExitCodeParseFlagsError
ExitCodeInvalidFlagError
)

const name = "gget"

// CLI represents CLI interface for gget
type CLI struct {
outStream, errStream io.Writer
}

// Run runs gget command
func (cli *CLI) Run(args []string) int {
var parallel int

flags := flag.NewFlagSet(name, flag.ContinueOnError)
flags.Usage = func() {
fmt.Fprint(cli.outStream, usage)
}

numCPU := runtime.NumCPU()
flags.IntVar(&parallel, "parallel", numCPU, "")
flags.IntVar(&parallel, "p", numCPU, "")

if err := flags.Parse(args[1:]); err != nil {
return ExitCodeParseFlagsError
}

if parallel < 1 {
fmt.Fprintf(cli.errStream, "Failed to set up gget: The number of parallels cannot be less than one\n")
return ExitCodeInvalidFlagError
}

parsedArgs := flags.Args()
if len(parsedArgs) != 1 {
fmt.Fprintf(cli.errStream, "Invalid arguments: you need to set exactly one URL\n")
return ExitCodeBadArgsError
}

request, err := NewRequest(parsedArgs[0], parallel)
if err != nil {
fmt.Fprintf(cli.errStream, "Error occurred while initializing a request: %s\n", err)
return ExitCodeError
}

if err := request.Do(); err != nil {
fmt.Fprintf(cli.errStream, "Error occurred while downloading the file: %s\n", err)
return ExitCodeError
}

return ExitCodeOK
}

var usage = `Usage: gget [options...] URL

gget is a wget like command to download file, but downloads a file in parallel

OPTIONS:
--parallel value, -p value specifies the amount of parallelism (default: the number of CPU)
--help, -h prints help

`
8 changes: 8 additions & 0 deletions kadai3-2/shuheiktgw/gget/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package main

import "os"

func main() {
cli := &CLI{outStream: os.Stdout, errStream: os.Stderr}
os.Exit(cli.Run(os.Args))
}
176 changes: 176 additions & 0 deletions kadai3-2/shuheiktgw/gget/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package main

import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"

"golang.org/x/sync/errgroup"
)

// RangeRequest represents a request with a range access
type RangeRequest struct {
URL string
FName string
Ranges []*Range
}

// NonRangeRequest represents a request without a range access
type NonRangeRequest struct {
URL string
FName string
}

// Request represents a request
type Request interface {
Do() error
}

// Range tells tha range of file to download
type Range struct {
start int64
end int64
}

// NewRequest initializes Request object
func NewRequest(rawURL string, parallel int) (Request, error) {
u, err := url.Parse(rawURL)
if err != nil {
return nil, err
}

ss := strings.Split(u.Path, "/")
fname := ss[len(ss)-1]

res, err := http.Head(rawURL)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Accept-Ranges
この時点で Accept-Ranges をチェックしておいたほうが良いと思います

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

確かに,ありがとうございます!
RangeRequestとNonRangeRequestに分けて扱うようにしました.

if err != nil {
return nil, err
}
defer res.Body.Close()

if res.Header.Get("Accept-Ranges") != "bytes" {
return &NonRangeRequest{URL: rawURL, FName: fname}, nil
}

total := res.ContentLength
unit := total / int64(parallel)
ranges := make([]*Range, parallel)

for i := 0; i < parallel; i++ {
var start int64
if i == 0 {
start = 0
} else {
start = int64(i)*unit + 1
}

var end int64
if i == parallel-1 {
end = total
} else {
end = int64(i+1) * unit
}

ranges[i] = &Range{start: start, end: end}
}

return &RangeRequest{URL: rawURL, FName: fname, Ranges: ranges}, nil
}

// Do sends a real HTTP requests in parallel
func (r *NonRangeRequest) Do() error {
req, err := http.NewRequest(http.MethodGet, r.URL, nil)

client := http.DefaultClient
res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()

return saveResponseBody(r.FName, res)
}

// Do sends a real HTTP requests in parallel
func (r *RangeRequest) Do() error {
eg, ctx := errgroup.WithContext(context.TODO())

for idx := range r.Ranges {
// DO NOT refer to idx directly since function below
// is a closure and idx changes for each iterations
i := idx
eg.Go(func() error {
return r.do(i, ctx)
})
}

if err := eg.Wait(); err != nil {
return err
}

return r.mergeFiles()
}

func (r *RangeRequest) do(idx int, ctx context.Context) error {
req, err := http.NewRequest(http.MethodGet, r.URL, nil)
if err != nil {
return err
}
req = req.WithContext(ctx)

ran := r.Ranges[idx]
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", ran.start, ran.end))

client := http.DefaultClient

res, err := client.Do(req)
if err != nil {
return err
}
defer res.Body.Close()

tmpFName := fmt.Sprintf("%s.%d", r.FName, idx)
return saveResponseBody(tmpFName, res)
}

func (r *RangeRequest) mergeFiles() error {
f, err := os.Create(r.FName)
if err != nil {
return err
}
defer f.Close()

for idx := range r.Ranges {
tmpFName := fmt.Sprintf("%s.%d", r.FName, idx)
tmpFile, err := os.Open(tmpFName)
if err != nil {
return err
}

io.Copy(f, tmpFile)
tmpFile.Close()
if err := os.Remove(tmpFName); err != nil {
return err
}
}

return nil
}

func saveResponseBody(fname string, response *http.Response) error {
file, err := os.Create(fname)
if err != nil {
return err
}
defer file.Close()

if _, err := io.Copy(file, response.Body); err != nil {
return err
}

return nil
}