-
-
Notifications
You must be signed in to change notification settings - Fork 52
/
recorder.go
131 lines (124 loc) · 3.55 KB
/
recorder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package requests
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/base64"
"errors"
"fmt"
"io/fs"
"net/http"
"net/http/httputil"
"os"
"path/filepath"
)
// Record returns an http.RoundTripper that writes out its
// requests and their responses to text files in basepath.
// Requests are named according to a hash of their contents.
// Responses are named according to the request that made them.
//
// Deprecated: Use reqtest.Record.
func Record(rt http.RoundTripper, basepath string) Transport {
if rt == nil {
rt = http.DefaultTransport
}
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("problem while recording transport: %w", err)
}
}()
_ = os.MkdirAll(basepath, 0755)
b, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
reqname, resname := buildName(b)
name := filepath.Join(basepath, reqname)
if err = os.WriteFile(name, b, 0644); err != nil {
return nil, err
}
if res, err = rt.RoundTrip(req); err != nil {
return
}
b, err = httputil.DumpResponse(res, true)
if err != nil {
return nil, err
}
name = filepath.Join(basepath, resname)
if err = os.WriteFile(name, b, 0644); err != nil {
return nil, err
}
return
})
}
// Replay returns an http.RoundTripper that reads its
// responses from text files in basepath.
// Responses are looked up according to a hash of the request.
//
// Deprecated: Use reqtest.Replay.
func Replay(basepath string) Transport {
return ReplayFS(os.DirFS(basepath))
}
var errNotFound = errors.New("response not found")
// ReplayFS returns an http.RoundTripper that reads its
// responses from text files in the fs.FS.
// Responses are looked up according to a hash of the request.
// Response file names may optionally be prefixed with comments for better human organization.
//
// Deprecated: Use reqtest.ReplayFS.
func ReplayFS(fsys fs.FS) Transport {
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("problem while replaying transport: %w", err)
}
}()
b, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
_, name := buildName(b)
glob := "*" + name
matches, err := fs.Glob(fsys, glob)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, fmt.Errorf("%w: no replay file matches %q", errNotFound, glob)
}
if len(matches) > 1 {
return nil, fmt.Errorf("ambiguous response: multiple replay files match %q", glob)
}
b, err = fs.ReadFile(fsys, matches[0])
if err != nil {
return nil, err
}
r := bufio.NewReader(bytes.NewReader(b))
return http.ReadResponse(r, req)
})
}
func buildName(b []byte) (reqname, resname string) {
h := md5.New()
h.Write(b)
s := base64.URLEncoding.EncodeToString(h.Sum(nil))
return s[:8] + ".req.txt", s[:8] + ".res.txt"
}
// Caching returns an http.RoundTripper that attempts to read its
// responses from text files in basepath. If the response is absent,
// it caches the result of issuing the request with rt in basepath.
// Requests are named according to a hash of their contents.
// Responses are named according to the request that made them.
//
// Deprecated: Use reqtest.Caching.
func Caching(rt http.RoundTripper, basepath string) Transport {
replay := Replay(basepath).RoundTrip
record := Record(rt, basepath).RoundTrip
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
res, err = replay(req)
if errors.Is(err, errNotFound) {
res, err = record(req)
}
return
})
}