Skip to content

Commit d03e03a

Browse files
committed
breaking change: HeaderValidator: returns an extra output argument which can optionally (if not nil) set the decryption method dynamically based on the kid
1 parent 4ee796d commit d03e03a

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

_examples/custom-header/main.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,30 @@ type Header struct {
6868
Alg string `json:"alg"`
6969
}
7070

71-
func validateHeader(alg string, headerDecoded []byte) (jwt.PublicKey, error) {
71+
func validateHeader(alg string, headerDecoded []byte) (jwt.Alg, jwt.PublicKey, jwt.InjectFunc, error) {
7272
var h Header
7373
err := jwt.Unmarshal(headerDecoded, &h)
7474
if err != nil {
75-
return nil, err
75+
return nil, nil, nil, err
7676
}
7777

7878
if h.Alg != alg {
79-
return nil, jwt.ErrTokenAlg
79+
return nil, nil, nil, jwt.ErrTokenAlg
8080
}
8181

8282
if h.Kid == "" {
83-
return nil, fmt.Errorf("kid is empty")
83+
return nil, nil, nil, fmt.Errorf("kid is empty")
8484
}
8585

8686
key, ok := keys[h.Kid]
8787
if !ok {
88-
return nil, fmt.Errorf("unknown kid")
88+
return nil, nil, nil, fmt.Errorf("unknown kid")
8989
}
9090

9191
publicKey, err := jwt.ParsePublicKeyRSA(key)
9292
if err != nil {
93-
return nil, jwt.ErrTokenAlg
93+
return nil, nil, nil, jwt.ErrTokenAlg
9494
}
9595

96-
return publicKey, nil
96+
return nil, publicKey, nil, nil
9797
}

kid_keys.go

+30-9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ type (
3434
Public PublicKey
3535
Private PrivateKey
3636
MaxAge time.Duration // optional.
37+
Encrypt InjectFunc // optional.
38+
Decrypt InjectFunc // optional.
3739
}
3840

3941
// Keys is a map which holds the key id and a key pair.
@@ -75,11 +77,20 @@ type (
7577
Alg string `json:"alg" yaml:"Alg" toml:"Alg" ini:"alg"`
7678
Private string `json:"private" yaml:"Private" toml:"Private" ini:"private"`
7779
Public string `json:"public" yaml:"Public" toml:"Public" ini:"public"`
78-
// Token expiration. Optional.
80+
// MaxAge sets the token expiration. It is optional.
7981
// If greater than zero then the MaxAge token validation
8082
// will be appended to the "VerifyToken" and the token is invalid
8183
// after expiration of its sign time.
8284
MaxAge time.Duration `json:"max_age" yaml:"MaxAge" toml:"MaxAge" ini:"max_age"`
85+
86+
// EncryptionKey enables encryption on the generated token. It is optional.
87+
// Encryption using the Galois Counter mode of operation with
88+
// AES cipher symmetric-key cryptographic.
89+
//
90+
// The value should be the AES key,
91+
// either 16, 24, or 32 bytes to select
92+
// AES-128, AES-192, or AES-256.
93+
EncryptionKey string `json:"encryption_key" yaml:"EncryptionKey" toml:"EncryptionKey" ini:"encryption_key"`
8394
}
8495
)
8596

@@ -131,6 +142,16 @@ func (c KeysConfiguration) Load() (Keys, error) {
131142
p.Public = entry.Public
132143
}
133144

145+
if entry.EncryptionKey != "" {
146+
encrypt, decrypt, err := GCM([]byte(entry.EncryptionKey), nil)
147+
if err != nil {
148+
return nil, fmt.Errorf("jwt: load keys: build encryption: %w", err)
149+
}
150+
151+
p.Encrypt = encrypt
152+
p.Decrypt = decrypt
153+
}
154+
134155
parsedKeys[entry.ID] = p
135156
}
136157

@@ -155,33 +176,33 @@ func (keys Keys) Register(alg Alg, kid string, pubKey PublicKey, privKey Private
155176

156177
// ValidateHeader validates the given json header value (base64 decoded) based on the "keys".
157178
// Keys structure completes the `HeaderValidator` interface.
158-
func (keys Keys) ValidateHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
179+
func (keys Keys) ValidateHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
159180
var h HeaderWithKid
160181

161182
err := Unmarshal(headerDecoded, &h)
162183
if err != nil {
163-
return nil, nil, err
184+
return nil, nil, nil, err
164185
}
165186

166187
if h.Kid == "" {
167-
return nil, nil, ErrEmptyKid
188+
return nil, nil, nil, ErrEmptyKid
168189
}
169190

170191
key, ok := keys.Get(h.Kid)
171192
if !ok {
172-
return nil, nil, ErrUnknownKid
193+
return nil, nil, nil, ErrUnknownKid
173194
}
174195

175196
if h.Alg != key.Alg.Name() {
176-
return nil, nil, ErrTokenAlg
197+
return nil, nil, nil, ErrTokenAlg
177198
}
178199

179200
// If for some reason a specific alg was given by the caller then check that as well.
180201
if alg != "" && alg != h.Alg {
181-
return nil, nil, ErrTokenAlg
202+
return nil, nil, nil, ErrTokenAlg
182203
}
183204

184-
return key.Alg, key.Public, nil
205+
return key.Alg, key.Public, key.Decrypt, nil
185206
}
186207

187208
// SignToken signs the "claims" using the given "alg" based a specific key.
@@ -195,7 +216,7 @@ func (keys Keys) SignToken(kid string, claims interface{}, opts ...SignOption) (
195216
opts = append([]SignOption{MaxAge(k.MaxAge)}, opts...)
196217
}
197218

198-
return SignWithHeader(k.Alg, k.Private, claims, HeaderWithKid{
219+
return SignEncryptedWithHeader(k.Alg, k.Private, k.Encrypt, claims, HeaderWithKid{
199220
Kid: kid,
200221
Alg: k.Alg.Name(),
201222
}, opts...)

token.go

+19-9
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV
8585
algName = alg.Name()
8686
}
8787

88-
dynamicAlg, pubKey, err := compareHeaderFunc(algName, headerDecoded)
88+
dynamicAlg, pubKey, decrypt, err := compareHeaderFunc(algName, headerDecoded)
8989
if err != nil {
9090
return nil, nil, nil, err
9191
}
@@ -113,6 +113,14 @@ func decodeToken(alg Alg, key PublicKey, token []byte, compareHeaderFunc HeaderV
113113
if err != nil {
114114
return nil, nil, nil, err
115115
}
116+
117+
if decrypt != nil {
118+
payload, err = decrypt(payload)
119+
if err != nil {
120+
return nil, nil, nil, err
121+
}
122+
}
123+
116124
return headerDecoded, payload, signatureDecoded, nil
117125
}
118126

@@ -196,22 +204,24 @@ func createHeaderWithoutTyp(alg string) []byte {
196204
// If the "alg" is empty then this function should return a non-nil algorithm
197205
// based on the token contents.
198206
// It should return a nil PublicKey and a non-nil error on validation failure.
207+
// The out InjectFunc is optional. If it's not nil then decryption of the payload
208+
// using GCM (AES key) is performed before verification.
199209
// On success, if public key is not nil then it overrides the VerifyXXX method's one.
200-
type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, error)
210+
type HeaderValidator func(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error)
201211

202212
// Note that this check is fully hard coded for known
203213
// algorithms and it is fully hard coded in terms of
204214
// its serialized format.
205-
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
215+
func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, InjectFunc, error) {
206216
if n := len(headerDecoded); n < 25 /* 28 but allow custom short algs*/ {
207217
if n == 15 { // header without "typ": "JWT".
208218
expectedHeader := createHeaderWithoutTyp(alg)
209219
if bytes.Equal(expectedHeader, headerDecoded) {
210-
return nil, nil, nil
220+
return nil, nil, nil, nil
211221
}
212222
}
213223

214-
return nil, nil, ErrTokenAlg
224+
return nil, nil, nil, ErrTokenAlg
215225
}
216226

217227
// Fast check if the order is reversed.
@@ -221,18 +231,18 @@ func compareHeader(alg string, headerDecoded []byte) (Alg, PublicKey, error) {
221231
if headerDecoded[2] == 't' {
222232
expectedHeader := createHeaderReversed(alg)
223233
if !bytes.Equal(expectedHeader, headerDecoded) {
224-
return nil, nil, ErrTokenAlg
234+
return nil, nil, nil, ErrTokenAlg
225235
}
226236

227-
return nil, nil, nil
237+
return nil, nil, nil, nil
228238
}
229239

230240
expectedHeader := createHeaderRaw(alg)
231241
if !bytes.Equal(expectedHeader, headerDecoded) {
232-
return nil, nil, ErrTokenAlg
242+
return nil, nil, nil, ErrTokenAlg
233243
}
234244

235-
return nil, nil, nil
245+
return nil, nil, nil, nil
236246
}
237247

238248
func createSignature(alg Alg, key PrivateKey, headerAndPayload []byte) ([]byte, error) {

token_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func TestCompareHeader(t *testing.T) {
9393
}
9494

9595
for i, tt := range tests {
96-
_, _, err := compareHeader(tt.alg, []byte(tt.header))
96+
_, _, _, err := compareHeader(tt.alg, []byte(tt.header))
9797
if tt.ok && err != nil {
9898
t.Fatalf("[%d] expected to pass but got error: %v", i, err)
9999
}

0 commit comments

Comments
 (0)