Skip to content

Commit 4ffaa60

Browse files
committed
Improve auth flow error handling
Signed-off-by: Tomasz Kleczek <[email protected]>
1 parent d4bd371 commit 4ffaa60

File tree

4 files changed

+107
-86
lines changed

4 files changed

+107
-86
lines changed

server/handlers.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
158158
return
159159
}
160160
}
161-
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
161+
s.renderError(r, w, http.StatusBadRequest, "Connector ID does not match a valid Connector")
162162
return
163163
}
164164

@@ -187,21 +187,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
187187
authReq, err := s.parseAuthorizationRequest(r)
188188
if err != nil {
189189
s.logger.Errorf("Failed to parse authorization request: %v", err)
190-
status := http.StatusInternalServerError
191-
192-
// If this is an authErr, let's let it handle the error, or update the HTTP
193-
// status code
194-
if err, ok := err.(*authErr); ok {
195-
if handler, ok := err.Handle(); ok {
196-
// client_id and redirect_uri checked out and we can redirect back to
197-
// the client with the error.
198-
handler.ServeHTTP(w, r)
199-
return
200-
}
201-
status = err.Status()
190+
191+
switch authErr := err.(type) {
192+
case *redirectedAuthErr:
193+
authErr.Handler().ServeHTTP(w, r)
194+
case *displayedAuthErr:
195+
s.renderError(r, w, authErr.Status, err.Error())
196+
default:
197+
panic("unsupported error type")
202198
}
203199

204-
s.renderError(r, w, status, err.Error())
205200
return
206201
}
207202

@@ -770,7 +765,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
770765
case grantTypePassword:
771766
s.withClientFromStorage(w, r, s.handlePasswordGrant)
772767
default:
773-
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
768+
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
774769
}
775770
}
776771

server/oauth2.go

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,35 @@ import (
2929

3030
// TODO(ericchiang): clean this file up and figure out more idiomatic error handling.
3131

32-
// authErr is an error response to an authorization request.
3332
// See: https://tools.ietf.org/html/rfc6749#section-4.1.2.1
34-
type authErr struct {
33+
34+
// displayedAuthErr is an error that should be displayed to the user as a web page
35+
type displayedAuthErr struct {
36+
Status int
37+
Description string
38+
}
39+
40+
func (err *displayedAuthErr) Error() string {
41+
return err.Description
42+
}
43+
44+
func newDisplayedErr(status int, format string, a ...interface{}) *displayedAuthErr {
45+
return &displayedAuthErr{status, fmt.Sprintf(format, a...)}
46+
}
47+
48+
// redirectedAuthErr is an error that should be reported back to the client by 302 redirect
49+
type redirectedAuthErr struct {
3550
State string
3651
RedirectURI string
3752
Type string
3853
Description string
3954
}
4055

41-
func (err *authErr) Status() int {
42-
if err.State == errServerError {
43-
return http.StatusInternalServerError
44-
}
45-
return http.StatusBadRequest
46-
}
47-
48-
func (err *authErr) Error() string {
56+
func (err *redirectedAuthErr) Error() string {
4957
return err.Description
5058
}
5159

52-
func (err *authErr) Handle() (http.Handler, bool) {
53-
// Didn't get a valid redirect URI.
54-
if err.RedirectURI == "" {
55-
return nil, false
56-
}
57-
60+
func (err *redirectedAuthErr) Handler() http.Handler {
5861
hf := func(w http.ResponseWriter, r *http.Request) {
5962
v := url.Values{}
6063
v.Add("state", err.State)
@@ -70,7 +73,7 @@ func (err *authErr) Handle() (http.Handler, bool) {
7073
}
7174
http.Redirect(w, r, redirectURI, http.StatusSeeOther)
7275
}
73-
return http.HandlerFunc(hf), true
76+
return http.HandlerFunc(hf)
7477
}
7578

7679
func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) error {
@@ -102,7 +105,6 @@ const (
102105
errUnsupportedGrantType = "unsupported_grant_type"
103106
errInvalidGrant = "invalid_grant"
104107
errInvalidClient = "invalid_client"
105-
errInvalidConnectorID = "invalid_connector_id"
106108
)
107109

108110
const (
@@ -408,12 +410,12 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
408410
// parse the initial request from the OAuth2 client.
409411
func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) {
410412
if err := r.ParseForm(); err != nil {
411-
return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."}
413+
return nil, newDisplayedErr(http.StatusBadRequest, "Failed to parse request.")
412414
}
413415
q := r.Form
414416
redirectURI, err := url.QueryUnescape(q.Get("redirect_uri"))
415417
if err != nil {
416-
return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
418+
return nil, newDisplayedErr(http.StatusBadRequest, "No redirect_uri provided.")
417419
}
418420

419421
clientID := q.Get("client_id")
@@ -434,45 +436,44 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
434436
client, err := s.storage.GetClient(clientID)
435437
if err != nil {
436438
if err == storage.ErrNotFound {
437-
description := fmt.Sprintf("Invalid client_id (%q).", clientID)
438-
return nil, &authErr{"", "", errUnauthorizedClient, description}
439+
return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID)
439440
}
440441
s.logger.Errorf("Failed to get client: %v", err)
441-
return nil, &authErr{"", "", errServerError, ""}
442-
}
443-
444-
if connectorID != "" {
445-
connectors, err := s.storage.ListConnectors()
446-
if err != nil {
447-
return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
448-
}
449-
if !validateConnectorID(connectors, connectorID) {
450-
return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
451-
}
442+
return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.")
452443
}
453444

454445
if !validateRedirectURI(client, redirectURI) {
455-
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
456-
return nil, &authErr{"", "", errInvalidRequest, description}
446+
return nil, newDisplayedErr(http.StatusBadRequest, "Unregistered redirect_uri (%q).", redirectURI)
457447
}
458448
if redirectURI == deviceCallbackURI && client.Public {
459449
redirectURI = s.issuerURL.Path + deviceCallbackURI
460450
}
461451

462452
// From here on out, we want to redirect back to the client with an error.
463-
newErr := func(typ, format string, a ...interface{}) *authErr {
464-
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
453+
newRedirectedErr := func(typ, format string, a ...interface{}) *redirectedAuthErr {
454+
return &redirectedAuthErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
455+
}
456+
457+
if connectorID != "" {
458+
connectors, err := s.storage.ListConnectors()
459+
if err != nil {
460+
s.logger.Errorf("Failed to list connectors: %v", err)
461+
return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors")
462+
}
463+
if !validateConnectorID(connectors, connectorID) {
464+
return nil, newRedirectedErr(errInvalidRequest, "Invalid ConnectorID")
465+
}
465466
}
466467

467468
// dex doesn't support request parameter and must return request_not_supported error
468469
// https://openid.net/specs/openid-connect-core-1_0.html#6.1
469470
if q.Get("request") != "" {
470-
return nil, newErr(errRequestNotSupported, "Server does not support request parameter.")
471+
return nil, newRedirectedErr(errRequestNotSupported, "Server does not support request parameter.")
471472
}
472473

473474
if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain {
474475
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
475-
return nil, newErr(errInvalidRequest, description)
476+
return nil, newRedirectedErr(errInvalidRequest, description)
476477
}
477478

478479
var (
@@ -494,21 +495,21 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
494495

495496
isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
496497
if err != nil {
497-
return nil, newErr(errServerError, "Internal server error.")
498+
return nil, newRedirectedErr(errServerError, "Internal server error.")
498499
}
499500
if !isTrusted {
500501
invalidScopes = append(invalidScopes, scope)
501502
}
502503
}
503504
}
504505
if !hasOpenIDScope {
505-
return nil, newErr(errInvalidScope, `Missing required scope(s) ["openid"].`)
506+
return nil, newRedirectedErr(errInvalidScope, `Missing required scope(s) ["openid"].`)
506507
}
507508
if len(unrecognized) > 0 {
508-
return nil, newErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized)
509+
return nil, newRedirectedErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized)
509510
}
510511
if len(invalidScopes) > 0 {
511-
return nil, newErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes)
512+
return nil, newRedirectedErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes)
512513
}
513514

514515
var rt struct {
@@ -526,37 +527,37 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
526527
case responseTypeToken:
527528
rt.token = true
528529
default:
529-
return nil, newErr(errInvalidRequest, "Invalid response type %q", responseType)
530+
return nil, newRedirectedErr(errInvalidRequest, "Invalid response type %q", responseType)
530531
}
531532

532533
if !s.supportedResponseTypes[responseType] {
533-
return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
534+
return nil, newRedirectedErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
534535
}
535536
}
536537

537538
if len(responseTypes) == 0 {
538-
return nil, newErr(errInvalidRequest, "No response_type provided")
539+
return nil, newRedirectedErr(errInvalidRequest, "No response_type provided")
539540
}
540541

541542
if rt.token && !rt.code && !rt.idToken {
542543
// "token" can't be provided by its own.
543544
//
544545
// https://openid.net/specs/openid-connect-core-1_0.html#Authentication
545-
return nil, newErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'")
546+
return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'")
546547
}
547548
if !rt.code {
548549
// Either "id_token token" or "id_token" has been provided which implies the
549550
// implicit flow. Implicit flow requires a nonce value.
550551
//
551552
// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
552553
if nonce == "" {
553-
return nil, newErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.")
554+
return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.")
554555
}
555556
}
556557
if rt.token {
557558
if redirectURI == redirectURIOOB {
558559
err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
559-
return nil, newErr(errInvalidRequest, err)
560+
return nil, newRedirectedErr(errInvalidRequest, err)
560561
}
561562
}
562563

server/oauth2_test.go

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"strings"
1111
"testing"
1212

13-
"github.com/stretchr/testify/require"
1413
"gopkg.in/square/go-jose.v2"
1514

1615
"github.com/dexidp/dex/storage"
@@ -27,8 +26,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
2726

2827
queryParams map[string]string
2928

30-
wantErr bool
31-
exactError *authErr
29+
expectedError error
3230
}{
3331
{
3432
name: "normal request",
@@ -78,7 +76,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
7876
"response_type": "code",
7977
"scope": "openid email profile",
8078
},
81-
wantErr: true,
79+
expectedError: &displayedAuthErr{Status: http.StatusNotFound},
8280
},
8381
{
8482
name: "invalid redirect uri",
@@ -95,7 +93,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
9593
"response_type": "code",
9694
"scope": "openid email profile",
9795
},
98-
wantErr: true,
96+
expectedError: &displayedAuthErr{Status: http.StatusBadRequest},
9997
},
10098
{
10199
name: "implicit flow",
@@ -128,7 +126,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
128126
"response_type": "code id_token",
129127
"scope": "openid email profile",
130128
},
131-
wantErr: true,
129+
expectedError: &redirectedAuthErr{Type: errUnsupportedResponseType},
132130
},
133131
{
134132
name: "only token response type",
@@ -145,7 +143,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
145143
"response_type": "token",
146144
"scope": "openid email profile",
147145
},
148-
wantErr: true,
146+
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
149147
},
150148
{
151149
name: "choose connector_id",
@@ -197,7 +195,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
197195
"response_type": "code id_token",
198196
"scope": "openid email profile",
199197
},
200-
wantErr: true,
198+
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
201199
},
202200
{
203201
name: "PKCE code_challenge_method plain",
@@ -269,7 +267,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
269267
"code_challenge_method": "invalid_method",
270268
"scope": "openid email profile",
271269
},
272-
wantErr: true,
270+
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
273271
},
274272
{
275273
name: "No response type",
@@ -287,12 +285,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
287285
"code_challenge_method": "plain",
288286
"scope": "openid email profile",
289287
},
290-
wantErr: true,
291-
exactError: &authErr{
292-
RedirectURI: "https://example.com/bar",
293-
Type: "invalid_request",
294-
Description: "No response_type provided",
295-
},
288+
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
296289
},
297290
}
298291

@@ -321,13 +314,34 @@ func TestParseAuthorizationRequest(t *testing.T) {
321314
}
322315

323316
_, err := server.parseAuthorizationRequest(req)
324-
if tc.wantErr {
325-
require.Error(t, err)
326-
if tc.exactError != nil {
327-
require.Equal(t, tc.exactError, err)
317+
if tc.expectedError == nil {
318+
if err != nil {
319+
t.Errorf("%s: expected no error", tc.name)
328320
}
329321
} else {
330-
require.NoError(t, err)
322+
switch expectedErr := tc.expectedError.(type) {
323+
case *redirectedAuthErr:
324+
e, ok := err.(*redirectedAuthErr)
325+
if !ok {
326+
t.Fatalf("%s: expected redirectedAuthErr error", tc.name)
327+
}
328+
if e.Type != expectedErr.Type {
329+
t.Errorf("%s: expected error type %v, got %v", tc.name, expectedErr.Type, e.Type)
330+
}
331+
if e.RedirectURI != tc.queryParams["redirect_uri"] {
332+
t.Errorf("%s: expected error to be returned in redirect to %v", tc.name, tc.queryParams["redirect_uri"])
333+
}
334+
case *displayedAuthErr:
335+
e, ok := err.(*displayedAuthErr)
336+
if !ok {
337+
t.Fatalf("%s: expected displayedAuthErr error", tc.name)
338+
}
339+
if e.Status != expectedErr.Status {
340+
t.Errorf("%s: expected http status %v, got %v", tc.name, expectedErr.Status, e.Status)
341+
}
342+
default:
343+
t.Fatalf("%s: unsupported error type", tc.name)
344+
}
331345
}
332346
}()
333347
}

0 commit comments

Comments
 (0)