From e5972f0edc92cbe8b5d3f18c7105782f22a99523 Mon Sep 17 00:00:00 2001 From: Olakunle Arewa Date: Tue, 8 Sep 2020 23:51:30 +0100 Subject: [PATCH 1/4] feat: created better and simpler encoders/decoders for JWT --- go.mod | 1 + jwt/jwt.go | 107 ++++++++++++++++++++++++++++++-------- jwt/jwt_test.go | 135 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 195 insertions(+), 48 deletions(-) diff --git a/go.mod b/go.mod index a00063f..d9381c3 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/joho/godotenv v1.3.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/mitchellh/mapstructure v1.3.3 + github.com/pkg/errors v0.9.1 github.com/rs/cors v1.7.0 github.com/rs/zerolog v1.19.0 github.com/satori/go.uuid v1.2.0 diff --git a/jwt/jwt.go b/jwt/jwt.go index 13d6944..4e8efa9 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -1,12 +1,13 @@ package jwt import ( - "errors" "fmt" + "reflect" "time" "github.com/dgrijalva/jwt-go" "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" ) const expiredErr = jwt.ValidationErrorExpired | jwt.ValidationErrorNotValidYet @@ -17,62 +18,126 @@ var ( ErrNoClaims = errors.New("There are no claims in your token") ) -// Encode creates a JWT token for some given struct using the HMAC algorithm. It uses -// the key to separate the data stored from the JWT claims to prevent clashes. This works -// best with structs, please use the jwt-go library directly for primitive types. -func Encode(key string, secret []byte, t time.Duration, v interface{}) (string, error) { +// Encodes generates and signs a JWT token for the given payload using the HMAC algorithm. +func Encode(secret []byte, t time.Duration, payload map[string]interface{}) (string, error) { jwtClaims := jwt.MapClaims{ "iat": time.Now().Unix(), "exp": time.Now().Add(t).Unix(), } - // save data using key - jwtClaims[key] = v + for k, v := range payload { + jwtClaims[k] = v + } token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims) return token.SignedString(secret) } -// Decode verifies a JWT token and extract the claims stored at key. Note that it expects v -// to be a pointer to a struct(based off what Encode does.) -func Decode(key string, secret []byte, tokenBytes []byte, v interface{}) error { - token, err := jwt.Parse(string(tokenBytes), func(token *jwt.Token) (interface{}, error) { +// EncodeStruct generates a JWT token for the given struct using the HMAC algorithm. +func EncodeStruct(secret []byte, t time.Duration, v interface{}) (string, error) { + payload := make(map[string]interface{}) + + r := reflect.ValueOf(v) + if r.Kind() == reflect.Ptr { + r = r.Elem() + } + + // we only accept structs + if r.Kind() != reflect.Struct { + return "", errors.Errorf("cannot not encode child struct; got %T", v) + } + + typ := r.Type() + for i := 0; i < r.NumField(); i++ { + ft := typ.Field(i) + + // use json tag if available + n := ft.Tag.Get("json") + if n == "" { + n = ft.Name + } + + payload[n] = r.Field(i).Interface() + } + + return Encode(secret, t, payload) +} + +// EncodeEmbedded attaches the payload as an entry to the final claim using the key +// `claim` to prevent clashes with JWT field names. +func EncodeEmbedded(secret []byte, t time.Duration, v interface{}) (string, error) { + return Encode(secret, t, map[string]interface{}{"claim": v}) +} + +// Decode validates and parses the given JWT token into a map +func Decode(secret []byte, token []byte) (map[string]interface{}, error) { + jwtToken, err := jwt.Parse(string(token), func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return secret, nil }) - if token == nil || !token.Valid { + if jwtToken == nil || !jwtToken.Valid { if verr, ok := err.(*jwt.ValidationError); ok { switch { case verr.Errors&jwt.ValidationErrorMalformed != 0: - return ErrInvalidToken + return nil, ErrInvalidToken case verr.Errors&expiredErr != 0: - return ErrJWTExpired + return nil, ErrJWTExpired default: - return err + return nil, err } + } else { + return nil, err } } - claims, ok := token.Claims.(jwt.MapClaims) + claims, ok := jwtToken.Claims.(jwt.MapClaims) if !ok { - return errors.New("could not convert JWT to map claims") + return nil, errors.New("could not convert JWT to map claims") } - // ignore tokens without claim data - if claims[key] == nil { - return nil + return claims, nil +} + +// DecodeStruct validates and parses a JWT token into a struct. +func DecodeStruct(secret []byte, tokenBytes []byte, v interface{}) error { + claims, err := Decode(secret, tokenBytes) + if err != nil { + return err } // convert claims data map to struct config := &mapstructure.DecoderConfig{Result: v, TagName: `json`} decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return errors.Wrap(err, "could not convert claims to struct") + } + + return decoder.Decode(claims) +} + +// DecodeEmbedded validates and parses a JWT token into a struct. It expects the +// struct's payload to be attached to the key `claim` of the actual JWT claim. Note that +// the struct should have json tags +func DecodeEmbedded(secret []byte, tokenBytes []byte, v interface{}) error { + claims, err := Decode(secret, tokenBytes) if err != nil { return err } - return decoder.Decode(claims[key]) + if claims["claim"] == nil { + return nil + } + + // convert claims data map to struct + config := &mapstructure.DecoderConfig{Result: v, TagName: `json`} + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return errors.Wrap(err, "could not convert claims to struct") + } + + return decoder.Decode(claims["claim"]) } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 311cdb2..6b1d504 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -1,8 +1,11 @@ package jwt import ( + "fmt" "testing" "time" + + "syreclabs.com/go/faker" ) type jwtStruct struct { @@ -10,48 +13,126 @@ type jwtStruct struct { } func TestEncode(t *testing.T) { - jwt := jwtStruct{Name: "Olakunkle"} - token, err := Encode("claims", []byte("mysecret"), time.Minute, jwt) - if err != nil { - t.Fatal(err) - } - - if token == "" { - t.Error("Expected EncodeJWT to generate a token") - } -} - -func TestDecode(t *testing.T) { secret := []byte("test-secret") - t.Run("should decode token generated by encode", func(t *testing.T) { - jwt := jwtStruct{Name: "testName"} - token, err := Encode("claims", secret, time.Minute, jwt) + t.Run("should encode a map", func(t *testing.T) { + payload := map[string]interface{}{"name": faker.Name().FirstName()} + token, err := Encode(secret, time.Minute, payload) if err != nil { t.Fatal(err) } - var loaded jwtStruct - if err := Decode("claims", secret, []byte(token), &loaded); err != nil { + if token == "" { + t.Error("Expected a token, got an empty string") + } + + parsed, err := Decode(secret, []byte(token)) + if err != nil { t.Fatal(err) } - if loaded.Name != "testName" { - t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name) + if parsed["name"] != payload["name"] { + t.Errorf("Expected the parsed name to be %s, got %s", payload["name"], parsed["name"]) } }) - t.Run("should decode token generated externally", func(t *testing.T) { - // don't use this token it would fail. - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiY2xhaW0iOnsibmFtZSI6InRva2VuIn0sImV4cCI6MTU5OTUyODEwNSwiaWF0IjoxNTE2MjM5MDIyfQ.-u5OhBHdq7L_wBlI2cBjstOXyhRS2m27bx1tLN869so" - - var loaded jwtStruct - if err := Decode("claim", secret, []byte(token), &loaded); err != nil { + t.Run("should fail with an error", func(t *testing.T) { + payload := map[string]interface{}{"name": faker.Name().FirstName()} + token, err := Encode(secret, time.Second, payload) + if err != nil { t.Fatal(err) } - if loaded.Name != "token" { - t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name) + time.Sleep(time.Second * 2) + + _, err = Decode(secret, []byte(token)) + if err == nil { + t.Fatal("Expected Decode to fail with non-nil error") + } + + if err != ErrJWTExpired { + t.Errorf("Expected Decode to fail with ErrJWTExpired, failed with %v", err) } }) } + +func TestEncodeStruct(t *testing.T) { + secret := []byte("test-secret") + payload := jwtStruct{faker.Name().FirstName()} + + token, err := EncodeStruct(secret, time.Second, payload) + if err != nil { + t.Fatal(err) + } + + if token == "" { + t.Error("Expected a token, got an empty string") + } + + var parsed jwtStruct + err = DecodeStruct(secret, []byte(token), &parsed) + if err != nil { + t.Fatal(err) + } + + if parsed.Name != payload.Name { + t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, parsed.Name) + } +} + +func TestEncodeEmbedded(t *testing.T) { + secret := []byte("test-secret") + payload := jwtStruct{faker.Name().FirstName()} + + token, err := EncodeEmbedded(secret, time.Second, payload) + if err != nil { + t.Fatal(err) + } + + if token == "" { + t.Error("Expected a token, got an empty string") + } + + parsed, err := Decode(secret, []byte(token)) + if err != nil { + t.Fatal(err) + } + + if parsed["claim"] == nil { + t.Fatal("Expected the claim to be non-nil") + } + + data, ok := parsed["claim"].(map[string]interface{}) + fmt.Println(data) + if !ok { + t.Fatalf("Expected claim to be a map of string to string, got %T", parsed["claim"]) + } + + if data["name"] != payload.Name { + t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, data["name"]) + } +} + +func TestDecodeEmbedded(t *testing.T) { + secret := []byte("test-secret") + payload := jwtStruct{faker.Name().FirstName()} + + token, err := EncodeEmbedded(secret, time.Second, payload) + if err != nil { + t.Fatal(err) + } + + if token == "" { + t.Error("Expected a token, got an empty string") + } + + var parsed jwtStruct + err = DecodeEmbedded(secret, []byte(token), &parsed) + if err != nil { + t.Fatal(err) + } + + if parsed.Name != payload.Name { + t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, parsed.Name) + } +} From 8bf556644a51999c97282872b89f309c7de59670 Mon Sep 17 00:00:00 2001 From: Olakunle Arewa Date: Sat, 12 Sep 2020 01:48:08 +0100 Subject: [PATCH 2/4] ref: removed complex middleware for controller recovery --- errors.go | 48 ++------------------------------------------ jwt/jwt.go | 2 +- middleware/errors.go | 38 ++++++++++++----------------------- responses.go | 21 +++++++++++-------- 4 files changed, 29 insertions(+), 80 deletions(-) diff --git a/errors.go b/errors.go index 592ed32..2f260e5 100644 --- a/errors.go +++ b/errors.go @@ -1,13 +1,6 @@ package anansi -import ( - "fmt" - "net/http" - "os" - "runtime/debug" - - "github.com/rs/zerolog" -) +import "fmt" // APIError is a struct describing an error type APIError struct { @@ -18,43 +11,6 @@ type APIError struct { } // implements the error interface -func (e APIError) Error() string { return e.Message } +func (e APIError) Error() string { return fmt.Sprintf("%s: %v", e.Message, e.Err) } func (e APIError) Unwrap() error { return e.Err } - -// Recoverer creates a middleware that handles panics from chi controllers. It uses -// the passed interpreters to try to convert errors to APIErrors where possible -// otherwise it returns a 500 error. When the panic is an APIError or is interpreted -// as one, it sends a response with the right error code. -// TODO: add support for wrapped errors in APIError. -func Recoverer(env string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - defer func() { - if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler { - - log := zerolog.Ctx(r.Context()) - if log != nil { - err := rvr.(error) // kill yourself - log.Err(err).Msg("") - } else { - fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) - } - - if e, ok := rvr.(APIError); ok { - SendError(r, w, e) - } else { - if env == "dev" || env == "test" { - debug.PrintStack() - } - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } - } - }() - - next.ServeHTTP(w, r) - } - - return http.HandlerFunc(fn) - } -} diff --git a/jwt/jwt.go b/jwt/jwt.go index 4e8efa9..a04bcc6 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -45,7 +45,7 @@ func EncodeStruct(secret []byte, t time.Duration, v interface{}) (string, error) // we only accept structs if r.Kind() != reflect.Struct { - return "", errors.Errorf("cannot not encode child struct; got %T", v) + return "", errors.Errorf("can only encode structs; got %T", v) } typ := r.Type() diff --git a/middleware/errors.go b/middleware/errors.go index 703e57a..1a5361b 100644 --- a/middleware/errors.go +++ b/middleware/errors.go @@ -12,37 +12,25 @@ import ( type Catch func(w http.ResponseWriter, r *http.Request, v interface{}) bool -// Recoverer creates a middleware that can detect APIError from panic. Internally uses -// RecoverWithHandler. +// RecovererWithHandler creates a middleware that handles panics from chi controllers. It +// automatically handles APIController errors passing the right error code. func Recoverer(env string) func(http.Handler) http.Handler { - return RecovererWithHandler(env, func(w http.ResponseWriter, r *http.Request, v interface{}) bool { - if e, ok := v.(anansi.APIError); ok { - anansi.SendError(r, w, e) - return true - } - return false - }) -} - -// RecovererWithHandler creates a middleware that handles panics from chi controllers. It uses -// the passed catch to interprete and handle the error(like send it as JSON) and returns 500 -// if it can't be interpreted. -// TODO: add support for wrapped errors in APIError. -func RecovererWithHandler(env string, catch Catch) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { defer func() { if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler { - // only do this if catch could not handle it. - if !catch(w, r, rvr) { - log := zerolog.Ctx(r.Context()) - if log == nil { - err := rvr.(error) // kill yourself - log.Err(err).Msg("") - } else { - fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) - } + log := zerolog.Ctx(r.Context()) + if log != nil { + err := rvr.(error) // kill yourself + log.Err(err).Msg("") + } else { + fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) + } + + if e, ok := rvr.(anansi.APIError); ok { + anansi.SendError(r, w, e) + } else { if env == "dev" || env == "test" { debug.PrintStack() } diff --git a/responses.go b/responses.go index c5fc041..e6c42ec 100644 --- a/responses.go +++ b/responses.go @@ -8,6 +8,17 @@ import ( "github.com/rs/zerolog" ) +func Send(w http.ResponseWriter, code int, data []byte) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + + w.WriteHeader(code) + _, err := w.Write(data) + if err != nil { + panic(err) + } +} + // SendSuccess sends a JSON success message with status code 200 func SendSuccess(r *http.Request, w http.ResponseWriter, v interface{}) { log := zerolog.Ctx(r.Context()) @@ -15,10 +26,7 @@ func SendSuccess(r *http.Request, w http.ResponseWriter, v interface{}) { log.Info().Msg("") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(raw) + Send(w, http.StatusOK, raw) } // SendError sends a JSON error message @@ -28,10 +36,7 @@ func SendError(r *http.Request, w http.ResponseWriter, err APIError) { log.Err(err).Msg("") - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(err.Code) - _, _ = w.Write(raw) + Send(w, err.Code, raw) } func getJSON(log *zerolog.Logger, v interface{}) []byte { From a410175688f75a33d50b098ebb0135f50ebaa73f Mon Sep 17 00:00:00 2001 From: Olakunle Arewa Date: Sat, 12 Sep 2020 01:54:49 +0100 Subject: [PATCH 3/4] ref: simplify session management API --- sessions.go | 106 +++++++++++++++++++++------------------------------- 1 file changed, 42 insertions(+), 64 deletions(-) diff --git a/sessions.go b/sessions.go index 9128467..bf4293e 100644 --- a/sessions.go +++ b/sessions.go @@ -18,13 +18,14 @@ var ( ) type SessionStore struct { - Store *tokens.Store - Timeout time.Duration - - Secret []byte + store *tokens.Store + timeout time.Duration + secret []byte + scheme string +} - Scheme string - ClaimsKey string +func NewSessionStore(secret []byte, scheme string, timeout time.Duration, store *tokens.Store) *SessionStore { + return &SessionStore{store, timeout, secret, scheme} } // Load retrieves a user's session object based on the session key from the Authorization @@ -32,30 +33,9 @@ type SessionStore struct { func (s *SessionStore) Load(r *http.Request, session interface{}) { var err error - authHeader := r.Header.Get("Authorization") - - // if there's no authorisation header, then there's no use going further - if len(authHeader) == 0 { - panic(APIError{ - Code: http.StatusUnauthorized, - Message: ErrHeaderNotSet.Error(), - Err: ErrHeaderNotSet, - }) - } - - splitAuth := strings.Split(authHeader, " ") + scheme, token := getAuthorization(r) - // we are expecting "${Scheme} ${Token}" - if len(splitAuth) != 2 { - panic(APIError{ - Code: http.StatusUnauthorized, - Message: ErrAuthorisationFormat.Error(), - Err: ErrAuthorisationFormat, - }) - } - - scheme := splitAuth[0] - if scheme != s.Scheme && scheme != "Bearer" { + if scheme != s.scheme && scheme != "bearer" { panic(APIError{ Code: http.StatusUnauthorized, Message: ErrUnsupportedScheme.Error(), @@ -63,9 +43,7 @@ func (s *SessionStore) Load(r *http.Request, session interface{}) { }) } - token := splitAuth[1] - - if len(token) == 0 { + if token == "" { panic(APIError{ Code: http.StatusUnauthorized, Message: ErrEmptyToken.Error(), @@ -73,10 +51,10 @@ func (s *SessionStore) Load(r *http.Request, session interface{}) { }) } - if scheme == "Bearer" { - err = s.Store.Extend(token, s.Timeout, session) + if scheme == "bearer" { + err = s.store.Extend(token, s.timeout, session) } else { - err = jwt.Decode(s.ClaimsKey, s.Secret, []byte(token), session) + err = jwt.DecodeEmbedded(s.secret, []byte(token), session) } if err != nil { @@ -92,31 +70,8 @@ func (s *SessionStore) Load(r *http.Request, session interface{}) { func (s *SessionStore) Headless() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var session interface{} - - authHeader := r.Header.Get("Authorization") - // if there's no authorisation header, then there's no use going further - if len(authHeader) == 0 { - panic(APIError{ - Code: http.StatusUnauthorized, - Message: ErrHeaderNotSet.Error(), - Err: ErrHeaderNotSet, - }) - } - - splitAuth := strings.Split(authHeader, " ") - - // we are expecting "${Scheme} ${Token}" - if len(splitAuth) != 2 { - panic(APIError{ - Code: http.StatusUnauthorized, - Message: ErrAuthorisationFormat.Error(), - Err: ErrAuthorisationFormat, - }) - } - - scheme := splitAuth[0] - if scheme != s.Scheme { + scheme, token := getAuthorization(r) + if scheme != s.scheme { panic(APIError{ Code: http.StatusUnauthorized, Message: ErrUnsupportedScheme.Error(), @@ -124,9 +79,7 @@ func (s *SessionStore) Headless() func(http.Handler) http.Handler { }) } - token := splitAuth[1] - - if len(token) == 0 { + if token == "" { panic(APIError{ Code: http.StatusUnauthorized, Message: ErrEmptyToken.Error(), @@ -134,7 +87,8 @@ func (s *SessionStore) Headless() func(http.Handler) http.Handler { }) } - if err := jwt.Decode(s.ClaimsKey, s.Secret, []byte(token), &session); err != nil { + // read and discard session data + if err := jwt.DecodeEmbedded(s.secret, []byte(token), &struct{}{}); err != nil { panic(APIError{ Code: http.StatusUnauthorized, Message: err.Error(), @@ -146,3 +100,27 @@ func (s *SessionStore) Headless() func(http.Handler) http.Handler { }) } } + +func getAuthorization(r *http.Request) (scheme, token string) { + authHeader := r.Header.Get("Authorization") + + if authHeader == "" { + panic(APIError{ + Code: http.StatusUnauthorized, + Message: ErrHeaderNotSet.Error(), + Err: ErrHeaderNotSet, + }) + } + + splitAuth := strings.Split(strings.TrimSpace(authHeader), " ") + + if len(splitAuth) != 2 { + panic(APIError{ + Code: http.StatusUnauthorized, + Message: ErrAuthorisationFormat.Error(), + Err: ErrAuthorisationFormat, + }) + } + + return strings.ToLower(splitAuth[0]), splitAuth[1] +} From c7d0c8f8494f9c906d476306c5cd982f38d8de67 Mon Sep 17 00:00:00 2001 From: Olakunle Arewa Date: Sat, 12 Sep 2020 02:06:12 +0100 Subject: [PATCH 4/4] ref: ReadJSON now destroys the body of the request --- requests.go | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/requests.go b/requests.go index 09b8e26..5f7cb6e 100644 --- a/requests.go +++ b/requests.go @@ -16,7 +16,7 @@ import ( "github.com/mitchellh/mapstructure" ) -// ReadBody extracts the bytes in a request body without destroying the contents of the body +// ReadBody extracts the bytes in a request body without destroying the contents of the body func ReadBody(r *http.Request) []byte { var buffer bytes.Buffer @@ -33,8 +33,8 @@ func ReadBody(r *http.Request) []byte { return body } -// ReadJSON decodes the JSON body of the request without destroying the request and -// validates it. If the content type is not JSON it fails with a 415. Otherwise it fails +// ReadJSON decodes the JSON body of the request and destroys to prevent possible issues with +// writing a response. If the content type is not JSON it fails with a 415. Otherwise it fails // with a 400 on validation errors. func ReadJSON(r *http.Request, v interface{}) { // make sure we are reading a JSON type @@ -46,14 +46,7 @@ func ReadJSON(r *http.Request, v interface{}) { }) } - // copy request body to in memory buffer while being read - var buffer bytes.Buffer - bodyReader := io.TeeReader(r.Body, &buffer) - - // make sure others can read the body - r.Body = ioutil.NopCloser(&buffer) - - err := json.NewDecoder(bodyReader).Decode(v) + err := json.NewDecoder(r.Body).Decode(v) switch { case err == io.EOF: // tell the user all the required attributes @@ -85,6 +78,8 @@ func ReadJSON(r *http.Request, v interface{}) { } } +// ReadQuery reads the requests URL query parameters into a struct. +// It doesn't support multi-value parameters func ReadQuery(r *http.Request, v interface{}) { raw := r.URL.Query() qMap := make(map[string]string)