Skip to content

Commit

Permalink
Drop authn.Request for *http.Request
Browse files Browse the repository at this point in the history
Signed-off-by: Edward McFarlane <[email protected]>
  • Loading branch information
emcfarlane committed Jun 14, 2024
1 parent 2ce323f commit 87caeaf
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 54 deletions.
58 changes: 9 additions & 49 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package authn

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
Expand All @@ -38,7 +37,7 @@ const infoKey key = iota
// the information is automatically attached to the context using [SetInfo].
//
// Implementations must be safe to call concurrently.
type AuthFunc func(ctx context.Context, req Request) (any, error)
type AuthFunc func(ctx context.Context, req *http.Request) (any, error)

// SetInfo attaches authentication information to the context. It's often
// useful in tests.
Expand Down Expand Up @@ -71,58 +70,30 @@ func Errorf(template string, args ...any) *connect.Error {
return connect.NewError(connect.CodeUnauthenticated, fmt.Errorf(template, args...))
}

// Request describes a single RPC invocation.
type Request struct {
Request *http.Request
}

// BasicAuth returns the username and password provided in the request's
// Authorization header, if any.
func (r Request) BasicAuth() (username string, password string, ok bool) {
return r.Request.BasicAuth()
}

// Cookies parses and returns the HTTP cookies sent with the request, if any.
func (r Request) Cookies() []*http.Cookie {
return r.Request.Cookies()
}

// Cookie returns the named cookie provided in the request or
// [http.ErrNoCookie] if not found. If multiple cookies match the given name,
// only one cookie will be returned.
func (r Request) Cookie(name string) (*http.Cookie, error) {
return r.Request.Cookie(name)
}

// Procedure returns the RPC procedure name, in the form "/service/method". If
// the request path does not contain a procedure name, the entire path is
// returned.
func (r Request) Procedure() string {
path := strings.TrimSuffix(r.Request.URL.Path, "/")
func Procedure(request *http.Request) string {
path := strings.TrimSuffix(request.URL.Path, "/")
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return r.Request.URL.Path
return request.URL.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return r.Request.URL.Path
return request.URL.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return r.Request.URL.Path
return request.URL.Path
}
return procedure
}

// ClientAddr returns the client address, in IP:port format.
func (r Request) ClientAddr() string {
return r.Request.RemoteAddr
}

// Protocol returns the RPC protocol. It is one of [connect.ProtocolConnect],
// [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func (r Request) Protocol() string {
ct := r.Request.Header.Get("Content-Type")
func Protocol(request *http.Request) string {
ct := request.Header.Get("Content-Type")
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
Expand All @@ -133,17 +104,6 @@ func (r Request) Protocol() string {
}
}

// Header returns the HTTP request headers.
func (r Request) Header() http.Header {
return r.Request.Header
}

// TLS returns the TLS connection state, if any. It may be nil if the connection
// is not using TLS.
func (r Request) TLS() *tls.ConnectionState {
return r.Request.TLS
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
// In addition to rejecting unauthenticated requests, it can optionally attach
// arbitrary information about the authenticated identity to the context.
Expand Down Expand Up @@ -175,7 +135,7 @@ func NewMiddleware(auth AuthFunc, opts ...connect.HandlerOption) *Middleware {
func (m *Middleware) Wrap(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
info, err := m.auth(ctx, Request{Request: request})
info, err := m.auth(ctx, request)
if err != nil {
_ = m.errW.Write(writer, request, err)
return
Expand Down
4 changes: 2 additions & 2 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func assertInfo(ctx context.Context, tb testing.TB) {
}
}

func authenticate(_ context.Context, req authn.Request) (any, error) {
parts := strings.SplitN(req.Header().Get("Authorization"), " ", 2)
func authenticate(_ context.Context, req *http.Request) (any, error) {
parts := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
if len(parts) < 2 || parts[0] != "Bearer" {
err := authn.Errorf("expected Bearer authentication scheme")
err.Meta().Set("WWW-Authenticate", "Bearer")
Expand Down
6 changes: 3 additions & 3 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Example_basicAuth() {
// works similarly.

// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
authenticate := func(_ context.Context, req *http.Request) (any, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, authn.Errorf("invalid authorization")
Expand Down Expand Up @@ -95,8 +95,8 @@ func Example_basicAuth() {
func Example_mutualTLS() {
// This example shows how to use this package with mutual TLS.
// First, we define our authentication logic and use it to build middleware.
authenticate := func(_ context.Context, req authn.Request) (any, error) {
tls := req.TLS()
authenticate := func(_ context.Context, req *http.Request) (any, error) {
tls := req.TLS
if tls == nil {
return nil, authn.Errorf("TLS required")
}
Expand Down

0 comments on commit 87caeaf

Please sign in to comment.