Skip to content

Commit 0db445c

Browse files
danrjohnsonmreiferson
authored andcommitted
nsqd: support for multiple Auth HTTP Methods
Adds simple config option and flag to allow for auth to occur via POST request in addition to GET. Rationale: Errors from net/http requests are bubbled to nsqd when there is an error during authentication, such as if the nsq authentication server is unavailable. These errors include the full path, including any GET parameter, thus causing the authentication secret to be logged. This does not occur by default for the POST body thus helping protect secrets in transit between nsqd and the authentication server.
1 parent 62fa868 commit 0db445c

File tree

10 files changed

+81
-28
lines changed

10 files changed

+81
-28
lines changed

apps/nsqd/options.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet {
134134

135135
authHTTPAddresses := app.StringArray{}
136136
flagSet.Var(&authHTTPAddresses, "auth-http-address", "<addr>:<port> or a full url to query auth server (may be given multiple times)")
137+
flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests")
137138
flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)")
138139
flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)")
139140
flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)")

internal/auth/authorizations.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
7676
}
7777

7878
func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
79-
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
79+
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
8080
var retErr error
8181
start := rand.Int()
8282
n := len(authd)
8383
for i := 0; i < n; i++ {
8484
a := authd[(i+start)%n]
85-
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
85+
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
8686
if err != nil {
8787
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
8888
if retErr != nil {
@@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
9797
}
9898

9999
func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
100-
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
100+
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
101+
var authState State
101102
v := url.Values{}
102103
v.Set("remote_ip", remoteIP)
103104
if tlsEnabled {
@@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin
110111

111112
var endpoint string
112113
if strings.Contains(authd, "://") {
113-
endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
114+
endpoint = authd
114115
} else {
115-
endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
116+
endpoint = fmt.Sprintf("http://%s/auth", authd)
116117
}
117118

118-
var authState State
119119
client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
120-
if err := client.GETV1(endpoint, &authState); err != nil {
121-
return nil, err
120+
if httpRequestMethod == "post" {
121+
if err := client.POSTV1(endpoint, v, &authState); err != nil {
122+
return nil, err
123+
}
124+
} else {
125+
endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode())
126+
if err := client.GETV1(endpoint, &authState); err != nil {
127+
return nil, err
128+
}
122129
}
123130

124131
// validation on response

internal/clusterinfo/data.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro
878878
for _, addr := range addrs {
879879
endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs)
880880
c.logf("CI: querying nsqlookupd %s", endpoint)
881-
err := c.client.POSTV1(endpoint)
881+
err := c.client.POSTV1(endpoint, nil, nil)
882882
if err != nil {
883883
errs = append(errs, err)
884884
}
@@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error {
894894
for _, p := range pl {
895895
endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs)
896896
c.logf("CI: querying nsqd %s", endpoint)
897-
err := c.client.POSTV1(endpoint)
897+
err := c.client.POSTV1(endpoint, nil, nil)
898898
if err != nil {
899899
errs = append(errs, err)
900900
}

internal/http_api/api_request.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package http_api
22

33
import (
4+
"bytes"
45
"crypto/tls"
56
"encoding/json"
67
"fmt"
@@ -86,14 +87,26 @@ retry:
8687

8788
// PostV1 is a helper function to perform a V1 HTTP request
8889
// and parse our NSQ daemon's expected response format, with deadlines.
89-
func (c *Client) POSTV1(endpoint string) error {
90+
func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error {
9091
retry:
91-
req, err := http.NewRequest("POST", endpoint, nil)
92+
var reqBody io.Reader
93+
if data != nil {
94+
js, err := json.Marshal(data)
95+
if err != nil {
96+
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
97+
}
98+
reqBody = bytes.NewBuffer(js)
99+
}
100+
101+
req, err := http.NewRequest("POST", endpoint, reqBody)
92102
if err != nil {
93103
return err
94104
}
95105

96106
req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
107+
if reqBody != nil {
108+
req.Header.Add("Content-Type", "application/json")
109+
}
97110

98111
resp, err := c.c.Do(req)
99112
if err != nil {
@@ -116,6 +129,10 @@ retry:
116129
return fmt.Errorf("got response %s %q", resp.Status, body)
117130
}
118131

132+
if v != nil {
133+
return json.Unmarshal(body, &v)
134+
}
135+
119136
return nil
120137
}
121138

nsqd/client_v2.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error {
659659
remoteIP, tlsEnabled, commonName, c.AuthSecret,
660660
c.nsqd.clientTLSConfig,
661661
c.nsqd.getOpts().HTTPClientConnectTimeout,
662-
c.nsqd.getOpts().HTTPClientRequestTimeout)
662+
c.nsqd.getOpts().HTTPClientRequestTimeout,
663+
c.nsqd.getOpts().AuthHTTPRequestMethod,
664+
)
663665
if err != nil {
664666
return err
665667
}

nsqd/nsqd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
135135
}
136136
n.clientTLSConfig = clientTLSConfig
137137

138+
if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
139+
return nil, errors.New("--auth-http-request-method must be post or get")
140+
}
141+
138142
for _, v := range opts.E2EProcessingLatencyPercentiles {
139143
if v <= 0 || v > 1 {
140144
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)

nsqd/nsqd_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,11 @@ func TestCluster(t *testing.T) {
336336
test.Nil(t, err)
337337

338338
url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName)
339-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
339+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
340340
test.Nil(t, err)
341341

342342
url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName)
343-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
343+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
344344
test.Nil(t, err)
345345

346346
// allow some time for nsqd to push info to nsqlookupd
@@ -394,7 +394,7 @@ func TestCluster(t *testing.T) {
394394
test.Equal(t, "ch", lr.Channels[0])
395395

396396
url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName)
397-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
397+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
398398
test.Nil(t, err)
399399

400400
// allow some time for nsqd to push info to nsqlookupd

nsqd/options.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type Options struct {
2727
BroadcastHTTPPort int `flag:"broadcast-http-port"`
2828
NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"`
2929
AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"`
30+
AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"`
3031
HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"`
3132
HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"`
3233

@@ -110,6 +111,7 @@ func NewOptions() *Options {
110111

111112
NSQLookupdTCPAddresses: make([]string, 0),
112113
AuthHTTPAddresses: make([]string, 0),
114+
AuthHTTPRequestMethod: "get",
113115

114116
HTTPClientConnectTimeout: 2 * time.Second,
115117
HTTPClientRequestTimeout: 5 * time.Second,

nsqd/protocol_v2_test.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"os"
1919
"runtime"
2020
"strconv"
21+
"strings"
2122
"sync"
2223
"sync/atomic"
2324
"testing"
@@ -1476,24 +1477,30 @@ func TestClientAuth(t *testing.T) {
14761477
authSuccess := ""
14771478
tlsEnabled := false
14781479
commonName := ""
1479-
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
1480+
httpAuthRequestMethod := "get"
1481+
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)
14801482

14811483
// now one that will succeed
14821484
authResponse = `{"ttl":10, "authorizations":
14831485
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
14841486
}`
14851487
authError = ""
14861488
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
1487-
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
1489+
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)
14881490

14891491
// one with TLS enabled
14901492
tlsEnabled = true
14911493
commonName = "test.local"
1492-
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
1494+
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)
1495+
1496+
// test POST based authentication
1497+
httpAuthRequestMethod = "post"
1498+
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)
1499+
14931500
}
14941501

14951502
func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
1496-
authSuccess string, tlsEnabled bool, commonName string) {
1503+
authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) {
14971504
var err error
14981505
var expectedRemoteIP string
14991506
expectedTLS := "false"
@@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
15031510

15041511
authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15051512
t.Logf("in test auth handler %s", r.RequestURI)
1506-
r.ParseForm()
1507-
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
1508-
test.Equal(t, expectedTLS, r.Form.Get("tls"))
1509-
test.Equal(t, commonName, r.Form.Get("common_name"))
1510-
test.Equal(t, authSecret, r.Form.Get("secret"))
1513+
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))
1514+
1515+
var values url.Values
1516+
1517+
if r.Method == "POST" {
1518+
err = json.NewDecoder(r.Body).Decode(&values)
1519+
if err != nil {
1520+
t.Error(err)
1521+
}
1522+
} else {
1523+
r.ParseForm()
1524+
values = r.Form
1525+
}
1526+
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
1527+
test.Equal(t, expectedTLS, values.Get("tls"))
1528+
test.Equal(t, commonName, values.Get("common_name"))
1529+
test.Equal(t, authSecret, values.Get("secret"))
15111530
fmt.Fprint(w, authResponse)
15121531
}))
15131532
defer authd.Close()
@@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
15191538
opts.Logger = test.NewTestLogger(t)
15201539
opts.LogLevel = LOG_DEBUG
15211540
opts.AuthHTTPAddresses = []string{addr.Host}
1541+
opts.AuthHTTPRequestMethod = httpAuthRequestMethod
15221542
if tlsEnabled {
15231543
opts.TLSCert = "./test/certs/server.pem"
15241544
opts.TLSKey = "./test/certs/server.key"

nsqlookupd/nsqlookupd_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) {
220220

221221
endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
222222
httpAddr, topicName, HostAddr, HTTPPort)
223-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
223+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
224224
test.Nil(t, err)
225225

226226
pr := ProducersDoc{}
@@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) {
263263

264264
endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
265265
httpAddr, topicName, HostAddr, HTTPPort)
266-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
266+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
267267
test.Nil(t, err)
268268

269269
pr := ProducersDoc{}
@@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) {
348348

349349
endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
350350
httpAddr, topicName, HostAddr, HTTPPort)
351-
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
351+
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
352352
test.Nil(t, err)
353353

354354
producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)

0 commit comments

Comments
 (0)