Skip to content

Commit

Permalink
Merge pull request #36 from tsaron/ref/tsaron-anansi
Browse files Browse the repository at this point in the history
Finally refactor for anansi
  • Loading branch information
noxecane authored Sep 12, 2020
2 parents fba17e8 + c7d0c8f commit 047177d
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 202 deletions.
48 changes: 2 additions & 46 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
package anansi

import (
"fmt"
"net/http"
"os"
"runtime/debug"

"github.com/rs/zerolog"
)
import "fmt"

// APIError is a struct describing an error
type APIError struct {
Expand All @@ -18,43 +11,6 @@ type APIError struct {
}

// implements the error interface
func (e APIError) Error() string { return e.Message }
func (e APIError) Error() string { return fmt.Sprintf("%s: %v", e.Message, e.Err) }

func (e APIError) Unwrap() error { return e.Err }

// Recoverer creates a middleware that handles panics from chi controllers. It uses
// the passed interpreters to try to convert errors to APIErrors where possible
// otherwise it returns a 500 error. When the panic is an APIError or is interpreted
// as one, it sends a response with the right error code.
// TODO: add support for wrapped errors in APIError.
func Recoverer(env string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler {

log := zerolog.Ctx(r.Context())
if log != nil {
err := rvr.(error) // kill yourself
log.Err(err).Msg("")
} else {
fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr)
}

if e, ok := rvr.(APIError); ok {
SendError(r, w, e)
} else {
if env == "dev" || env == "test" {
debug.PrintStack()
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
}()

next.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/joho/godotenv v1.3.0
github.com/kelseyhightower/envconfig v1.4.0
github.com/mitchellh/mapstructure v1.3.3
github.com/pkg/errors v0.9.1
github.com/rs/cors v1.7.0
github.com/rs/zerolog v1.19.0
github.com/satori/go.uuid v1.2.0
Expand Down
107 changes: 86 additions & 21 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package jwt

import (
"errors"
"fmt"
"reflect"
"time"

"github.com/dgrijalva/jwt-go"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
)

const expiredErr = jwt.ValidationErrorExpired | jwt.ValidationErrorNotValidYet
Expand All @@ -17,62 +18,126 @@ var (
ErrNoClaims = errors.New("There are no claims in your token")
)

// Encode creates a JWT token for some given struct using the HMAC algorithm. It uses
// the key to separate the data stored from the JWT claims to prevent clashes. This works
// best with structs, please use the jwt-go library directly for primitive types.
func Encode(key string, secret []byte, t time.Duration, v interface{}) (string, error) {
// Encodes generates and signs a JWT token for the given payload using the HMAC algorithm.
func Encode(secret []byte, t time.Duration, payload map[string]interface{}) (string, error) {
jwtClaims := jwt.MapClaims{
"iat": time.Now().Unix(),
"exp": time.Now().Add(t).Unix(),
}

// save data using key
jwtClaims[key] = v
for k, v := range payload {
jwtClaims[k] = v
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims)

return token.SignedString(secret)
}

// Decode verifies a JWT token and extract the claims stored at key. Note that it expects v
// to be a pointer to a struct(based off what Encode does.)
func Decode(key string, secret []byte, tokenBytes []byte, v interface{}) error {
token, err := jwt.Parse(string(tokenBytes), func(token *jwt.Token) (interface{}, error) {
// EncodeStruct generates a JWT token for the given struct using the HMAC algorithm.
func EncodeStruct(secret []byte, t time.Duration, v interface{}) (string, error) {
payload := make(map[string]interface{})

r := reflect.ValueOf(v)
if r.Kind() == reflect.Ptr {
r = r.Elem()
}

// we only accept structs
if r.Kind() != reflect.Struct {
return "", errors.Errorf("can only encode structs; got %T", v)
}

typ := r.Type()
for i := 0; i < r.NumField(); i++ {
ft := typ.Field(i)

// use json tag if available
n := ft.Tag.Get("json")
if n == "" {
n = ft.Name
}

payload[n] = r.Field(i).Interface()
}

return Encode(secret, t, payload)
}

// EncodeEmbedded attaches the payload as an entry to the final claim using the key
// `claim` to prevent clashes with JWT field names.
func EncodeEmbedded(secret []byte, t time.Duration, v interface{}) (string, error) {
return Encode(secret, t, map[string]interface{}{"claim": v})
}

// Decode validates and parses the given JWT token into a map
func Decode(secret []byte, token []byte) (map[string]interface{}, error) {
jwtToken, err := jwt.Parse(string(token), func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return secret, nil
})

if token == nil || !token.Valid {
if jwtToken == nil || !jwtToken.Valid {
if verr, ok := err.(*jwt.ValidationError); ok {
switch {
case verr.Errors&jwt.ValidationErrorMalformed != 0:
return ErrInvalidToken
return nil, ErrInvalidToken
case verr.Errors&expiredErr != 0:
return ErrJWTExpired
return nil, ErrJWTExpired
default:
return err
return nil, err
}
} else {
return nil, err
}
}

claims, ok := token.Claims.(jwt.MapClaims)
claims, ok := jwtToken.Claims.(jwt.MapClaims)
if !ok {
return errors.New("could not convert JWT to map claims")
return nil, errors.New("could not convert JWT to map claims")
}

// ignore tokens without claim data
if claims[key] == nil {
return nil
return claims, nil
}

// DecodeStruct validates and parses a JWT token into a struct.
func DecodeStruct(secret []byte, tokenBytes []byte, v interface{}) error {
claims, err := Decode(secret, tokenBytes)
if err != nil {
return err
}

// convert claims data map to struct
config := &mapstructure.DecoderConfig{Result: v, TagName: `json`}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return errors.Wrap(err, "could not convert claims to struct")
}

return decoder.Decode(claims)
}

// DecodeEmbedded validates and parses a JWT token into a struct. It expects the
// struct's payload to be attached to the key `claim` of the actual JWT claim. Note that
// the struct should have json tags
func DecodeEmbedded(secret []byte, tokenBytes []byte, v interface{}) error {
claims, err := Decode(secret, tokenBytes)
if err != nil {
return err
}

return decoder.Decode(claims[key])
if claims["claim"] == nil {
return nil
}

// convert claims data map to struct
config := &mapstructure.DecoderConfig{Result: v, TagName: `json`}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return errors.Wrap(err, "could not convert claims to struct")
}

return decoder.Decode(claims["claim"])
}
135 changes: 108 additions & 27 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,138 @@
package jwt

import (
"fmt"
"testing"
"time"

"syreclabs.com/go/faker"
)

type jwtStruct struct {
Name string `json:"name"`
}

func TestEncode(t *testing.T) {
jwt := jwtStruct{Name: "Olakunkle"}
token, err := Encode("claims", []byte("mysecret"), time.Minute, jwt)
if err != nil {
t.Fatal(err)
}

if token == "" {
t.Error("Expected EncodeJWT to generate a token")
}
}

func TestDecode(t *testing.T) {
secret := []byte("test-secret")

t.Run("should decode token generated by encode", func(t *testing.T) {
jwt := jwtStruct{Name: "testName"}
token, err := Encode("claims", secret, time.Minute, jwt)
t.Run("should encode a map", func(t *testing.T) {
payload := map[string]interface{}{"name": faker.Name().FirstName()}
token, err := Encode(secret, time.Minute, payload)
if err != nil {
t.Fatal(err)
}

var loaded jwtStruct
if err := Decode("claims", secret, []byte(token), &loaded); err != nil {
if token == "" {
t.Error("Expected a token, got an empty string")
}

parsed, err := Decode(secret, []byte(token))
if err != nil {
t.Fatal(err)
}

if loaded.Name != "testName" {
t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name)
if parsed["name"] != payload["name"] {
t.Errorf("Expected the parsed name to be %s, got %s", payload["name"], parsed["name"])
}
})

t.Run("should decode token generated externally", func(t *testing.T) {
// don't use this token it would fail.
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiY2xhaW0iOnsibmFtZSI6InRva2VuIn0sImV4cCI6MTU5OTUyODEwNSwiaWF0IjoxNTE2MjM5MDIyfQ.-u5OhBHdq7L_wBlI2cBjstOXyhRS2m27bx1tLN869so"

var loaded jwtStruct
if err := Decode("claim", secret, []byte(token), &loaded); err != nil {
t.Run("should fail with an error", func(t *testing.T) {
payload := map[string]interface{}{"name": faker.Name().FirstName()}
token, err := Encode(secret, time.Second, payload)
if err != nil {
t.Fatal(err)
}

if loaded.Name != "token" {
t.Errorf("Expected Name to be %s, got %s", "Olakunle", loaded.Name)
time.Sleep(time.Second * 2)

_, err = Decode(secret, []byte(token))
if err == nil {
t.Fatal("Expected Decode to fail with non-nil error")
}

if err != ErrJWTExpired {
t.Errorf("Expected Decode to fail with ErrJWTExpired, failed with %v", err)
}
})
}

func TestEncodeStruct(t *testing.T) {
secret := []byte("test-secret")
payload := jwtStruct{faker.Name().FirstName()}

token, err := EncodeStruct(secret, time.Second, payload)
if err != nil {
t.Fatal(err)
}

if token == "" {
t.Error("Expected a token, got an empty string")
}

var parsed jwtStruct
err = DecodeStruct(secret, []byte(token), &parsed)
if err != nil {
t.Fatal(err)
}

if parsed.Name != payload.Name {
t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, parsed.Name)
}
}

func TestEncodeEmbedded(t *testing.T) {
secret := []byte("test-secret")
payload := jwtStruct{faker.Name().FirstName()}

token, err := EncodeEmbedded(secret, time.Second, payload)
if err != nil {
t.Fatal(err)
}

if token == "" {
t.Error("Expected a token, got an empty string")
}

parsed, err := Decode(secret, []byte(token))
if err != nil {
t.Fatal(err)
}

if parsed["claim"] == nil {
t.Fatal("Expected the claim to be non-nil")
}

data, ok := parsed["claim"].(map[string]interface{})
fmt.Println(data)
if !ok {
t.Fatalf("Expected claim to be a map of string to string, got %T", parsed["claim"])
}

if data["name"] != payload.Name {
t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, data["name"])
}
}

func TestDecodeEmbedded(t *testing.T) {
secret := []byte("test-secret")
payload := jwtStruct{faker.Name().FirstName()}

token, err := EncodeEmbedded(secret, time.Second, payload)
if err != nil {
t.Fatal(err)
}

if token == "" {
t.Error("Expected a token, got an empty string")
}

var parsed jwtStruct
err = DecodeEmbedded(secret, []byte(token), &parsed)
if err != nil {
t.Fatal(err)
}

if parsed.Name != payload.Name {
t.Errorf("Expected the parsed name to be %s, got %s", payload.Name, parsed.Name)
}
}
Loading

0 comments on commit 047177d

Please sign in to comment.