-
Notifications
You must be signed in to change notification settings - Fork 5
/
util.go
144 lines (134 loc) · 4.03 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
)
// b64 is the base64 encoding config used for encoding/decoding jwt
// parts.
var b64 = base64.URLEncoding.WithPadding(base64.NoPadding)
// getFieldWithTag lookups jwt tag, with specified tagName on obj, returning
// its reflected value.
func getFieldWithTag(obj interface{}, tagName string) *reflect.Value {
objVal := reflect.ValueOf(obj)
if objVal.Kind() != reflect.Struct {
objVal = objVal.Elem()
}
for i := 0; i < objVal.NumField(); i++ {
fieldType := objVal.Type().Field(i)
if tagName == fieldType.Tag.Get("jwt") {
field := objVal.Field(i)
return &field
}
}
return nil
}
// decodeToObjOrFieldWithTag decodes the buf into obj's field having the
// specified jwt tagName. If the provided obj's has the same type as
// defaultObj, then the obj is set to the defaultObj, otherwise an attempt is
// made to json.Decode the buf into obj.
func decodeToObjOrFieldWithTag(buf []byte, obj interface{}, tagName string, defaultObj interface{}) error {
// reflect values
objValElem := reflect.ValueOf(obj).Elem()
defaultObjValElem := reflect.ValueOf(defaultObj).Elem()
// first check type, if same type, then set
if objValElem.Type() == defaultObjValElem.Type() {
objValElem.Set(defaultObjValElem)
return nil
}
// get field with specified jwt tagName (if any)
fieldVal := getFieldWithTag(obj, tagName)
if fieldVal != nil {
// check field type and defaultObj type, if same, set
if fieldVal.Type() == defaultObjValElem.Type() {
fieldVal.Set(defaultObjValElem)
return nil
}
// otherwise, assign obj address of field
obj = fieldVal.Addr().Interface()
}
// decode json
d := json.NewDecoder(bytes.NewBuffer(buf))
d.UseNumber()
return d.Decode(obj)
}
// grabEncodeTargets grabs the fields for the obj.
func grabEncodeTargets(alg Algorithm, obj interface{}) (interface{}, interface{}, error) {
var headerObj, payloadObj interface{}
// get header
if headerVal := getFieldWithTag(obj, "header"); headerVal != nil {
headerObj = headerVal.Interface()
}
if headerObj == nil {
headerObj = alg.Header()
}
// get payload
if payloadVal := getFieldWithTag(obj, "payload"); payloadVal != nil {
payloadObj = payloadVal.Interface()
}
if payloadObj == nil {
payloadObj = obj
}
return headerObj, payloadObj, nil
}
// encodeTargets determines what to encode.
func encodeTargets(alg Algorithm, obj interface{}) (interface{}, interface{}, error) {
// determine what to encode
switch val := obj.(type) {
case *Token:
return val.Header, val.Payload, nil
}
objVal := reflect.ValueOf(obj)
objKind := objVal.Kind()
if objKind == reflect.Struct || (objKind == reflect.Ptr && objVal.Elem().Kind() == reflect.Struct) {
return grabEncodeTargets(alg, obj)
}
return alg.Header(), obj, nil
}
// tokenPosition is the different positions of the constituent JWT parts.
//
// Used in conjunction with peekField.
type tokenPosition int
const (
tokenPositionHeader tokenPosition = iota
tokenPositionPayload
// tokenPositionSignature
)
// peekField looks at an undecoded JWT, JSON decoding the data at pos, and
// returning the specified field's value as string.
//
// If the fieldName is not present, then an error will be returned.
func peekField(buf []byte, fieldName string, pos tokenPosition) (string, error) {
// split token
var t UnverifiedToken
if err := DecodeUnverifiedToken(buf, &t); err != nil {
return "", err
}
// determine position decode
var typ string
var b []byte
switch pos {
case tokenPositionHeader:
typ, b = "header", t.Header
case tokenPositionPayload:
typ, b = "payload", t.Payload
default:
return "", fmt.Errorf("invalid field %d", pos)
}
// b64 decode
dec, err := b64.DecodeString(string(b))
if err != nil {
return "", fmt.Errorf("could not decode token %s", typ)
}
// json decode
m := make(map[string]interface{})
if err := json.Unmarshal(dec, &m); err != nil {
return "", err
}
if val, ok := m[fieldName]; ok {
return fmt.Sprintf("%v", val), nil
}
return "", fmt.Errorf("token %s field %s not present or invalid", typ, fieldName)
}