Skip to content

Commit

Permalink
[keyvault/azsecrets] make azsecrets.Client thread-safe
Browse files Browse the repository at this point in the history
azsecrets.Client uses NewKeyVaultChallengePolicy. This policy is not
goroutine-safe, violating the documented requirement that policies are
goroutine-safe [1]. This leads to data races which are reported by Go's race
detector.

Fix NewKeyVaultChallengePolicy to be goroutine-safe using a mutex. This can lead
to redundant preflight requests, but at least Go's race detector no longer
complains.

Test plan:

    $ cd sdk/security/keyvault/internal/
    $ go test -race

[1] https://learn.microsoft.com/en-us/azure/developer/go/azure-sdk-core-concepts
  • Loading branch information
strager committed Feb 3, 2025
1 parent 90c29cc commit 7adba6e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
1 change: 1 addition & 0 deletions sdk/security/keyvault/azsecrets/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
* Fixed data race when using Client from multiple goroutines concurrently.

### Other Changes

Expand Down
37 changes: 31 additions & 6 deletions sdk/security/keyvault/internal/challenge_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -30,7 +31,12 @@ type KeyVaultChallengePolicyOptions struct {
type keyVaultAuthorizer struct {
// tro is the policy's authentication parameters. These are discovered from an authentication challenge
// elicited ahead of the first client request.
tro policy.TokenRequestOptions
//
// Protected by troLock.
tro policy.TokenRequestOptions
// Lock protecting tro in case there are multiple concurrent initial requests.
troLock sync.RWMutex

verifyChallengeResource bool
}

Expand All @@ -55,7 +61,8 @@ func NewKeyVaultChallengePolicy(cred azcore.TokenCredential, opts *KeyVaultChall
}

func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
if len(k.tro.Scopes) == 0 || k.tro.TenantID == "" {
tro := k.getTokenRequestOptions()
if len(tro.Scopes) == 0 || tro.TenantID == "" {
if body := req.Body(); body != nil {
// We don't know the scope or tenant ID because we haven't seen a challenge yet. We elicit one now by sending
// the request without authorization, first removing its body, if any. authorizeOnChallenge will reattach the
Expand All @@ -70,7 +77,7 @@ func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.T
return nil
}
// else we know the auth parameters and can authorize the request as normal
return authNZ(k.tro)
return authNZ(tro)
}

func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http.Response, authNZ func(policy.TokenRequestOptions) error) error {
Expand All @@ -87,7 +94,7 @@ func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http
}
}
// authenticate with the parameters supplied by Key Vault, authorize the request, send it again
return authNZ(k.tro)
return authNZ(k.getTokenRequestOptions())
}

// parses Tenant ID from auth challenge
Expand Down Expand Up @@ -126,7 +133,6 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
}
}

k.tro.TenantID = parseTenant(vals["authorization"])
scope := ""
if v, ok := vals["scope"]; ok {
scope = v
Expand All @@ -149,6 +155,25 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
if !strings.HasSuffix(scope, "/.default") {
scope += "/.default"
}
k.tro.Scopes = []string{scope}
k.setTokenRequestOptions(policy.TokenRequestOptions{
TenantID: parseTenant(vals["authorization"]),
Scopes: []string{scope},
})
return nil
}

// Returns a (possibly-zero) copy of TokenRequestOptions.
//
// The returned value's Scopes and other fields must not be modified.
func (k *keyVaultAuthorizer) getTokenRequestOptions() policy.TokenRequestOptions {
k.troLock.RLock()
defer k.troLock.RUnlock()
return k.tro // Copy.
}

// After calling this function, tro.Scopes and other fields must not be modified.
func (k *keyVaultAuthorizer) setTokenRequestOptions(tro policy.TokenRequestOptions) {
k.troLock.Lock()
defer k.troLock.Unlock()
k.tro = tro // Copy.
}
64 changes: 64 additions & 0 deletions sdk/security/keyvault/internal/challenge_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -355,3 +358,64 @@ func TestParseTenant(t *testing.T) {
actual = parseTenant(sampleURL)
require.Equal(t, expected, actual, "tenant was not properly parsed, got %s, expected %s", actual, expected)
}

func TestChallengePolicy_ConcurrentRequests(t *testing.T) {
concurrentRequestCount := 3

serverAuthenticateRequests := atomic.Int32{}
serverAuthenticatedRequests := atomic.Int32{}
var srv *httptest.Server
srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authz := r.Header.Values("Authorization")
if len(authz) == 0 {
// Initial request without Authorization header. Send a
// challenge response to the client.
serverAuthenticateRequests.Add(1)
resource := srv.URL
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer authorization="https://login.microsoftonline.com/{tenant}", resource="%s"`, resource))
w.WriteHeader(401)
} else {
// Authenticated request.
serverAuthenticatedRequests.Add(1)
if len(authz) != 1 || authz[0] != "Bearer ***" {
t.Errorf(`unexpected Authorization "%s"`, authz)
}
// Return nothing.
w.WriteHeader(200)
}
}))
defer srv.Close()
srv.StartTLS()

cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewKeyVaultChallengePolicy(cred, &KeyVaultChallengePolicyOptions{
// Challenge resource verification will always fail because we
// use local IPs instead of domain names and subdomains in this
// test.
DisableChallengeResourceVerification: true,
})
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv.Client()},
)

wg := sync.WaitGroup{}
for i := 0; i < concurrentRequestCount; i += 1 {
go (func() {
defer wg.Done()
req, err := runtime.NewRequest(context.Background(), "GET", srv.URL)
require.NoError(t, err)
res, err := pl.Do(req)
require.NoError(t, err)
defer res.Body.Close()
})()
wg.Add(1)
}
wg.Wait()

require.GreaterOrEqual(t, int(serverAuthenticateRequests.Load()), 1, "client should have sent at least one preflight request")
require.LessOrEqual(t, int(serverAuthenticateRequests.Load()), concurrentRequestCount, "client should have sent no more preflight requests than client requests")
require.EqualValues(t, concurrentRequestCount, serverAuthenticatedRequests.Load(), "client preflight request count should equal server preflight request count")
}

0 comments on commit 7adba6e

Please sign in to comment.