diff --git a/docker/docker_client.go b/docker/docker_client.go index 97d97fed5..7f61597a8 100644 --- a/docker/docker_client.go +++ b/docker/docker_client.go @@ -34,6 +34,7 @@ import ( digest "github.com/opencontainers/go-digest" imgspecv1 "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" ) const ( @@ -86,8 +87,19 @@ type extensionSignatureList struct { Signatures []extensionSignature `json:"signatures"` } -// bearerToken records a cached token we can use to authenticate. +// bearerToken records a cached token we can use to authenticate, or a pending process to obtain one. +// +// The goroutine obtaining the token holds lock to block concurrent token requests, and fills the structure (err and possibly the other fields) +// before releasing the lock. +// Other goroutines obtain lock to block on the token request, if any; and then inspect err to see if the token is usable. +// If it is not, they try to get a new one. type bearerToken struct { + // lock is held while obtaining the token. Potentially nested inside dockerClient.tokenCacheLock. + // This is a counting semaphore only because we need a cancellable lock operation. + lock *semaphore.Weighted + + // The following fields can only be accessed with lock held. + err error // nil if the token was successfully obtained (but may be expired); an error if the next lock holder _must_ obtain a new token. token string expirationTime time.Time } @@ -117,7 +129,8 @@ type dockerClient struct { supportsSignatures bool // Private state for setupRequestAuth (key: string, value: bearerToken) - tokenCache sync.Map + tokenCacheLock sync.Mutex // Protects tokenCache. + tokenCache map[string]*bearerToken // Private state for detectProperties: detectPropertiesOnce sync.Once // detectPropertiesOnce is used to execute detectProperties() at most once. detectPropertiesError error // detectPropertiesError caches the initial error. @@ -269,6 +282,7 @@ func newDockerClient(sys *types.SystemContext, registry, reference string) (*doc registry: registry, userAgent: userAgent, tlsClientConfig: tlsClientConfig, + tokenCache: map[string]*bearerToken{}, reportedWarnings: set.New[string](), }, nil } @@ -712,50 +726,11 @@ func (c *dockerClient) setupRequestAuth(req *http.Request, extraScope *authScope req.SetBasicAuth(c.auth.Username, c.auth.Password) return nil case "bearer": - registryToken := c.registryToken - if registryToken == "" { - cacheKey := "" - scopes := []authScope{c.scope} - if extraScope != nil { - // Using ':' as a separator here is unambiguous because getBearerToken below - // uses the same separator when formatting a remote request (and because - // repository names that we create can't contain colons, and extraScope values - // coming from a server come from `parseAuthScope`, which also splits on colons). - cacheKey = fmt.Sprintf("%s:%s:%s", extraScope.resourceType, extraScope.remoteName, extraScope.actions) - if colonCount := strings.Count(cacheKey, ":"); colonCount != 2 { - return fmt.Errorf( - "Internal error: there must be exactly 2 colons in the cacheKey ('%s') but got %d", - cacheKey, - colonCount, - ) - } - scopes = append(scopes, *extraScope) - } - var token bearerToken - t, inCache := c.tokenCache.Load(cacheKey) - if inCache { - token = t.(bearerToken) - } - if !inCache || time.Now().After(token.expirationTime) { - var ( - t *bearerToken - err error - ) - if c.auth.IdentityToken != "" { - t, err = c.getBearerTokenOAuth2(req.Context(), challenge, scopes) - } else { - t, err = c.getBearerToken(req.Context(), challenge, scopes) - } - if err != nil { - return err - } - - token = *t - c.tokenCache.Store(cacheKey, token) - } - registryToken = token.token + token, err := c.obtainBearerToken(req.Context(), challenge, extraScope) + if err != nil { + return err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", registryToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) return nil default: logrus.Debugf("no handler for %s authentication", challenge.Scheme) @@ -765,16 +740,115 @@ func (c *dockerClient) setupRequestAuth(req *http.Request, extraScope *authScope return nil } -func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge challenge, - scopes []authScope) (*bearerToken, error) { +// obtainBearerToken gets an "Authorization: Bearer" token if one is available, or obtains a fresh one. +func (c *dockerClient) obtainBearerToken(ctx context.Context, challenge challenge, extraScope *authScope) (string, error) { + if c.registryToken != "" { + return c.registryToken, nil + } + + cacheKey := "" + scopes := []authScope{c.scope} + if extraScope != nil { + // Using ':' as a separator here is unambiguous because getBearerToken below + // uses the same separator when formatting a remote request (and because + // repository names that we create can't contain colons, and extraScope values + // coming from a server come from `parseAuthScope`, which also splits on colons). + cacheKey = fmt.Sprintf("%s:%s:%s", extraScope.resourceType, extraScope.remoteName, extraScope.actions) + if colonCount := strings.Count(cacheKey, ":"); colonCount != 2 { + return "", fmt.Errorf( + "Internal error: there must be exactly 2 colons in the cacheKey ('%s') but got %d", + cacheKey, + colonCount, + ) + } + scopes = append(scopes, *extraScope) + } + + logrus.Debugf("REMOVE: Checking token cache for key %q", cacheKey) + token, newEntry, err := func() (*bearerToken, bool, error) { // A scope for defer + c.tokenCacheLock.Lock() + defer c.tokenCacheLock.Unlock() + token, ok := c.tokenCache[cacheKey] + if ok { + return token, false, nil + } else { + logrus.Debugf("REMOVE: No token cache for key %q, allocating one…", cacheKey) + token = &bearerToken{ + lock: semaphore.NewWeighted(1), + } + // If this is a new *bearerToken, lock the entry before adding it to the cache, so that any other goroutine that finds + // this entry blocks until we obtain the token for the first time, and does not see an empty object + // (and does not try to obtain the token itself when we are going to do so). + if err := token.lock.Acquire(ctx, 1); err != nil { + // We do not block on this Acquire, so we don’t really expect to fail here — but if ctx is canceled, + // there is no point in trying to continue anyway. + return nil, false, err + } + c.tokenCache[cacheKey] = token + return token, true, nil + } + }() + if err != nil { + return "", err + } + if !newEntry { + // If this is an existing *bearerToken, obtain the lock only after releasing c.tokenCacheLock, + // so that users of other cacheKey values are not blocked for the whole duration of our HTTP roundtrip. + logrus.Debugf("REMOVE: Found existing token cache for key %q, getting lock", cacheKey) + if err := token.lock.Acquire(ctx, 1); err != nil { + return "", err + } + logrus.Debugf("REMOVE: Locked existing token cache for key %q", cacheKey) + } + + defer token.lock.Release(1) + + // Determine if the bearerToken is usable: if it is not, log the cause and fall through, otherwise return early. + switch { + case newEntry: + logrus.Debugf("REMOVE: New token cache entry for key %q, getting first token", cacheKey) + case token.err != nil: + // If obtaining a token fails for any reason, the request that triggered that will fail; + // other requests will see token.err and try obtaining their own token, one goroutine at a time. + // (Consider that a request can fail because a very short timeout was provided to _that one operation_ using a context.Context; + // that clearly shouldn’t prevent other operations from trying with a longer timeout.) + // + // If we got here while holding token.lock, we are the goroutine responsible for trying again; others are blocked + // on token.lock. + logrus.Debugf("REMOVE: Token cache for key %q records failure %v, getting new token", cacheKey, token.err) + case time.Now().After(token.expirationTime): + logrus.Debugf("REMOVE: Token cache for key %q is expired, getting new token", cacheKey) + + default: + return token.token, nil + } + + if c.auth.IdentityToken != "" { + err = c.getBearerTokenOAuth2(ctx, token, challenge, scopes) + } else { + err = c.getBearerToken(ctx, token, challenge, scopes) + } + logrus.Debugf("REMOVE: Obtaining a token for key %q, error %v", cacheKey, err) + token.err = err + if token.err != nil { + return "", token.err + } + return token.token, nil +} + +// getBearerTokenOAuth2 obtains an "Authorization: Bearer" token using a pre-existing identity token per +// https://github.com/distribution/distribution/blob/main/docs/spec/auth/oauth.md for challenge and scopes, +// and writes it into dest. +func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, dest *bearerToken, challenge challenge, + scopes []authScope) error { realm, ok := challenge.Parameters["realm"] if !ok { - return nil, errors.New("missing realm in bearer auth challenge") + return errors.New("missing realm in bearer auth challenge") } authReq, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, nil) if err != nil { - return nil, err + return err } // Make the form data required against the oauth2 authentication @@ -799,26 +873,29 @@ func (c *dockerClient) getBearerTokenOAuth2(ctx context.Context, challenge chall logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted()) res, err := c.client.Do(authReq) if err != nil { - return nil, err + return err } defer res.Body.Close() if err := httpResponseToError(res, "Trying to obtain access token"); err != nil { - return nil, err + return err } - return newBearerTokenFromHTTPResponseBody(res) + return dest.readFromHTTPResponseBody(res) } -func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge, - scopes []authScope) (*bearerToken, error) { +// getBearerToken obtains an "Authorization: Bearer" token using a GET request, per +// https://github.com/distribution/distribution/blob/main/docs/spec/auth/token.md for challenge and scopes, +// and writes it into dest. +func (c *dockerClient) getBearerToken(ctx context.Context, dest *bearerToken, challenge challenge, + scopes []authScope) error { realm, ok := challenge.Parameters["realm"] if !ok { - return nil, errors.New("missing realm in bearer auth challenge") + return errors.New("missing realm in bearer auth challenge") } authReq, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil) if err != nil { - return nil, err + return err } params := authReq.URL.Query() @@ -846,22 +923,22 @@ func (c *dockerClient) getBearerToken(ctx context.Context, challenge challenge, logrus.Debugf("%s %s", authReq.Method, authReq.URL.Redacted()) res, err := c.client.Do(authReq) if err != nil { - return nil, err + return err } defer res.Body.Close() if err := httpResponseToError(res, "Requesting bearer token"); err != nil { - return nil, err + return err } - return newBearerTokenFromHTTPResponseBody(res) + return dest.readFromHTTPResponseBody(res) } -// newBearerTokenFromHTTPResponseBody parses a http.Response to obtain a bearerToken. +// readFromHTTPResponseBody sets token data by parsing a http.Response. // The caller is still responsible for ensuring res.Body is closed. -func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error) { +func (bt *bearerToken) readFromHTTPResponseBody(res *http.Response) error { blob, err := iolimits.ReadAtMost(res.Body, iolimits.MaxAuthTokenBodySize) if err != nil { - return nil, err + return err } var token struct { @@ -877,12 +954,10 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error if len(bodySample) > bodySampleLength { bodySample = bodySample[:bodySampleLength] } - return nil, fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err) + return fmt.Errorf("decoding bearer token (last URL %q, body start %q): %w", res.Request.URL.Redacted(), string(bodySample), err) } - bt := &bearerToken{ - token: token.Token, - } + bt.token = token.Token if bt.token == "" { bt.token = token.AccessToken } @@ -895,7 +970,7 @@ func newBearerTokenFromHTTPResponseBody(res *http.Response) (*bearerToken, error token.IssuedAt = time.Now().UTC() } bt.expirationTime = token.IssuedAt.Add(time.Duration(token.ExpiresIn) * time.Second) - return bt, nil + return nil } // detectPropertiesHelper performs the work of detectProperties which executes diff --git a/docker/docker_client_test.go b/docker/docker_client_test.go index 16e2c286b..3ee318070 100644 --- a/docker/docker_client_test.go +++ b/docker/docker_client_test.go @@ -106,7 +106,7 @@ func testTokenHTTPResponse(t *testing.T, body string) *http.Response { } } -func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { +func TestBearerTokenReadFromHTTPResponseBody(t *testing.T) { for _, c := range []struct { input string expected *bearerToken // or nil if a failure is expected @@ -128,7 +128,8 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { expected: &bearerToken{token: "IAmAToken", expirationTime: time.Unix(1514800802+60, 0)}, }, } { - token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, c.input)) + token := &bearerToken{} + err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, c.input)) if c.expected == nil { assert.Error(t, err, c.input) } else { @@ -140,11 +141,12 @@ func TestNewBearerTokenFromHTTPResponseBody(t *testing.T) { } } -func TestNewBearerTokenFromHTTPResponseBodyIssuedAtZero(t *testing.T) { +func TestBearerTokenReadFromHTTPResponseBodyIssuedAtZero(t *testing.T) { zeroTime := time.Time{}.Format(time.RFC3339) now := time.Now() tokenBlob := fmt.Sprintf(`{"token":"IAmAToken","expires_in":100,"issued_at":"%s"}`, zeroTime) - token, err := newBearerTokenFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob)) + token := &bearerToken{} + err := token.readFromHTTPResponseBody(testTokenHTTPResponse(t, tokenBlob)) require.NoError(t, err) expectedExpiration := now.Add(time.Duration(100) * time.Second) require.False(t, token.expirationTime.Before(expectedExpiration),