Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[keyvault/azsecrets] make azsecrets.Client thread-safe #24032

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/security/keyvault/azsecrets/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
* Fixed data race when using Client from multiple goroutines concurrently.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reference issue numbers here?


### Other Changes

Expand Down
37 changes: 31 additions & 6 deletions sdk/security/keyvault/internal/challenge_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -30,7 +31,12 @@ type KeyVaultChallengePolicyOptions struct {
type keyVaultAuthorizer struct {
// tro is the policy's authentication parameters. These are discovered from an authentication challenge
// elicited ahead of the first client request.
tro policy.TokenRequestOptions
//
// Protected by troLock.
tro policy.TokenRequestOptions
// Lock protecting tro in case there are multiple concurrent initial requests.
troLock sync.RWMutex

verifyChallengeResource bool
}

Expand All @@ -55,7 +61,8 @@ func NewKeyVaultChallengePolicy(cred azcore.TokenCredential, opts *KeyVaultChall
}

func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
if len(k.tro.Scopes) == 0 || k.tro.TenantID == "" {
tro := k.getTokenRequestOptions()
if len(tro.Scopes) == 0 || tro.TenantID == "" {
if body := req.Body(); body != nil {
// We don't know the scope or tenant ID because we haven't seen a challenge yet. We elicit one now by sending
// the request without authorization, first removing its body, if any. authorizeOnChallenge will reattach the
Expand All @@ -70,7 +77,7 @@ func (k *keyVaultAuthorizer) authorize(req *policy.Request, authNZ func(policy.T
return nil
}
// else we know the auth parameters and can authorize the request as normal
return authNZ(k.tro)
return authNZ(tro)
}

func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http.Response, authNZ func(policy.TokenRequestOptions) error) error {
Expand All @@ -87,7 +94,7 @@ func (k *keyVaultAuthorizer) authorizeOnChallenge(req *policy.Request, res *http
}
}
// authenticate with the parameters supplied by Key Vault, authorize the request, send it again
return authNZ(k.tro)
return authNZ(k.getTokenRequestOptions())
}

// parses Tenant ID from auth challenge
Expand Down Expand Up @@ -126,7 +133,6 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
}
}

k.tro.TenantID = parseTenant(vals["authorization"])
scope := ""
if v, ok := vals["scope"]; ok {
scope = v
Expand All @@ -149,6 +155,25 @@ func (k *keyVaultAuthorizer) updateTokenRequestOptions(resp *http.Response, req
if !strings.HasSuffix(scope, "/.default") {
scope += "/.default"
}
k.tro.Scopes = []string{scope}
k.setTokenRequestOptions(policy.TokenRequestOptions{
TenantID: parseTenant(vals["authorization"]),
Scopes: []string{scope},
})
return nil
}

// Returns a (possibly-zero) copy of TokenRequestOptions.
//
// The returned value's Scopes and other fields must not be modified.
func (k *keyVaultAuthorizer) getTokenRequestOptions() policy.TokenRequestOptions {
k.troLock.RLock()
defer k.troLock.RUnlock()
return k.tro // Copy.
}

// After calling this function, tro.Scopes and other fields must not be modified.
func (k *keyVaultAuthorizer) setTokenRequestOptions(tro policy.TokenRequestOptions) {
k.troLock.Lock()
defer k.troLock.Unlock()
k.tro = tro // Copy.
}
64 changes: 64 additions & 0 deletions sdk/security/keyvault/internal/challenge_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -355,3 +358,64 @@ func TestParseTenant(t *testing.T) {
actual = parseTenant(sampleURL)
require.Equal(t, expected, actual, "tenant was not properly parsed, got %s, expected %s", actual, expected)
}

func TestChallengePolicy_ConcurrentRequests(t *testing.T) {
concurrentRequestCount := 3

serverAuthenticateRequests := atomic.Int32{}
serverAuthenticatedRequests := atomic.Int32{}
var srv *httptest.Server
srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests use mock.Server, but here I used httptest.Server for two reasons:

  1. mock.Server is documented to not be safe when used from multiple goroutines.
  2. Even if mock.Server was safe at a surface level (e.g. by adding locks), its interface doesn't allow me to handle out-of-order requests.

httptest.Server has its own issues, but it seemed like the least amount of code to get the test working.

authz := r.Header.Values("Authorization")
if len(authz) == 0 {
// Initial request without Authorization header. Send a
// challenge response to the client.
serverAuthenticateRequests.Add(1)
resource := srv.URL
w.Header().Add("WWW-Authenticate", fmt.Sprintf(`Bearer authorization="https://login.microsoftonline.com/{tenant}", resource="%s"`, resource))
w.WriteHeader(401)
} else {
// Authenticated request.
serverAuthenticatedRequests.Add(1)
if len(authz) != 1 || authz[0] != "Bearer ***" {
t.Errorf(`unexpected Authorization "%s"`, authz)
}
// Return nothing.
w.WriteHeader(200)
}
}))
defer srv.Close()
srv.StartTLS()

cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewKeyVaultChallengePolicy(cred, &KeyVaultChallengePolicyOptions{
// Challenge resource verification will always fail because we
// use local IPs instead of domain names and subdomains in this
// test.
DisableChallengeResourceVerification: true,
})
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv.Client()},
)

wg := sync.WaitGroup{}
for i := 0; i < concurrentRequestCount; i += 1 {
go (func() {
defer wg.Done()
req, err := runtime.NewRequest(context.Background(), "GET", srv.URL)
require.NoError(t, err)
res, err := pl.Do(req)
require.NoError(t, err)
defer res.Body.Close()
})()
wg.Add(1)
}
wg.Wait()

require.GreaterOrEqual(t, int(serverAuthenticateRequests.Load()), 1, "client should have sent at least one preflight request")
require.LessOrEqual(t, int(serverAuthenticateRequests.Load()), concurrentRequestCount, "client should have sent no more preflight requests than client requests")
require.EqualValues(t, concurrentRequestCount, serverAuthenticatedRequests.Load(), "client preflight request count should equal server preflight request count")
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real assertion is the -race flag of go test. The three assertions in the test just make sure the test issued the requests, so they're not doing too much. Should I make the assertions more sophisticated? Is there a way to document that this test is designed to work with -race?