Skip to content

Commit

Permalink
Merge pull request #23 from tsaron/fix/jwt
Browse files Browse the repository at this point in the history
fixes JWT encoding and decoding
  • Loading branch information
noxecane authored Jul 26, 2020
2 parents afdaea9 + 5cae25b commit c921f23
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 87 deletions.
31 changes: 15 additions & 16 deletions auth/jwt.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"encoding/json"
"errors"
"fmt"
"time"
Expand All @@ -15,27 +16,22 @@ var (
ErrInvalidToken = errors.New("Your token is an invalid JWT token")
)

type jwtClaims struct {
Data interface{} `json:"claim"`
jwt.Claims
}

// EncodeJWT creates a JWT token for some given struct using the HMAC algorithm.
func EncodeJWT(secret []byte, t time.Duration, v interface{}) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims{
Data: v,
Claims: jwt.StandardClaims{
ExpiresAt: time.Now().Add(t).Unix(),
},
str, _ := json.Marshal(v)

token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"claims": string(str),
"iat": time.Now().Unix(),
"exp": time.Now().Add(t).Unix(),
})

return token.SignedString(secret)
}

// DecodeJWT extracts a struct from a JWT token using the HMAC algorithm
func DecodeJWT(secret []byte, token []byte, v *interface{}) error {
claim := new(jwtClaims)
t, err := jwt.ParseWithClaims(string(token), claim, func(token *jwt.Token) (interface{}, error) {
func DecodeJWT(secret []byte, token []byte, v interface{}) error {
t, err := jwt.Parse(string(token), func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
Expand All @@ -55,7 +51,10 @@ func DecodeJWT(secret []byte, token []byte, v *interface{}) error {
}
}

*v = claim.Data

return nil
if claims, ok := t.Claims.(jwt.MapClaims); !ok {
return errors.New("Could not convert JWT to map claims")
} else {
b := claims["claims"].(string)
return json.Unmarshal([]byte(b), v)
}
}
39 changes: 39 additions & 0 deletions auth/jwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package auth

import (
"testing"
"time"
)

type jwtStruct struct {
Name string `json:"name"`
}

func TestEncodeJWT(t *testing.T) {
jwt := jwtStruct{Name: "Olakunkle"}
token, err := EncodeJWT([]byte("mysecret"), time.Minute, jwt)
if err != nil {
t.Fatal(err)
}

if token == "" {
t.Error("Expected EncodeJWT to generate a token")
}
}

func TestDecodeJWT(t *testing.T) {
jwt := jwtStruct{Name: "Olakunle"}
token, err := EncodeJWT([]byte("mysecret"), time.Minute, jwt)
if err != nil {
t.Fatal(err)
}

var loaded jwtStruct
if err := DecodeJWT([]byte("mysecret"), []byte(token), &loaded); err != nil {
t.Fatal(err)
}

if loaded.Name != "Olakunle" {
t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name)
}
}
163 changes: 94 additions & 69 deletions auth/sessions.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package auth

import (
"context"
"errors"
"net/http"
"strings"
Expand All @@ -14,134 +13,160 @@ var (
ErrAuthorisationFormat = errors.New("Your authorization header format is invalid")
ErrUnsupportedScheme = errors.New("Your scheme is not supported")
ErrEmptyToken = errors.New("There was no token supplied to the authorization header")
ErrHeaderNotSet = errors.New("Authorization header is not set")
)

type SessionStore struct {
store *TokenStore
timeout time.Duration
cookieKey string
scheme string
}

type sessionKey struct{}
type errorKey struct{}

func NewSessionStore(store *TokenStore, sCycle, cookieKey string) *SessionStore {
func NewSessionStore(store *TokenStore, sCycle, cookieKey, scheme string) *SessionStore {
var timeout time.Duration
var err error

if timeout, err = time.ParseDuration(sCycle); err != nil {
panic(err)
}

return &SessionStore{store, timeout, cookieKey}
return &SessionStore{store, timeout, cookieKey, scheme}
}

// Load retrieves a user's session object based on the session key from the Authorization
// header or the session cookie and fails with an error if it faces any issue parsing any of them.
func (s *SessionStore) load(r *http.Request, w http.ResponseWriter, session *interface{}) error {
var token string
func (s *SessionStore) Load(r *http.Request, w http.ResponseWriter, session interface{}) {
var err error
var cookie *http.Cookie

if cookie, err = r.Cookie(s.cookieKey); err != nil {
authHeader := r.Header.Get("Authorization")

// if there's no authorisation header, then there's no use going further
if len(authHeader) == 0 {
return nil
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrHeaderNotSet.Error(),
Err: ErrHeaderNotSet,
})
}

splitAuth := strings.Split(authHeader, " ")

// we are expecting "${Scheme} ${Token}"
if len(splitAuth) != 2 {
return ErrAuthorisationFormat
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrAuthorisationFormat.Error(),
Err: ErrAuthorisationFormat,
})
}

scheme := strings.ToLower(splitAuth[0])
if scheme != "bearer" && scheme != "headless" {
return ErrUnsupportedScheme
scheme := splitAuth[0]
if scheme != s.scheme && scheme != "Bearer" {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrUnsupportedScheme.Error(),
Err: ErrUnsupportedScheme,
})
}

token = splitAuth[1]
token := splitAuth[1]

if len(token) == 0 {
return ErrEmptyToken
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrEmptyToken.Error(),
Err: ErrEmptyToken,
})
}

if scheme == "headless" {
err = DecodeJWT(s.store.secret, []byte(token), session)
return err
} else {
if scheme == "Bearer" {
err = s.store.Refresh(token, s.timeout, session)
return err
} else {
err = DecodeJWT(s.store.secret, []byte(token), session)
}

if err != nil {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: err.Error(),
Err: err,
})
}
} else {
token = cookie.Value
if err = s.store.Refresh(token, s.timeout, session); err != nil {
return err
err = s.store.Refresh(cookie.Value, s.timeout, session)

if err != nil {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: err.Error(),
Err: err,
})
}

// extend the cookie's lifetime
cookie.Expires = time.Now().Add(s.timeout)
http.SetCookie(w, cookie)
}

if err = s.store.Refresh(token, s.timeout, session); err != nil {
return err
}

return nil
}

// Secure loads a user session into the request context
func (s *SessionStore) Secure() func(http.Handler) http.Handler {
func (s *SessionStore) Headless() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var session interface{}
var ctx context.Context

if err := s.load(r, w, &session); err != nil {
ctx = context.WithValue(r.Context(), errorKey{}, err)
} else {
ctx = context.WithValue(r.Context(), sessionKey{}, session)
authHeader := r.Header.Get("Authorization")
// if there's no authorisation header, then there's no use going further
if len(authHeader) == 0 {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrHeaderNotSet.Error(),
Err: ErrHeaderNotSet,
})
}

r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
splitAuth := strings.Split(authHeader, " ")

// Get retrieves a user session stored in the context, panics with an appropriate error if the
// error value has been set on the request context(for failed session loads). Make sure to use this
// for handlers protected by the Secure method.
func Get(r *http.Request) interface{} {
ctx := r.Context()
err := ctx.Value(errorKey{}).(error)
session := ctx.Value(sessionKey{})

if err != nil {
switch err {
case ErrAuthorisationFormat, ErrEmptyToken, ErrUnsupportedScheme, ErrInvalidToken, ErrJWTExpired, ErrTokenNotFound:
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: err.Error(),
})
default:
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: "We could not load your session. Please reach out to support to check the problem",
Err: err,
})
}
}
// we are expecting "${Scheme} ${Token}"
if len(splitAuth) != 2 {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrAuthorisationFormat.Error(),
Err: ErrAuthorisationFormat,
})
}

scheme := splitAuth[0]
if scheme != s.scheme {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrUnsupportedScheme.Error(),
Err: ErrUnsupportedScheme,
})
}

if session == nil {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: "We could not find a session for your request",
token := splitAuth[1]

if len(token) == 0 {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: ErrEmptyToken.Error(),
Err: ErrEmptyToken,
})
}

if err := DecodeJWT(s.store.secret, []byte(token), session); err != nil {
panic(anansi.APIError{
Code: http.StatusUnauthorized,
Message: err.Error(),
Err: err,
})
}

next.ServeHTTP(w, r)
})
}

return session
}
31 changes: 31 additions & 0 deletions auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,37 @@ func (ts *TokenStore) Refresh(token string, timeout time.Duration, data interfac
return nil
}

// Reset changes the contents of the token without changing it's TTL
func (ts *TokenStore) Reset(key string, timeout time.Duration, data interface{}) error {
var err error
var encoded []byte
var token string

sig := hmac.New(sha256.New, ts.secret)
if _, err := sig.Write([]byte(key)); err != nil {
return err
}

token = hex.EncodeToString(sig.Sum(nil))

if _, err = ts.redis.Get(token).Result(); err != nil {
return ErrTokenNotFound
}

// we already know the key exists
ttl, _ := ts.redis.TTL(token).Result()

// TODO: replace this something lighter and faster
if encoded, err = json.Marshal(data); err != nil {
return err
}

if _, err = ts.redis.Set(token, encoded, ttl).Result(); err != nil {
return err
}
return nil
}

// Decommission loads the value referenced by the token and dispenses of the token,
// making it unvailable for further use.
func (ts *TokenStore) Decommission(token string, data interface{}) error {
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ func Recoverer(next http.Handler) http.Handler {
} else {
fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr)
}
// debug.PrintStack()

if e, ok := rvr.(APIError); ok {
SendError(r, w, e)
} else {
debug.PrintStack()
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ require (
github.com/kelseyhightower/envconfig v1.4.0
github.com/rs/cors v1.7.0
github.com/rs/zerolog v1.19.0
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899
)
Loading

0 comments on commit c921f23

Please sign in to comment.