Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace authn.Request for *http.Request #9

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"connectrpc.com/authn/internal/gen/authn/ping/v1/pingv1connect"
)

func authenticate(_ context.Context, req authn.Request) (any, error) {
func authenticate(_ context.Context, req *http.Request) (any, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, authn.Errorf("invalid authorization")
Expand Down
84 changes: 22 additions & 62 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package authn

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
Expand All @@ -38,7 +37,7 @@ const infoKey key = iota
// the information is automatically attached to the context using [SetInfo].
//
// Implementations must be safe to call concurrently.
type AuthFunc func(ctx context.Context, req Request) (any, error)
type AuthFunc func(ctx context.Context, req *http.Request) (any, error)

// SetInfo attaches authentication information to the context. It's often
// useful in tests.
Expand Down Expand Up @@ -71,79 +70,40 @@ func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}

// Request describes a single RPC invocation.
type Request struct {
request *http.Request
}

// BasicAuth returns the username and password provided in the request's
// Authorization header, if any.
func (r Request) BasicAuth() (username string, password string, ok bool) {
return r.request.BasicAuth()
}

// Cookies parses and returns the HTTP cookies sent with the request, if any.
func (r Request) Cookies() []*http.Cookie {
return r.request.Cookies()
}

// Cookie returns the named cookie provided in the request or
// [http.ErrNoCookie] if not found. If multiple cookies match the given name,
// only one cookie will be returned.
func (r Request) Cookie(name string) (*http.Cookie, error) {
return r.request.Cookie(name)
// InferProtocol returns the inferred RPC protocol. It is one of
// [connect.ProtocolConnect], [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func InferProtocol(request *http.Request) string {
ct := request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
default:
return connect.ProtocolConnect
Comment on lines +82 to +83
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no value in returning "unknown" (or empty string, etc) when the request doesn't look like any of these? Since this is middleware, it seems highly likely it could be used with a mux that has both connect and non-connect routes, so I think we do need better classification here.

}
}

// Procedure returns the RPC procedure name, in the form "/service/method". If
// the request path does not contain a procedure name, the entire path is
// returned.
func (r Request) Procedure() string {
path := strings.TrimSuffix(r.request.URL.Path, "/")
// InferProcedure returns the inferred RPC procedure. It is of the form
// "/service/method". If the request path does not contain a procedure name, the
// entire path is returned.
func InferProcedure(request *http.Request) string {
path := strings.TrimSuffix(request.URL.Path, "/")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this? Do connect RPC servers actually accept an invalid trailing slash like this? Pretty sure gRPC servers are usually strict and do not allow this.

ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return r.request.URL.Path
return request.URL.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return r.request.URL.Path
return request.URL.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return r.request.URL.Path
return request.URL.Path
}
return procedure
}

// ClientAddr returns the client address, in IP:port format.
func (r Request) ClientAddr() string {
return r.request.RemoteAddr
}

// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect],
// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func (r Request) Protocol() string {
ct := r.request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
default:
return connect.ProtocolConnect
}
}

// Header returns the HTTP request headers.
func (r Request) Header() http.Header {
return r.request.Header
}

// TLS returns the TLS connection state, if any. It may be nil if the connection
// is not using TLS.
func (r Request) TLS() *tls.ConnectionState {
return r.request.TLS
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
Expand Down Expand Up @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware {
func (m *Middleware) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
info, err := m.auth(ctx, Request{request: request})
info, err := m.auth(ctx, request)
if err != nil {
_ = m.errW.Write(writer, request, err)
return
Expand Down
61 changes: 59 additions & 2 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"testing"

"connectrpc.com/authn"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -93,8 +94,8 @@
}
}

func authenticate(_ context.Context, req authn.Request) (any, error) {
parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2)
func authenticate(_ context.Context, req *http.Request) (any, error) {
parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
if len(parts) < 2 || parts[0] != "Bearer" {
err := authn.Errorf("expected Bearer authentication scheme")
err.Meta().Set("WWW-Authenticate", "Bearer")
Expand All @@ -105,3 +106,59 @@
}
return hero, nil
}

func TestInferProcedures(t *testing.T) {
t.Parallel()
testProcedures := [][2]string{
{"/empty.v1/GetEmpty", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/prefix/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/", "/"},
{"/invalid/", "/invalid/"},
}
for _, tt := range testProcedures {
req := httptest.NewRequest(http.MethodPost, tt[0], strings.NewReader("{}"))
assert.Equal(t, tt[1], authn.InferProcedure(req))
}
}

func TestInferProtocol(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
method string
wantProtocol string
}{{
name: "connect",
contentType: "application/json",
wantProtocol: connect.ProtocolConnect,
}, {
name: "connectSubPath",
contentType: "application/connect+json",
wantProtocol: connect.ProtocolConnect,
}, {
name: "grpc",
contentType: "application/grpc+proto",
wantProtocol: connect.ProtocolGRPC,
}, {
name: "grpcWeb",
contentType: "application/grpc-web",
wantProtocol: connect.ProtocolGRPCWeb,
}, {
name: "grpcWeb",
contentType: "application/grpc-web+json",
wantProtocol: connect.ProtocolGRPCWeb,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/service/Method", strings.NewReader("{}"))
if tt.contentType != "" {

Check failure on line 158 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
req.Header.Set("Content-Type", tt.contentType)

Check failure on line 159 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
}
assert.Equal(t, tt.wantProtocol, authn.InferProtocol(req))

Check failure on line 161 in authn_test.go

View workflow job for this annotation

GitHub Actions / ci (1.22.x)

loop variable tt captured by func literal
})
}
}
6 changes: 3 additions & 3 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Example_basicAuth() {
// works similarly.

// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
authenticate := func(_ context.Context, req *http.Request) (any, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, authn.Errorf("invalid authorization")
Expand Down Expand Up @@ -95,8 +95,8 @@ func Example_basicAuth() {
func Example_mutualTLS() {
// This example shows how to use this package with mutual TLS.
// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
tls := req.TLS()
authenticate := func(_ context.Context, req *http.Request) (any, error) {
tls := req.TLS
if tls == nil {
return nil, authn.Errorf("TLS required")
}
Expand Down
Loading