Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions lib/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ import (
"crypto/tls"
"encoding/json"
"errors"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/sirupsen/logrus"
)

var client *http.Client
Expand Down Expand Up @@ -213,6 +215,21 @@ func GetBotUser(token string) (*BotUserResponse, error) {
}

func doDiscordReq(ctx context.Context, path string, method string, body io.ReadCloser, header http.Header, query string) (*http.Response, error) {
route := GetMetricsPath(path)
if route == "/channels/!/messages/!/reactions/!/!" {
segs := strings.Split(path, "/")
emojiIdx := 7
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the path doesn't have an API Version set (e.g. /api/channels/...), this doesn't work because it's missing one slash.
I'd suggest getting the 2nd last index

if strings.HasPrefix(path, "/") {
emojiIdx = 8
}
if emojiIdx < len(segs) {
unescaped, _ := url.PathUnescape(segs[emojiIdx])
if segs[emojiIdx] == unescaped {
segs[emojiIdx] = url.PathEscape(segs[emojiIdx])
}
path = strings.Join(segs, "/")
}
}
discordReq, err := http.NewRequestWithContext(ctx, method, "https://discord.com"+path+"?"+query, body)
if err != nil {
return nil, err
Expand All @@ -229,7 +246,6 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC
}

if err == nil {
route := GetMetricsPath(path)
status := discordResp.Status
method := discordResp.Request.Method
elapsed := time.Since(startTime).Seconds()
Expand Down
4 changes: 3 additions & 1 deletion lib/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lib

import (
"bytes"
"io/ioutil"
"net/http"
"strings"
Expand Down Expand Up @@ -34,5 +35,6 @@ func CopyResponseToResponseWriter(resp *http.Response, respWriter *http.Response
if err != nil {
return err
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
return nil
}
}
139 changes: 139 additions & 0 deletions lib/process_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package lib

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/sirupsen/logrus"
)

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func ensureTestLogger() {
if logger == nil {
SetLogger(logrus.New())
}
}

func TestDoDiscordReqForwardsRequestBody(t *testing.T) {
ensureTestLogger()

expectedBody := "payload"
header := make(http.Header)
header.Set("Content-Type", "application/json")

var capturedBody string
var capturedContentType string

originalClient := client
client = &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
data, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
capturedBody = string(data)
capturedContentType = req.Header.Get("Content-Type")
_ = req.Body.Close()

return &http.Response{
StatusCode: http.StatusNoContent,
Status: "204 No Content",
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBuffer(nil)),
Request: req,
}, nil
})}
t.Cleanup(func() { client = originalClient })

resp, err := doDiscordReq(context.Background(), "/api/v10/webhooks/123", http.MethodPost, io.NopCloser(bytes.NewBufferString(expectedBody)), header, "")
if err != nil {
t.Fatalf("doDiscordReq returned error: %v", err)
}
defer resp.Body.Close()

if capturedBody != expectedBody {
t.Fatalf("expected upstream to receive body %q, got %q", expectedBody, capturedBody)
}
if capturedContentType != "application/json" {
t.Fatalf("expected Content-Type header to be forwarded, got %q", capturedContentType)
}
}

func TestProcessRequestPreservesResponseBody(t *testing.T) {
ensureTestLogger()

const (
expectedRequestBody = "request body"
upstreamBody = "{\"message\":\"Unknown Webhook\",\"code\":10015}"
)

var capturedBody string

originalClient := client
client = &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
data, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
capturedBody = string(data)
_ = req.Body.Close()

return &http.Response{
StatusCode: http.StatusNotFound,
Status: "404 Not Found",
Header: http.Header{
"Content-Type": {"application/json"},
"X-Ratelimit-Scope": {"user"},
},
Body: io.NopCloser(bytes.NewBufferString(upstreamBody)),
Request: req,
}, nil
})}
t.Cleanup(func() { client = originalClient })

originalTimeout := contextTimeout
contextTimeout = time.Second
t.Cleanup(func() { contextTimeout = originalTimeout })

req := httptest.NewRequest(http.MethodPost, "http://localhost/api/v10/webhooks/123", bytes.NewBufferString(expectedRequestBody))
recorder := httptest.NewRecorder()
writer := http.ResponseWriter(recorder)

item := &QueueItem{
Req: req,
Res: &writer,
}

resp, err := ProcessRequest(context.Background(), item)
if err != nil {
t.Fatalf("ProcessRequest returned error: %v", err)
}
defer resp.Body.Close()

if capturedBody != expectedRequestBody {
t.Fatalf("expected upstream to receive body %q, got %q", expectedRequestBody, capturedBody)
}
if recorder.Code != http.StatusNotFound {
t.Fatalf("unexpected status written to client: got %d want %d", recorder.Code, http.StatusNotFound)
}
if recorder.Body.String() != upstreamBody {
t.Fatalf("unexpected body written to client: got %q want %q", recorder.Body.String(), upstreamBody)
}

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
if string(bodyBytes) != upstreamBody {
t.Fatalf("expected preserved response body %q, got %q", upstreamBody, string(bodyBytes))
}
}
17 changes: 13 additions & 4 deletions lib/queue.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package lib

import (
"bytes"
"context"
"errors"
"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
)

type QueueItem struct {
Expand Down Expand Up @@ -272,6 +276,11 @@ func return401(item *QueueItem) {
item.doneChan <- nil
}

func isUnknownWebhook(_body io.ReadCloser) bool {
body, _ := ioutil.ReadAll(_body);
return bytes.Contains(body, []byte("\"code\": 10015"))
}

func isInteraction(url string) bool {
parts := strings.Split(strings.SplitN(url, "?", 1)[0], "/")
for _, p := range parts {
Expand Down Expand Up @@ -348,7 +357,7 @@ func (q *RequestQueue) subscribe(ch *QueueChannel, path string, pathHash uint64)
}).Warn("Unexpected 429")
}

if resp.StatusCode == 404 && strings.HasPrefix(path, "/webhooks/") && !isInteraction(item.Req.URL.String()) {
if resp.StatusCode == 404 && isUnknownWebhook(resp.Body) && !isInteraction(item.Req.URL.String()) {
logger.WithFields(logrus.Fields{
"bucket": path,
"route": item.Req.URL.String(),
Expand Down
Loading