diff --git a/auth/jwt.go b/auth/jwt.go index 1f5c50a..38f010e 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -1,6 +1,7 @@ package auth import ( + "encoding/json" "errors" "fmt" "time" @@ -15,27 +16,22 @@ var ( ErrInvalidToken = errors.New("Your token is an invalid JWT token") ) -type jwtClaims struct { - Data interface{} `json:"claim"` - jwt.Claims -} - // EncodeJWT creates a JWT token for some given struct using the HMAC algorithm. func EncodeJWT(secret []byte, t time.Duration, v interface{}) (string, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims{ - Data: v, - Claims: jwt.StandardClaims{ - ExpiresAt: time.Now().Add(t).Unix(), - }, + str, _ := json.Marshal(v) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "claims": string(str), + "iat": time.Now().Unix(), + "exp": time.Now().Add(t).Unix(), }) return token.SignedString(secret) } // DecodeJWT extracts a struct from a JWT token using the HMAC algorithm -func DecodeJWT(secret []byte, token []byte, v *interface{}) error { - claim := new(jwtClaims) - t, err := jwt.ParseWithClaims(string(token), claim, func(token *jwt.Token) (interface{}, error) { +func DecodeJWT(secret []byte, token []byte, v interface{}) error { + t, err := jwt.Parse(string(token), func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } @@ -55,7 +51,10 @@ func DecodeJWT(secret []byte, token []byte, v *interface{}) error { } } - *v = claim.Data - - return nil + if claims, ok := t.Claims.(jwt.MapClaims); !ok { + return errors.New("Could not convert JWT to map claims") + } else { + b := claims["claims"].(string) + return json.Unmarshal([]byte(b), v) + } } diff --git a/auth/jwt_test.go b/auth/jwt_test.go new file mode 100644 index 0000000..329bf95 --- /dev/null +++ b/auth/jwt_test.go @@ -0,0 +1,39 @@ +package auth + +import ( + "testing" + "time" +) + +type jwtStruct struct { + Name string `json:"name"` +} + +func TestEncodeJWT(t *testing.T) { + jwt := jwtStruct{Name: "Olakunkle"} + token, err := EncodeJWT([]byte("mysecret"), time.Minute, jwt) + if err != nil { + t.Fatal(err) + } + + if token == "" { + t.Error("Expected EncodeJWT to generate a token") + } +} + +func TestDecodeJWT(t *testing.T) { + jwt := jwtStruct{Name: "Olakunle"} + token, err := EncodeJWT([]byte("mysecret"), time.Minute, jwt) + if err != nil { + t.Fatal(err) + } + + var loaded jwtStruct + if err := DecodeJWT([]byte("mysecret"), []byte(token), &loaded); err != nil { + t.Fatal(err) + } + + if loaded.Name != "Olakunle" { + t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name) + } +} diff --git a/auth/sessions.go b/auth/sessions.go index b3a3119..70d74c6 100644 --- a/auth/sessions.go +++ b/auth/sessions.go @@ -1,7 +1,6 @@ package auth import ( - "context" "errors" "net/http" "strings" @@ -14,18 +13,17 @@ var ( ErrAuthorisationFormat = errors.New("Your authorization header format is invalid") ErrUnsupportedScheme = errors.New("Your scheme is not supported") ErrEmptyToken = errors.New("There was no token supplied to the authorization header") + ErrHeaderNotSet = errors.New("Authorization header is not set") ) type SessionStore struct { store *TokenStore timeout time.Duration cookieKey string + scheme string } -type sessionKey struct{} -type errorKey struct{} - -func NewSessionStore(store *TokenStore, sCycle, cookieKey string) *SessionStore { +func NewSessionStore(store *TokenStore, sCycle, cookieKey, scheme string) *SessionStore { var timeout time.Duration var err error @@ -33,115 +31,142 @@ func NewSessionStore(store *TokenStore, sCycle, cookieKey string) *SessionStore panic(err) } - return &SessionStore{store, timeout, cookieKey} + return &SessionStore{store, timeout, cookieKey, scheme} } // Load retrieves a user's session object based on the session key from the Authorization // header or the session cookie and fails with an error if it faces any issue parsing any of them. -func (s *SessionStore) load(r *http.Request, w http.ResponseWriter, session *interface{}) error { - var token string +func (s *SessionStore) Load(r *http.Request, w http.ResponseWriter, session interface{}) { var err error var cookie *http.Cookie if cookie, err = r.Cookie(s.cookieKey); err != nil { authHeader := r.Header.Get("Authorization") + // if there's no authorisation header, then there's no use going further if len(authHeader) == 0 { - return nil + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrHeaderNotSet.Error(), + Err: ErrHeaderNotSet, + }) } splitAuth := strings.Split(authHeader, " ") // we are expecting "${Scheme} ${Token}" if len(splitAuth) != 2 { - return ErrAuthorisationFormat + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrAuthorisationFormat.Error(), + Err: ErrAuthorisationFormat, + }) } - scheme := strings.ToLower(splitAuth[0]) - if scheme != "bearer" && scheme != "headless" { - return ErrUnsupportedScheme + scheme := splitAuth[0] + if scheme != s.scheme && scheme != "Bearer" { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrUnsupportedScheme.Error(), + Err: ErrUnsupportedScheme, + }) } - token = splitAuth[1] + token := splitAuth[1] if len(token) == 0 { - return ErrEmptyToken + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrEmptyToken.Error(), + Err: ErrEmptyToken, + }) } - if scheme == "headless" { - err = DecodeJWT(s.store.secret, []byte(token), session) - return err - } else { + if scheme == "Bearer" { err = s.store.Refresh(token, s.timeout, session) - return err + } else { + err = DecodeJWT(s.store.secret, []byte(token), session) + } + + if err != nil { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: err.Error(), + Err: err, + }) } } else { - token = cookie.Value - if err = s.store.Refresh(token, s.timeout, session); err != nil { - return err + err = s.store.Refresh(cookie.Value, s.timeout, session) + + if err != nil { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: err.Error(), + Err: err, + }) } // extend the cookie's lifetime cookie.Expires = time.Now().Add(s.timeout) http.SetCookie(w, cookie) } - - if err = s.store.Refresh(token, s.timeout, session); err != nil { - return err - } - - return nil } // Secure loads a user session into the request context -func (s *SessionStore) Secure() func(http.Handler) http.Handler { +func (s *SessionStore) Headless() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var session interface{} - var ctx context.Context - if err := s.load(r, w, &session); err != nil { - ctx = context.WithValue(r.Context(), errorKey{}, err) - } else { - ctx = context.WithValue(r.Context(), sessionKey{}, session) + authHeader := r.Header.Get("Authorization") + // if there's no authorisation header, then there's no use going further + if len(authHeader) == 0 { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrHeaderNotSet.Error(), + Err: ErrHeaderNotSet, + }) } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - }) - } -} + splitAuth := strings.Split(authHeader, " ") -// Get retrieves a user session stored in the context, panics with an appropriate error if the -// error value has been set on the request context(for failed session loads). Make sure to use this -// for handlers protected by the Secure method. -func Get(r *http.Request) interface{} { - ctx := r.Context() - err := ctx.Value(errorKey{}).(error) - session := ctx.Value(sessionKey{}) - - if err != nil { - switch err { - case ErrAuthorisationFormat, ErrEmptyToken, ErrUnsupportedScheme, ErrInvalidToken, ErrJWTExpired, ErrTokenNotFound: - panic(anansi.APIError{ - Code: http.StatusUnauthorized, - Message: err.Error(), - }) - default: - panic(anansi.APIError{ - Code: http.StatusUnauthorized, - Message: "We could not load your session. Please reach out to support to check the problem", - Err: err, - }) - } - } + // we are expecting "${Scheme} ${Token}" + if len(splitAuth) != 2 { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrAuthorisationFormat.Error(), + Err: ErrAuthorisationFormat, + }) + } + + scheme := splitAuth[0] + if scheme != s.scheme { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrUnsupportedScheme.Error(), + Err: ErrUnsupportedScheme, + }) + } - if session == nil { - panic(anansi.APIError{ - Code: http.StatusUnauthorized, - Message: "We could not find a session for your request", + token := splitAuth[1] + + if len(token) == 0 { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: ErrEmptyToken.Error(), + Err: ErrEmptyToken, + }) + } + + if err := DecodeJWT(s.store.secret, []byte(token), session); err != nil { + panic(anansi.APIError{ + Code: http.StatusUnauthorized, + Message: err.Error(), + Err: err, + }) + } + + next.ServeHTTP(w, r) }) } - - return session } diff --git a/auth/token.go b/auth/token.go index a3e28b4..de8520f 100644 --- a/auth/token.go +++ b/auth/token.go @@ -74,6 +74,37 @@ func (ts *TokenStore) Refresh(token string, timeout time.Duration, data interfac return nil } +// Reset changes the contents of the token without changing it's TTL +func (ts *TokenStore) Reset(key string, timeout time.Duration, data interface{}) error { + var err error + var encoded []byte + var token string + + sig := hmac.New(sha256.New, ts.secret) + if _, err := sig.Write([]byte(key)); err != nil { + return err + } + + token = hex.EncodeToString(sig.Sum(nil)) + + if _, err = ts.redis.Get(token).Result(); err != nil { + return ErrTokenNotFound + } + + // we already know the key exists + ttl, _ := ts.redis.TTL(token).Result() + + // TODO: replace this something lighter and faster + if encoded, err = json.Marshal(data); err != nil { + return err + } + + if _, err = ts.redis.Set(token, encoded, ttl).Result(); err != nil { + return err + } + return nil +} + // Decommission loads the value referenced by the token and dispenses of the token, // making it unvailable for further use. func (ts *TokenStore) Decommission(token string, data interface{}) error { diff --git a/errors.go b/errors.go index 71fe551..b180c22 100644 --- a/errors.go +++ b/errors.go @@ -38,11 +38,11 @@ func Recoverer(next http.Handler) http.Handler { } else { fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) } - // debug.PrintStack() if e, ok := rvr.(APIError); ok { SendError(r, w, e) } else { + debug.PrintStack() http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } diff --git a/go.mod b/go.mod index c7ca7bc..b1dded7 100644 --- a/go.mod +++ b/go.mod @@ -14,5 +14,5 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 github.com/rs/cors v1.7.0 github.com/rs/zerolog v1.19.0 - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 ) diff --git a/go.sum b/go.sum index 663e809..10e55f1 100644 --- a/go.sum +++ b/go.sum @@ -329,6 +329,8 @@ golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 h1:DZhuSZLsGlFL4CmhA8BcRA0mnthyA/nZ00AqCUo7vHg= +golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=