-
Notifications
You must be signed in to change notification settings - Fork 2
/
recorder.go
378 lines (332 loc) · 9.87 KB
/
recorder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
package recorder
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"strings"
"sync"
"time"
"gopkg.in/yaml.v2"
)
// NoRequestError is returned when the recorder mode is ReplayOnly and a
// corresponding entry is not found for the current request.
//
// Because the error is returned from the transport, it may be wrapped.
type NoRequestError struct{ Request *http.Request }
// Error implements the error interface.
func (e NoRequestError) Error() string { return fmt.Sprintf("no recorded entry") }
// Mode controls the mode of the recorder.
type Mode int
// Possible values:
const (
// Auto reads requests from disk if a recording exists. If one does not
// exist, the request is performed and results saved to disk.
Auto Mode = iota
// ReplayOnly only allows replaying from disk without network traffic.
// If a recorded session does not exist, NoRequestError is returned.
ReplayOnly
// Record records all traffic even if an existing entry exists.
// The new requests & responses overwrite any existing ones.
Record
// Passthrough disables the recorder and passes through all traffic
// directly to client. Responses are not recorded to disk but can be
// retrieved from the with Lookup().
Passthrough
)
// Selector chooses a recorded Entry to response to a given request.
type Selector interface {
Select(entries []Entry, req *http.Request) (Entry, bool)
}
// New is a convenience function for creating a new recorder.
func New(filename string, filters ...Filter) *Recorder {
return &Recorder{
Filename: filename,
Mode: Auto,
Transport: http.DefaultTransport,
Filters: filters,
}
}
// Recorder wraps a http.RoundTripper by recording requests that go through it.
//
// When recording, any observed requests are written to disk after response. In
// case previous entries were recorded for the same endpoint, the file is
// overwritten on first request.
type Recorder struct {
// Filename to use for saved entries. A .yml extension is added if not set.
// Any subdirectories are created if needed.
//
// Required if mode is not Passthrough.
Filename string
// Mode to use. Default mode is Auto.
Mode Mode
// Filters to apply before saving to disk.
// Filters are executed in the order specified.
Filters []Filter
// Transport to use for real request.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// An optional Select function may be specified to control which recorded
// Entry is selected to respond to a given request. If nil, the default
// selection is used that picks the first recorded response with a matching
// method and url.
Selector Selector
once sync.Once
index int
entries []Entry
}
var _ http.RoundTripper = (*Recorder)(nil)
func (r *Recorder) loadFromDisk() {
if r.Mode == Passthrough {
return
}
if !strings.HasSuffix(r.Filename, ".yml") {
r.Filename += ".yml"
}
existing, err := ioutil.ReadFile(r.Filename)
if err == nil {
values := bytes.Split(existing, []byte("\n---\n"))
for i, val := range values {
if len(val) == 0 {
continue
}
var e Entry
if err := yaml.Unmarshal(val, &e); err != nil {
panic(fmt.Sprintf("unmarshal session %d from %s: %v", i, r.Filename, err))
}
r.entries = append(r.entries, e)
}
}
}
// RoundTrip implements http.RoundTripper and does the actual request.
//
// The behavior depends on the mode set:
//
// Auto: If an existing entry exists, the response from the entry
// is returned.
// ReplayOnly: Returns a previously recorded response. Returns
// NoRequestError if an entry is found for the request.
// Record: Always send real request and record the response. If an
// existing entry is found, it is overwritten.
// Passthrough: The request is passed through to the underlying
// transport.
//
// Attempting to set another mode will cause a panic.
func (r *Recorder) RoundTrip(req *http.Request) (*http.Response, error) {
if r.Mode > Passthrough {
panic("Unsupported mode")
}
r.once.Do(r.loadFromDisk)
if r.Mode == Auto || r.Mode == ReplayOnly {
var e Entry
var ok bool
if r.Selector != nil {
e, ok = r.Selector.Select(r.entries, req)
} else {
e, ok = r.Lookup(req.Method, req.URL.String())
}
if ok {
resp := e.Response
return &http.Response{
StatusCode: resp.StatusCode,
Header: expandHeader(resp.Headers),
Body: ioutil.NopCloser(strings.NewReader(resp.Body)),
ContentLength: int64(len(e.Response.Body)),
}, nil
}
if r.Mode == ReplayOnly {
return nil, NoRequestError{Request: req}
}
}
if r.Transport == nil {
r.Transport = http.DefaultTransport
}
// Construct request
var bodyOut bytes.Buffer
if req.Body != nil {
if _, err := io.Copy(&bodyOut, req.Body); err != nil {
return nil, err
}
}
req.Body = ioutil.NopCloser(&bodyOut)
out := &Request{
Method: req.Method,
URL: req.URL.String(),
Headers: flattenHeader(req.Header),
Body: bodyOut.String(),
}
for k, vv := range req.Header {
out.Headers[k] = vv[0]
}
// Send request
start := time.Now()
resp, err := r.Transport.RoundTrip(req)
if err != nil {
return nil, err
}
dur := time.Since(start)
// Construct response
in := &Response{
StatusCode: resp.StatusCode,
Headers: flattenHeader(resp.Header),
}
bodyIn, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if err := resp.Body.Close(); err != nil {
return nil, err
}
in.Body = string(bodyIn)
// Construct entry
e := Entry{Request: out, Response: in}
// Apply filters
for _, apply := range r.Filters {
apply(&e)
}
// Reconstruct response after filters have been processed
resp = &http.Response{
StatusCode: in.StatusCode,
Header: expandHeader(in.Headers),
Body: ioutil.NopCloser(strings.NewReader(in.Body)),
ContentLength: int64(len(in.Body)),
}
// Save entry
r.entries = append(r.entries, e)
if r.Mode == Auto || r.Mode == Record {
// Save to disk
if err := os.MkdirAll(path.Dir(r.Filename), 0750); err != nil {
return nil, err
}
var filemode int
if r.index == 0 {
filemode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
} else {
filemode = os.O_WRONLY | os.O_APPEND
}
f, err := os.OpenFile(r.Filename, filemode, 0644)
if err != nil {
return nil, err
}
if r.index > 0 {
fmt.Fprintf(f, "\n---\n\n")
}
fmt.Fprintf(f, "# request %d\n", r.index)
fmt.Fprintf(f, "# timestamp %s\n", start.UTC().Round(time.Second))
fmt.Fprintf(f, "# roundtrip %s\n", dur.Round(time.Millisecond))
r.index++
b, err := yaml.Marshal(e)
if err != nil {
return nil, err
}
if _, err := f.Write(b); err != nil {
return nil, err
}
if err := f.Close(); err != nil {
return nil, err
}
}
return resp, nil
}
// Lookup returns an existing entry matching the given method and url.
//
// The method and url are case-insensitive.
//
// Returns false if no such entry exists.
func (r *Recorder) Lookup(method, url string) (Entry, bool) {
r.once.Do(r.loadFromDisk)
for _, e := range r.entries {
if strings.EqualFold(e.Request.Method, method) && strings.EqualFold(e.Request.URL, url) {
return e, true
}
}
return Entry{}, false
}
// A Filter modifies the entry before it is saved to disk.
//
// Filters are applied after the actual request, with the primary purpose
// being to remove sensitive data from the saved file.
type Filter func(entry *Entry)
// RemoveRequestHeader removes a header with the given name from the request.
// The name of the header is case-sensitive.
func RemoveRequestHeader(name string) Filter {
return func(e *Entry) {
delete(e.Request.Headers, name)
}
}
// RemoveResponseHeader removes a header with the given name from the response.
// The name of the header is case-sensitive.
func RemoveResponseHeader(name string) Filter {
return func(e *Entry) {
delete(e.Response.Headers, name)
}
}
// An Entry is a single recorded request-response entry.
type Entry struct {
Request *Request `yaml:"request"`
Response *Response `yaml:"response"`
}
// A Request is a recorded outgoing request.
//
// The headers are flattened to a simple key-value map. The underlying request
// may contain multiple value for each key but in practice this is not very
// common and working with a simple key-value map is much more convenient.
type Request struct {
Method string `yaml:"method"`
URL string `yaml:"url"`
Headers map[string]string `yaml:"headers,omitempty"`
Body string `yaml:"body,omitempty"`
}
// A Response is a recorded incoming response.
//
// The headers are flattened to a simple key-value map. The underlying request
// may contain multiple value for each key but in practice this is not very
// common and working with a simple key-value map is much more convenient.
type Response struct {
StatusCode int `yaml:"status_code"`
Headers map[string]string `yaml:"headers,omitempty"`
Body string `yaml:"body,omitempty"`
}
func flattenHeader(in http.Header) map[string]string {
out := make(map[string]string, len(in))
for k, vv := range in {
out[k] = vv[0]
}
return out
}
func expandHeader(in map[string]string) http.Header {
out := make(http.Header, len(in))
for k, v := range in {
out.Set(k, v)
}
return out
}
// OncePerCall is a Selector that selects entries based on the method and URL,
// but it will only select any given entry at most once.
type OncePerCall struct {
mu sync.Mutex
used map[int]bool
}
// Select implements Selector and chooses an entry.
func (s *OncePerCall) Select(entries []Entry, req *http.Request) (Entry, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.used == nil {
s.used = map[int]bool{}
}
for i, e := range entries {
if !strings.EqualFold(e.Request.Method, req.Method) {
continue
} else if !strings.EqualFold(e.Request.URL, req.URL.String()) {
continue
}
if !s.used[i] {
s.used[i] = true
return e, true
}
}
return Entry{}, false
}