Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace authn.Request for *http.Request #9

Merged
merged 11 commits into from
Sep 30, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"connectrpc.com/authn/internal/gen/authn/ping/v1/pingv1connect"
)

func authenticate(_ context.Context, req authn.Request) (any, error) {
func authenticate(_ context.Context, req *http.Request) (any, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, authn.Errorf("invalid authorization")
Expand Down
155 changes: 91 additions & 64 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ package authn

import (
"context"
"crypto/tls"
"fmt"
"mime"
"net/http"
"net/url"
"strings"

"connectrpc.com/connect"
Expand All @@ -38,7 +39,7 @@ const infoKey key = iota
// the information is automatically attached to the context using [SetInfo].
//
// Implementations must be safe to call concurrently.
type AuthFunc func(ctx context.Context, req Request) (any, error)
type AuthFunc func(ctx context.Context, req *http.Request) (any, error)

// SetInfo attaches authentication information to the context. It's often
// useful in tests.
Expand Down Expand Up @@ -71,77 +72,63 @@ func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}

// Request describes a single RPC invocation.
type Request struct {
request *http.Request
}

// BasicAuth returns the username and password provided in the request's
// Authorization header, if any.
func (r Request) BasicAuth() (username string, password string, ok bool) {
return r.request.BasicAuth()
}

// Cookies parses and returns the HTTP cookies sent with the request, if any.
func (r Request) Cookies() []*http.Cookie {
return r.request.Cookies()
}

// Cookie returns the named cookie provided in the request or
// [http.ErrNoCookie] if not found. If multiple cookies match the given name,
// only one cookie will be returned.
func (r Request) Cookie(name string) (*http.Cookie, error) {
return r.request.Cookie(name)
// InferProtocol returns the inferred RPC protocol. It is one of
// [connect.ProtocolConnect], [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func InferProtocol(request *http.Request) (string, bool) {
const (
grpcContentTypeDefault = "application/grpc"
grpcContentTypePrefix = "application/grpc+"
grpcWebContentTypeDefault = "application/grpc-web"
grpcWebContentTypePrefix = "application/grpc-web+"
connectStreamingContentTypePrefix = "application/connect+"
connectUnaryContentTypePrefix = "application/"
connectUnaryMessageQueryParameter = "message"
connectUnaryEncodingQueryParameter = "encoding"
)
ctype := canonicalizeContentType(request.Header.Get("Content-Type"))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return connect.ProtocolGRPC, true
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return connect.ProtocolGRPCWeb, true
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
return connect.ProtocolConnect, true
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
return connect.ProtocolConnect, true
case isGet:
query := request.URL.Query()
hasMessage := query.Has(connectUnaryMessageQueryParameter)
hasEncoding := query.Has(connectUnaryEncodingQueryParameter)
if !hasMessage || !hasEncoding {
return "", false
}
return connect.ProtocolConnect, true
default:
return "", false
}
}

// Procedure returns the RPC procedure name, in the form "/service/method". If
// the request path does not contain a procedure name, the entire path is
// returned.
func (r Request) Procedure() string {
path := strings.TrimSuffix(r.request.URL.Path, "/")
// InferProcedure returns the inferred RPC procedure. It's returned in the form
// "/service/method" if a valid suffix is found. If the request doesn't contain
// a service and method, the entire path and false is returned.
func InferProcedure(url *url.URL) (string, bool) {
path := url.Path
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return r.request.URL.Path
return url.Path, false
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return r.request.URL.Path
return url.Path, false
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return r.request.URL.Path
}
return procedure
}

// ClientAddr returns the client address, in IP:port format.
func (r Request) ClientAddr() string {
return r.request.RemoteAddr
}

// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect],
// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func (r Request) Protocol() string {
ct := r.request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
default:
return connect.ProtocolConnect
// Ensure that the service and method are non-empty.
if ultimate == len(path)-1 || penultimate == ultimate-1 {
return url.Path, false
}
}

// Header returns the HTTP request headers.
func (r Request) Header() http.Header {
return r.request.Header
}

// TLS returns the TLS connection state, if any. It may be nil if the connection
// is not using TLS.
func (r Request) TLS() *tls.ConnectionState {
return r.request.TLS
return procedure, true
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
Expand Down Expand Up @@ -175,7 +162,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware {
func (m *Middleware) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
info, err := m.auth(ctx, Request{request: request})
info, err := m.auth(ctx, request)
if err != nil {
_ = m.errW.Write(writer, request, err)
return
Expand All @@ -187,3 +174,43 @@ func (m *Middleware) Wrap(handler http.Handler) http.Handler {
handler.ServeHTTP(writer, request)
})
}

func canonicalizeContentType(contentType string) string {
// Typically, clients send Content-Type in canonical form, without
// parameters. In those cases, we'd like to avoid parsing and
// canonicalization overhead.
//
// See https://www.rfc-editor.org/rfc/rfc2045.html#section-5.1 for a full
// grammar.
var slashes int
for _, r := range contentType {
switch {
case r >= 'a' && r <= 'z':
case r == '.' || r == '+' || r == '-':
case r == '/':
slashes++
default:
return canonicalizeContentTypeSlow(contentType)
}
}
if slashes == 1 {
return contentType
}
return canonicalizeContentTypeSlow(contentType)
}

func canonicalizeContentTypeSlow(contentType string) string {
base, params, err := mime.ParseMediaType(contentType)
if err != nil {
return contentType
}
// According to RFC 9110 Section 8.3.2, the charset parameter value should be treated as case-insensitive.
// mime.FormatMediaType canonicalizes parameter names, but not parameter values,
// because the case sensitivity of a parameter value depends on its semantics.
// Therefore, the charset parameter value should be canonicalized here.
// ref.) https://httpwg.org/specs/rfc9110.html#rfc.section.8.3.2
if charset, ok := params["charset"]; ok {
params["charset"] = strings.ToLower(charset)
}
return mime.FormatMediaType(base, params)
}
154 changes: 152 additions & 2 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -94,8 +96,8 @@ func assertInfo(ctx context.Context, tb testing.TB) {
}
}

func authenticate(_ context.Context, req authn.Request) (any, error) {
parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2)
func authenticate(_ context.Context, req *http.Request) (any, error) {
parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
if len(parts) < 2 || parts[0] != "Bearer" {
err := authn.Errorf("expected Bearer authentication scheme")
err.Meta().Set("WWW-Authenticate", "Bearer")
Expand All @@ -106,3 +108,151 @@ func authenticate(_ context.Context, req authn.Request) (any, error) {
}
return hero, nil
}

func TestInferProcedures(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
want string
valid bool
}{
{name: "simple", url: "http://localhost:8080/foo", want: "/foo", valid: false},
{name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar", valid: true},
{name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar/", valid: false},
{name: "subroute", url: "http://localhost:8080/api/service/bar", want: "/service/bar", valid: true},
{name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/api/service/bar/", valid: false},
{name: "missingService", url: "http://localhost:8080//foo", want: "//foo", valid: false},
{name: "missingMethod", url: "http://localhost:8080/foo//", want: "/foo//", valid: false},
{
name: "real",
url: "http://localhost:8080/connect.ping.v1.PingService/Ping",
want: "/connect.ping.v1.PingService/Ping",
valid: true,
},
}
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(testcase.url)
require.NoError(t, err)
got, valid := authn.InferProcedure(url)
assert.Equal(t, testcase.want, got)
assert.Equal(t, testcase.valid, valid)
})
}
}

func TestInferProtocol(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
method string
params url.Values
want string
valid bool
}{{
name: "connectUnary",
contentType: "application/json",
method: http.MethodPost,
params: nil,
want: connect.ProtocolConnect,
valid: true,
}, {
name: "connectStreaming",
contentType: "application/connec+json",
method: http.MethodPost,
params: nil,
want: connect.ProtocolConnect,
valid: true,
}, {
name: "grpcWeb",
contentType: "application/grpc-web",
method: http.MethodPost,
params: nil,
want: connect.ProtocolGRPCWeb,
valid: true,
}, {
name: "grpc",
contentType: "application/grpc",
method: http.MethodPost,
params: nil,
want: connect.ProtocolGRPC,
valid: true,
}, {
name: "connectGet",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{"{}"}, "encoding": []string{"json"}},
want: connect.ProtocolConnect,
valid: true,
}, {
name: "connectGetProto",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{""}, "encoding": []string{"proto"}},
want: connect.ProtocolConnect,
valid: true,
}, {
name: "connectGetMissingParams",
contentType: "",
method: http.MethodGet,
params: nil,
want: "",
valid: false,
}, {
name: "connectGetMissingParam-Message",
contentType: "",
method: http.MethodGet,
params: url.Values{"encoding": []string{"json"}},
want: "",
valid: false,
}, {
name: "connectGetMissingParam-Encoding",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{"{}"}},
want: "",
valid: false,
}, {
name: "connectPutContentType",
contentType: "application/connect+json",
method: http.MethodPut,
params: nil,
want: "",
valid: false,
}, {
name: "nakedGet",
contentType: "",
method: http.MethodGet,
params: nil,
want: "",
valid: false,
}, {
name: "unknown",
contentType: "text/html",
method: http.MethodPost,
params: nil,
want: "",
valid: false,
}}
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(testcase.method, "http://localhost:8080/service/Method", nil)
if testcase.contentType != "" {
req.Header.Set("Content-Type", testcase.contentType)
}
if testcase.params != nil {
req.URL.RawQuery = testcase.params.Encode()
}
req.Method = testcase.method
got, valid := authn.InferProtocol(req)
assert.Equal(t, testcase.want, got, "protocol")
assert.Equal(t, testcase.valid, valid, "valid")
})
}
}
Loading