diff --git a/lib/discord.go b/lib/discord.go index cf99f86..7efb4d5 100644 --- a/lib/discord.go +++ b/lib/discord.go @@ -5,15 +5,17 @@ import ( "crypto/tls" "encoding/json" "errors" - "github.com/sirupsen/logrus" "io" "io/ioutil" "math" "net" "net/http" + "net/url" "strconv" "strings" "time" + + "github.com/sirupsen/logrus" ) var client *http.Client @@ -213,6 +215,21 @@ func GetBotUser(token string) (*BotUserResponse, error) { } func doDiscordReq(ctx context.Context, path string, method string, body io.ReadCloser, header http.Header, query string) (*http.Response, error) { + route := GetMetricsPath(path) + if route == "/channels/!/messages/!/reactions/!/!" { + segs := strings.Split(path, "/") + emojiIdx := 7 + if strings.HasPrefix(path, "/") { + emojiIdx = 8 + } + if emojiIdx < len(segs) { + unescaped, _ := url.PathUnescape(segs[emojiIdx]) + if segs[emojiIdx] == unescaped { + segs[emojiIdx] = url.PathEscape(segs[emojiIdx]) + } + path = strings.Join(segs, "/") + } + } discordReq, err := http.NewRequestWithContext(ctx, method, "https://discord.com"+path+"?"+query, body) if err != nil { return nil, err @@ -229,7 +246,6 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC } if err == nil { - route := GetMetricsPath(path) status := discordResp.Status method := discordResp.Request.Method elapsed := time.Since(startTime).Seconds() diff --git a/lib/http.go b/lib/http.go index eaa1ecd..67a2784 100644 --- a/lib/http.go +++ b/lib/http.go @@ -1,6 +1,7 @@ package lib import ( + "bytes" "io/ioutil" "net/http" "strings" @@ -34,5 +35,6 @@ func CopyResponseToResponseWriter(resp *http.Response, respWriter *http.Response if err != nil { return err } + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) return nil -} \ No newline at end of file +} diff --git a/lib/process_request_test.go b/lib/process_request_test.go new file mode 100644 index 0000000..c34de15 --- /dev/null +++ b/lib/process_request_test.go @@ -0,0 +1,139 @@ +package lib + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func ensureTestLogger() { + if logger == nil { + SetLogger(logrus.New()) + } +} + +func TestDoDiscordReqForwardsRequestBody(t *testing.T) { + ensureTestLogger() + + expectedBody := "payload" + header := make(http.Header) + header.Set("Content-Type", "application/json") + + var capturedBody string + var capturedContentType string + + originalClient := client + client = &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + data, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + capturedBody = string(data) + capturedContentType = req.Header.Get("Content-Type") + _ = req.Body.Close() + + return &http.Response{ + StatusCode: http.StatusNoContent, + Status: "204 No Content", + Header: make(http.Header), + Body: io.NopCloser(bytes.NewBuffer(nil)), + Request: req, + }, nil + })} + t.Cleanup(func() { client = originalClient }) + + resp, err := doDiscordReq(context.Background(), "/api/v10/webhooks/123", http.MethodPost, io.NopCloser(bytes.NewBufferString(expectedBody)), header, "") + if err != nil { + t.Fatalf("doDiscordReq returned error: %v", err) + } + defer resp.Body.Close() + + if capturedBody != expectedBody { + t.Fatalf("expected upstream to receive body %q, got %q", expectedBody, capturedBody) + } + if capturedContentType != "application/json" { + t.Fatalf("expected Content-Type header to be forwarded, got %q", capturedContentType) + } +} + +func TestProcessRequestPreservesResponseBody(t *testing.T) { + ensureTestLogger() + + const ( + expectedRequestBody = "request body" + upstreamBody = "{\"message\":\"Unknown Webhook\",\"code\":10015}" + ) + + var capturedBody string + + originalClient := client + client = &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + data, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + capturedBody = string(data) + _ = req.Body.Close() + + return &http.Response{ + StatusCode: http.StatusNotFound, + Status: "404 Not Found", + Header: http.Header{ + "Content-Type": {"application/json"}, + "X-Ratelimit-Scope": {"user"}, + }, + Body: io.NopCloser(bytes.NewBufferString(upstreamBody)), + Request: req, + }, nil + })} + t.Cleanup(func() { client = originalClient }) + + originalTimeout := contextTimeout + contextTimeout = time.Second + t.Cleanup(func() { contextTimeout = originalTimeout }) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/api/v10/webhooks/123", bytes.NewBufferString(expectedRequestBody)) + recorder := httptest.NewRecorder() + writer := http.ResponseWriter(recorder) + + item := &QueueItem{ + Req: req, + Res: &writer, + } + + resp, err := ProcessRequest(context.Background(), item) + if err != nil { + t.Fatalf("ProcessRequest returned error: %v", err) + } + defer resp.Body.Close() + + if capturedBody != expectedRequestBody { + t.Fatalf("expected upstream to receive body %q, got %q", expectedRequestBody, capturedBody) + } + if recorder.Code != http.StatusNotFound { + t.Fatalf("unexpected status written to client: got %d want %d", recorder.Code, http.StatusNotFound) + } + if recorder.Body.String() != upstreamBody { + t.Fatalf("unexpected body written to client: got %q want %q", recorder.Body.String(), upstreamBody) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if string(bodyBytes) != upstreamBody { + t.Fatalf("expected preserved response body %q, got %q", upstreamBody, string(bodyBytes)) + } +} diff --git a/lib/queue.go b/lib/queue.go index 203383a..eab5eb8 100644 --- a/lib/queue.go +++ b/lib/queue.go @@ -1,17 +1,21 @@ package lib import ( + "bytes" "context" "errors" - "github.com/Clever/leakybucket" - "github.com/Clever/leakybucket/memory" - "github.com/sirupsen/logrus" + "io" + "io/ioutil" "net/http" "strconv" "strings" "sync" "sync/atomic" "time" + + "github.com/Clever/leakybucket" + "github.com/Clever/leakybucket/memory" + "github.com/sirupsen/logrus" ) type QueueItem struct { @@ -272,6 +276,11 @@ func return401(item *QueueItem) { item.doneChan <- nil } +func isUnknownWebhook(_body io.ReadCloser) bool { + body, _ := ioutil.ReadAll(_body); + return bytes.Contains(body, []byte("\"code\": 10015")) +} + func isInteraction(url string) bool { parts := strings.Split(strings.SplitN(url, "?", 1)[0], "/") for _, p := range parts { @@ -348,7 +357,7 @@ func (q *RequestQueue) subscribe(ch *QueueChannel, path string, pathHash uint64) }).Warn("Unexpected 429") } - if resp.StatusCode == 404 && strings.HasPrefix(path, "/webhooks/") && !isInteraction(item.Req.URL.String()) { + if resp.StatusCode == 404 && isUnknownWebhook(resp.Body) && !isInteraction(item.Req.URL.String()) { logger.WithFields(logrus.Fields{ "bucket": path, "route": item.Req.URL.String(),