-
Notifications
You must be signed in to change notification settings - Fork 1
/
issuer.go
153 lines (125 loc) · 3.87 KB
/
issuer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package jwit
import (
"encoding/json"
"io/ioutil"
"log"
"net/http"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
)
// Issuer is a third-party that publishes a set of public keys.
type Issuer struct {
sync.Mutex
// Name is the name of the issuer (corresponding to a "iss" claim)
Name string
// PublicKeys is a set of known public keys for this issuer.
PublicKeys []interface{}
// JWKSURL is an URL where the issuer publishes its JWKS.
JWKSURL string
// TTL defines how long a JWKS are considered "fresh". Past that TTL, jwit will try
// to refresh the JWKS asynchronously.
TTL time.Duration
// localJWKS is a JWKS derived from PublicKeys the user gave when creating the issuer.
localJWKS jose.JSONWebKeySet
// jwks is the last JSON Web Key Set known for this issuer. It contains the union of the local
// JWKS and the JWKS resulting from the last call to JWKSURL.
jwks jose.JSONWebKeySet
// lastRefreshedAt shows when the issuer's JWKS were last fetched.
lastRefreshedAt time.Time
}
// initialize initializes the issuer's default parameters and local JWKS.
func (issuer *Issuer) initialize() error {
// Set a default expiration to 24 hours.
if issuer.TTL == time.Duration(0) {
issuer.TTL = 24 * time.Hour
}
// Extract JWKS from provided public keys
issuer.localJWKS.Keys = make([]jose.JSONWebKey, 0, len(issuer.PublicKeys))
localKeys := &issuer.localJWKS.Keys
for _, publicKey := range issuer.PublicKeys {
switch publicKey := publicKey.(type) {
case []byte:
var err error
var jwks jose.JSONWebKeySet
firstChar := publicKey[0]
if firstChar == '{' {
// Bytes look like a JSON payload. Likely a JWKS.
jwks, err = loadJWKS(publicKey)
} else {
// If it's not a JWKS, then it's probably a PEM.
jwks, err = pemToJWKS(publicKey)
}
if err != nil {
return err
}
for _, key := range jwks.Keys {
*localKeys = append(*localKeys, key.Public())
}
default:
key := (&jose.JSONWebKey{Key: publicKey}).Public()
*localKeys = append(*localKeys, key)
}
}
issuer.PublicKeys = nil
issuer.jwks = issuer.localJWKS
return nil
}
// needsRefresh checks if the issuer's JWKS have reached TTL.
func (issuer *Issuer) needsRefresh() bool {
if issuer.JWKSURL == "" {
return false
}
return issuer.TTL <= time.Since(issuer.lastRefreshedAt)
}
// getJWKS returns the issuer's JWKS. Fetches the JWKS if they are missing and schedules a refresh
// of the JWKS if they are out-of-date.
func (issuer *Issuer) getJWKS(client *http.Client) (*jose.JSONWebKeySet, error) {
jwks := issuer.jwks
needsRefresh := issuer.needsRefresh()
if needsRefresh && issuer.lastRefreshedAt.IsZero() {
// First time we fetch the JWKS, we need to do it synchronously.
return issuer.refreshJWKS(client)
}
if needsRefresh {
// The JWKS are out of date. Schedule their refresh.
go func() {
if _, err := issuer.refreshJWKS(client); err != nil {
log.Printf("jwit: couldn't fetch issuer's JWKS: %s", err)
}
}()
}
return &jwks, nil
}
// refreshJWKS thread-safely fetches the issuer's JWKS.
func (issuer *Issuer) refreshJWKS(client *http.Client) (*jose.JSONWebKeySet, error) {
var jwks jose.JSONWebKeySet
// Lock the issuer to prevent races
issuer.Lock()
defer issuer.Unlock()
// Check data isn't fresh already (in case it was already fetched by a concurrent goroutine)
if !issuer.needsRefresh() {
return &issuer.jwks, nil
}
// Make the HTTP request
resp, err := client.Get(issuer.JWKSURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if !(200 <= resp.StatusCode && resp.StatusCode < 300) {
return nil, ErrJWKSFetchFailed
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(body, &jwks)
if err != nil {
return nil, err
}
// Set the JWKS value and return it
issuer.jwks.Keys = append(issuer.localJWKS.Keys, jwks.Keys...)
issuer.lastRefreshedAt = time.Now()
return &jwks, nil
}