diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index dfe5b34b..08556ab1 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -49,6 +49,6 @@ jobs: uses: modelcontextprotocol/conformance@c2f3fdaf781dcd5a862cb0d2f6454c1c210bf0f0 # v0.1.11 with: mode: client - command: go run ./conformance/everything-client/main.go + command: go run -tags mcp_go_client_oauth ./conformance/everything-client suite: core expected-failures: ./conformance/baseline.yml diff --git a/auth/auth.go b/auth/auth.go index 87665121..29cca526 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -106,6 +106,9 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO } return nil, err.Error(), http.StatusInternalServerError } + if tokenInfo == nil { + return nil, "token validation failed", http.StatusInternalServerError + } // Check scopes. All must be present. if opts != nil { diff --git a/auth/authorization_code.go b/auth/authorization_code.go new file mode 100644 index 00000000..4600ea27 --- /dev/null +++ b/auth/authorization_code.go @@ -0,0 +1,450 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "slices" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// ClientSecretAuthMethod defines "client_secret_*" authentication methods per +// https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml#token-endpoint-auth-method. +// "client_secret_jwt" is not currently supported. +type ClientSecretAuthMethod int + +const ( + // ClientSecretAuthMethodBasic uses the "client_secret_basic" authentication method. + ClientSecretAuthMethodBasic ClientSecretAuthMethod = iota + // ClientSecretAuthMethodPost uses the "client_secret_post" authentication method. + ClientSecretAuthMethodPost +) + +func (m ClientSecretAuthMethod) String() string { + switch m { + case ClientSecretAuthMethodBasic: + return "client_secret_basic" + case ClientSecretAuthMethodPost: + return "client_secret_post" + default: + return "" + } +} + +// ClientSecretAuthConfig is used to configure client authentication using client_secret. +type ClientSecretAuthConfig struct { + // ClientID is the client ID to be used for client authentication. + ClientID string + // ClientSecret is the client secret to be used for client authentication. + ClientSecret string + // PreferredClientSecretAuthMethod to be used for client authentication. + // If not specified or unsupported by the authorization server, the method + // will be selected based on the authorization server's supported methods, + // according to the following preference order: + // + // 1. "client_secret_post" + // 2. "client_secret_basic" + // + PreferredClientSecretAuthMethod ClientSecretAuthMethod +} + +// ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document +// based client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. +// See https://client.dev/ for more information. +type ClientIDMetadataDocumentConfig struct { + // URL is the client identifier URL as per + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-client-id-metadata-document-00#section-3. + URL string +} + +// PreregisteredClientConfig is used to configure a pre-registered client per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. +// Currently only "client_secret_basic" and "client_secret_post" authentication methods are supported. +type PreregisteredClientConfig struct { + // ClientSecretAuthConfig is the client_secret based configuration to be used for client authentication. + ClientSecretAuthConfig *ClientSecretAuthConfig +} + +// DynamicClientRegistrationConfig is used to configure dynamic client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration. +type DynamicClientRegistrationConfig struct { + // Metadata to be used in dynamic client registration request as per + // https://datatracker.ietf.org/doc/html/rfc7591#section-2. + Metadata *oauthex.ClientRegistrationMetadata +} + +// AuthorizationResult is the result of an authorization flow. +// It is returned by [AuthorizationCodeHandler.AuthorizationURLHandler] implementations. +type AuthorizationResult struct { + // AuthorizationCode is the authorization code obtained from the authorization server. + AuthorizationCode string + // State string returned by the authorization server. + State string +} + +// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. +type AuthorizationCodeHandlerConfig struct { + // Client registration configuration. + // It is attempted in the following order: + // + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration + // + // At least one method must be configured. + ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig + PreregisteredClientConfig *PreregisteredClientConfig + DynamicClientRegistrationConfig *DynamicClientRegistrationConfig + + // RedirectURL is a required URL to redirect to after authorization. + // The caller is responsible for handling the redirect out of band. + // If Dynamic Client Registration is used, the RedirectURL must be consistent + // with [DynamicClientRegistrationConfig.Metadata.RedirectURIs]. + RedirectURL string + + // AuthorizationURLHandler is a required function called to handle the authorization request. + // It is responsible for opening the URL in a browser for the user to start the authorization process. + // It should return the authorization code and state once the Authorization Server + // redirects back to the [AuthorizationCodeHandler.RedirectURL]. + AuthorizationURLHandler func(ctx context.Context, authorizationURL string) (*AuthorizationResult, error) +} + +// AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses +// the authorization code flow to obtain access tokens. +type AuthorizationCodeHandler struct { + config *AuthorizationCodeHandlerConfig + + // tokenSource is the token source to use for authorization. + tokenSource oauth2.TokenSource +} + +var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) + +func (h *AuthorizationCodeHandler) isOAuthHandler() {} + +func (h *AuthorizationCodeHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. +// It performs validation of the configuration and returns an error if it is invalid. +// The passed config is consumed by the handler and should not be modified after. +func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.ClientIDMetadataDocumentConfig == nil && + config.PreregisteredClientConfig == nil && + config.DynamicClientRegistrationConfig == nil { + return nil, errors.New("at least one client registration configuration must be provided") + } + if config.RedirectURL == "" { + return nil, errors.New("field RedirectURL is required") + } + if config.AuthorizationURLHandler == nil { + return nil, errors.New("field AuthorizationURLHandler is required") + } + if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { + return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") + } + preCfg := config.PreregisteredClientConfig + if preCfg != nil { + if preCfg.ClientSecretAuthConfig == nil { + return nil, errors.New("field ClientSecretAuthConfig is required for pre-registered client") + } + if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { + return nil, fmt.Errorf("pre-registered client ID or secret is empty") + } + } + if config.DynamicClientRegistrationConfig != nil { + if config.DynamicClientRegistrationConfig.Metadata == nil { + return nil, errors.New("field Metadata is required for dynamic client registration") + } + if !slices.Contains(config.DynamicClientRegistrationConfig.Metadata.RedirectURIs, config.RedirectURL) { + return nil, fmt.Errorf("redirect URI %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) + } + } + return &AuthorizationCodeHandler{config: config}, nil +} + +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + +// Authorize performs the authorization flow. +// It is designed to perform the whole Authorization Code Grant flow. +// On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. +func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + log.Printf("Authorize: %s %s", req.Method, req.URL) + + resourceURL := req.URL.String() + wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) + } + log.Printf("WWW-Authenticate header: %v", wwwChallenges) + + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, resourceURL) + if err != nil { + return err + } + // log.Printf("Protected resource metadata: %+v", prm) + + asm, err := h.getAuthServerMetadata(ctx, prm, resourceURL) + if err != nil { + return err + } + // log.Printf("Authorization server metadata: %+v", asm) + + resolvedClientConfig, err := h.handleRegistration(ctx, asm) + if err != nil { + return err + } + + scopes := oauthex.Scopes(wwwChallenges) + if len(scopes) == 0 && prm != nil && len(prm.ScopesSupported) > 0 { + scopes = prm.ScopesSupported + } + + cfg := &oauth2.Config{ + ClientID: resolvedClientConfig.clientID, + ClientSecret: resolvedClientConfig.clientSecret, + + Endpoint: oauth2.Endpoint{ + AuthURL: asm.AuthorizationEndpoint, + TokenURL: asm.TokenEndpoint, + AuthStyle: resolvedClientConfig.authStyle, + }, + RedirectURL: h.config.RedirectURL, + Scopes: scopes, + } + + authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return err + } + + return h.exchangeAuthorizationCode(ctx, cfg, authRes, resourceURL) +} + +// getProtectedResourceMetadata returns the protected resource metadata. +// If no metadata was found or the fetched metadata fails security checks, +// it returns an error. +func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, resourceURL string) (*oauthex.ProtectedResourceMetadata, error) { + var errs []error + for _, url := range oauthex.ProtectedResourceMetadataURLs(oauthex.ResourceMetadataURL(wwwChallenges), resourceURL) { + log.Printf("Getting protected resource metadata from %q", url) + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url, http.DefaultClient) + if err != nil { + errs = append(errs, err) + continue + } + return prm, nil + } + return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) +} + +// getAuthServerMetadata returns the authorization server metadata. +// It returns an error if the metadata request fails with non-4xx HTTP status code +// or the fetched metadata fails security checks. +// If no metadata was found, it returns a minimal set of endpoints +// as a fallback to 2025-03-26 spec. +func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata, resourceURL string) (*oauthex.AuthServerMeta, error) { + var authServerURL string + if prm != nil && len(prm.AuthorizationServers) > 0 { + // Use the first authorization server, similarly to other SDKs. + authServerURL = prm.AuthorizationServers[0] + } else { + // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. + authURL, err := url.Parse(resourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse resource URL: %v", err) + } + authURL.Path = "" + authServerURL = authURL.String() + } + log.Printf("Authorization server URL: %s", authServerURL) + + for _, u := range oauthex.AuthorizationServerMetadataURLs(authServerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, u, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm != nil { + return asm, nil + } + } + + log.Print("Authorization server metadata not found, using fallback") + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + asm := &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", + } + return asm, nil +} + +type registrationType int + +const ( + registrationTypeClientIDMetadataDocument registrationType = iota + registrationTypePreregistered + registrationTypeDynamic +) + +type resolvedClientConfig struct { + registrationType registrationType + clientID string + clientSecret string + authStyle oauth2.AuthStyle +} + +func selectTokenAuthMethod(supported []string, preferred ClientSecretAuthMethod) oauth2.AuthStyle { + if slices.Contains(supported, preferred.String()) { + return authMethodToStyle(preferred.String()) + } + prefOrder := []string{ + // Preferred in OAuth 2.1 draft: https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-14.html#name-client-secret. + "client_secret_post", + "client_secret_basic", + } + for _, method := range prefOrder { + if slices.Contains(supported, method) { + return authMethodToStyle(method) + } + } + return oauth2.AuthStyleAutoDetect +} + +func authMethodToStyle(method string) oauth2.AuthStyle { + switch method { + case "client_secret_post": + return oauth2.AuthStyleInParams + case "client_secret_basic": + return oauth2.AuthStyleInHeader + case "none": + // "none" is equivalent to "client_secret_post" but without sending client secret. + return oauth2.AuthStyleInParams + default: + // "client_secret_basic" is the default per https://datatracker.ietf.org/doc/html/rfc7591#section-2. + return oauth2.AuthStyleInHeader + } +} + +// handleRegistration handles client registration. +// The provided authorization server metadata must be non-nil. +// Support for different registration methods is defined as follows: +// - Client ID Metadata Document: metadata must have +// `ClientIDMetadataDocumentSupported` set to true. +// - Pre-registered client: assumed to be supported. +// - Dynamic client registration: metadata must have +// `RegistrationEndpoint` set to a non-empty value. +func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { + // 1. Attempt to use Client ID Metadata Document (SEP-991). + cimdCfg := h.config.ClientIDMetadataDocumentConfig + if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { + return &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: cimdCfg.URL, + }, nil + } + // 2. Attempt to use pre-registered client configuration. + pCfg := h.config.PreregisteredClientConfig + if pCfg != nil { + authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported, pCfg.ClientSecretAuthConfig.PreferredClientSecretAuthMethod) + return &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: pCfg.ClientSecretAuthConfig.ClientID, + clientSecret: pCfg.ClientSecretAuthConfig.ClientSecret, + authStyle: authStyle, + }, nil + } + // 3. Attempt to use dynamic client registration. + dcrCfg := h.config.DynamicClientRegistrationConfig + if dcrCfg != nil && asm.RegistrationEndpoint != "" { + regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + cfg := &resolvedClientConfig{ + registrationType: registrationTypeDynamic, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + authStyle: authMethodToStyle(regResp.TokenEndpointAuthMethod), + } + log.Printf("Client registered with client ID: %s", regResp.ClientID) + return cfg, nil + } + return nil, fmt.Errorf("no configured client registration methods are supported by the authorization server") +} + +type authResult struct { + *AuthorizationResult + // usedCodeVerifier is the PKCE code verifier used to obtain the authorization code. + // It is preserved for the token exchange step. + usedCodeVerifier string +} + +// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationURLHandler] +// to obtain an authorization code. +func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { + codeVerifier := oauth2.GenerateVerifier() + state := rand.Text() + + authURL := cfg.AuthCodeURL(state, + oauth2.S256ChallengeOption(codeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + ) + + log.Printf("Calling AuthorizationURLHandler: %q", authURL) + authRes, err := h.config.AuthorizationURLHandler(ctx, authURL) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return nil, err + } + if authRes.State != state { + return nil, fmt.Errorf("state mismatch") + } + return &authResult{ + AuthorizationResult: authRes, + usedCodeVerifier: codeVerifier, + }, nil +} + +// exchangeAuthorizationCode exchanges the authorization code for a token +// and stores it in a token source. +func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { + log.Printf("Exchanging authorization code for token") + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(authResult.usedCodeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + } + token, err := cfg.Exchange(ctx, authResult.AuthorizationCode, opts...) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + h.tokenSource = cfg.TokenSource(ctx, token) + return nil +} diff --git a/auth/client.go b/auth/client.go index acadc51b..6ddd4a29 100644 --- a/auth/client.go +++ b/auth/client.go @@ -2,122 +2,29 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:build mcp_go_client_oauth - package auth import ( - "bytes" - "errors" - "io" + "context" "net/http" - "sync" "golang.org/x/oauth2" ) -// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization -// is approved, or an error if not. -// The handler receives the HTTP request and response that triggered the authentication flow. -// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) - -// HTTPTransport is an [http.RoundTripper] that follows the MCP -// OAuth protocol when it encounters a 401 Unauthorized response. -type HTTPTransport struct { - handler OAuthHandler - mu sync.Mutex // protects opts.Base - opts HTTPTransportOptions -} - -// NewHTTPTransport returns a new [*HTTPTransport]. -// The handler is invoked when an HTTP request results in a 401 Unauthorized status. -// It is called only once per transport. Once a TokenSource is obtained, it is used -// for the lifetime of the transport; subsequent 401s are not processed. -func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { - if handler == nil { - return nil, errors.New("handler cannot be nil") - } - t := &HTTPTransport{ - handler: handler, - } - if opts != nil { - t.opts = *opts - } - if t.opts.Base == nil { - t.opts.Base = http.DefaultTransport - } - return t, nil -} - -// HTTPTransportOptions are options to [NewHTTPTransport]. -type HTTPTransportOptions struct { - // Base is the [http.RoundTripper] to use. - // If nil, [http.DefaultTransport] is used. - Base http.RoundTripper -} - -func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - t.mu.Lock() - base := t.opts.Base - t.mu.Unlock() - - var ( - // If haveBody is set, the request has a nontrivial body, and we need avoid - // reading (or closing) it multiple times. In that case, bodyBytes is its - // content. - haveBody bool - bodyBytes []byte - ) - if req.Body != nil && req.Body != http.NoBody { - // if we're setting Body, we must mutate first. - req = req.Clone(req.Context()) - haveBody = true - var err error - bodyBytes, err = io.ReadAll(req.Body) - if err != nil { - return nil, err - } - // Now that we've read the request body, http.RoundTripper requires that we - // close it. - req.Body.Close() // ignore error - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - resp, err := base.RoundTrip(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if _, ok := base.(*oauth2.Transport); ok { - // We failed to authorize even with a token source; give up. - return resp, nil - } - - resp.Body.Close() - // Try to authorize. - t.mu.Lock() - defer t.mu.Unlock() - // If we don't have a token source, get one by following the OAuth flow. - // (We may have obtained one while t.mu was not held above.) - // TODO: We hold the lock for the entire OAuth flow. This could be a long - // time. Is there a better way? - if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - ts, err := t.handler(req, resp) - if err != nil { - return nil, err - } - t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} - } - - // If we don't have a body, the request is reusable, though it will be cloned - // by the base. However, if we've had to read the body, we must clone. - if haveBody { - req = req.Clone(req.Context()) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - return t.opts.Base.RoundTrip(req) +type OAuthHandler interface { + isOAuthHandler() + + // TokenSource returns a token source to be used for outgoing requests. + // Returned token source might be nil. In that case, the transport will not + // add any authorization headers to the request. + TokenSource(context.Context) (oauth2.TokenSource, error) + + // Authorize is called when an HTTP request results in an error that may + // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). + // It is responsible for performing the OAuth flow to obtain an access token. + // The arguments are the request that failed and the response that was received for it. + // If the returned error is nil, [TokenSource] is expected to return a non-nil token source. + // After a successful call to [Authorize], the HTTP request should be retried by the transport. + // The function is responsible for closing the response body. + Authorize(context.Context, *http.Request, *http.Response) error } diff --git a/auth/client_private.go b/auth/client_private.go new file mode 100644 index 00000000..f161bdc6 --- /dev/null +++ b/auth/client_private.go @@ -0,0 +1,131 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type HTTPTransport struct { + handler OAuthHandlerLegacy + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +// Deprecated: Please use the new OAuthHandler abstraction that is built +// into the streamable transport. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/auth/fake.go b/auth/fake.go new file mode 100644 index 00000000..b8d82f33 --- /dev/null +++ b/auth/fake.go @@ -0,0 +1,27 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package auth + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" +) + +type FakeOAuthHandler struct { + Token *oauth2.Token + AuthorizeErr error +} + +func (h *FakeOAuthHandler) isOAuthHandler() {} + +func (h *FakeOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return oauth2.StaticTokenSource(h.Token), nil +} + +func (h *FakeOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + return h.AuthorizeErr +} diff --git a/conformance/baseline.yml b/conformance/baseline.yml index ae1f9c63..9d7bef80 100644 --- a/conformance/baseline.yml +++ b/conformance/baseline.yml @@ -1,20 +1,2 @@ server: [] # All tests pass! -client: -- auth/basic-cimd -- auth/metadata-default -- auth/metadata-var1 -- auth/metadata-var2 -- auth/metadata-var3 -- auth/2025-03-26-oauth-metadata-backcompat -- auth/2025-03-26-oauth-endpoint-fallback -- auth/scope-from-www-authenticate -- auth/scope-from-scopes-supported -- auth/scope-omitted-when-undefined -- auth/scope-step-up -- auth/scope-retry-limit -- auth/token-endpoint-auth-basic -- auth/token-endpoint-auth-post -- auth/token-endpoint-auth-none -- auth/client-credentials-jwt -- auth/client-credentials-basic -- auth/pre-registration +client: [] # All tests pass! diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go new file mode 100644 index 00000000..6b4654ef --- /dev/null +++ b/conformance/everything-client/client_private.go @@ -0,0 +1,139 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The conformance client implements features required for MCP conformance testing. +// It mirrors the functionality of the TypeScript conformance client at +// https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/conformance/everything-client.ts + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +func init() { + authScenarios := []string{ + "auth/2025-03-26-oauth-metadata-backcompat", + "auth/2025-03-26-oauth-endpoint-fallback", + "auth/basic-cimd", + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/pre-registration", + "auth/resource-mismatch", + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + } + for _, scenario := range authScenarios { + registerScenario(scenario, runAuthClient) + } +} + +// ============================================================================ +// Auth scenarios +// ============================================================================ + +func fetchAuthorizationCodeAndState(ctx context.Context, authURL string) (*auth.AuthorizationResult, error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", authURL, nil) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // In conformance tests the authorization server immediately redirects + // to the callback URL with the authorization code and state. + locURL, err := url.Parse(resp.Header.Get("Location")) + if err != nil { + return nil, fmt.Errorf("parse location: %v", err) + } + + return &auth.AuthorizationResult{ + AuthorizationCode: locURL.Query().Get("code"), + State: locURL.Query().Get("state"), + }, nil +} + +func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { + authConfig := &auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:3000/callback", + AuthorizationURLHandler: fetchAuthorizationCodeAndState, + // Try client ID metadata document based registration. + ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ + URL: "https://conformance-test.local/client-metadata.json", + }, + // Try dynamic client registration. + DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"http://localhost:3000/callback"}, + }, + }, + } + // Try pre-registered client information if provided in the context. + if clientID, ok := configCtx["client_id"].(string); ok { + if clientSecret, ok := configCtx["client_secret"].(string); ok { + authConfig.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ + ClientSecretAuthConfig: &auth.ClientSecretAuthConfig{ + ClientID: clientID, + ClientSecret: clientSecret, + }, + } + } + } + + authHandler, err := auth.NewAuthorizationCodeHandler(authConfig) + if err != nil { + return fmt.Errorf("failed to create auth handler: %w", err) + } + + session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + return err + } + defer session.Close() + + if _, err := session.ListTools(ctx, nil); err != nil { + return fmt.Errorf("session.ListTools(): %v", err) + } + + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]any{}, + }); err != nil { + return fmt.Errorf("session.CallTool('test-tool'): %v", err) + } + + return nil +} + +func withOAuthHandler(handler auth.OAuthHandler) connectOption { + return func(c *connectConfig) { + c.oauthHandler = handler + } +} diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 9674dbbc..d34e8328 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -9,6 +9,7 @@ package main import ( "context" + "encoding/json" "fmt" "log" "os" @@ -16,12 +17,13 @@ import ( "sort" "strings" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/mcp" ) // scenarioHandler is the function signature for all conformance test scenarios. // It takes a context and the server URL to connect to. -type scenarioHandler func(ctx context.Context, serverURL string) error +type scenarioHandler func(ctx context.Context, serverURL string, configCtx map[string]any) error var ( // registry stores all registered scenario handlers. @@ -48,7 +50,7 @@ func init() { // Basic scenarios // ============================================================================ -func runBasicClient(ctx context.Context, serverURL string) error { +func runBasicClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -63,7 +65,7 @@ func runBasicClient(ctx context.Context, serverURL string) error { return nil } -func runToolsCallClient(ctx context.Context, serverURL string) error { +func runToolsCallClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -97,7 +99,7 @@ func runToolsCallClient(ctx context.Context, serverURL string) error { // Elicitation scenarios // ============================================================================ -func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { +func runElicitationDefaultsClient(ctx context.Context, serverURL string, _ map[string]any) error { elicitationHandler := func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { return &mcp.ElicitResult{ Action: "accept", @@ -141,7 +143,7 @@ func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { // SSE retry scenario // ============================================================================ -func runSSERetryClient(ctx context.Context, serverURL string) error { +func runSSERetryClient(ctx context.Context, serverURL string, _ map[string]any) error { // TODO: this scenario is not passing yet. It requires a fix in the client SSE handling. session, err := connectToServer(ctx, serverURL) if err != nil { @@ -185,6 +187,7 @@ func main() { serverURL := os.Args[1] scenarioName := os.Getenv("MCP_CONFORMANCE_SCENARIO") + configCtx := getConformanceContext() if scenarioName == "" { printUsageAndExit("MCP_CONFORMANCE_SCENARIO not set") @@ -196,11 +199,21 @@ func main() { } ctx := context.Background() - if err := handler(ctx, serverURL); err != nil { + if err := handler(ctx, serverURL, configCtx); err != nil { log.Fatalf("Scenario %q failed: %v", scenarioName, err) } } +func getConformanceContext() map[string]any { + ctxStr := os.Getenv("MCP_CONFORMANCE_CONTEXT") + if ctxStr == "" { + return nil + } + var ctx map[string]any + _ = json.Unmarshal([]byte(ctxStr), &ctx) + return ctx +} + func printUsageAndExit(format string, args ...any) { var scenarios []string for name := range registry { @@ -214,6 +227,7 @@ func printUsageAndExit(format string, args ...any) { type connectConfig struct { clientOptions *mcp.ClientOptions + oauthHandler auth.OAuthHandler } type connectOption func(*connectConfig) @@ -237,11 +251,14 @@ func connectToServer(ctx context.Context, serverURL string, opts ...connectOptio Version: "1.0.0", }, config.clientOptions) - transport := &mcp.StreamableClientTransport{Endpoint: serverURL} + transport := &mcp.StreamableClientTransport{ + Endpoint: serverURL, + OAuthHandler: config.oauthHandler, + } session, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, fmt.Errorf("client.Connect(): %v", err) + return nil, fmt.Errorf("client.Connect(): %w", err) } return session, nil diff --git a/docs/protocol.md b/docs/protocol.md index 16ba0bfa..0ed6b3af 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -306,9 +306,50 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { + // Open the URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil + }, +}) + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `auth.AuthorizationCodeHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security @@ -317,9 +358,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go new file mode 100644 index 00000000..dfc3d102 --- /dev/null +++ b/examples/auth/client/main.go @@ -0,0 +1,131 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + // URL of the MCP server. + serverURL = flag.String("server_url", "http://localhost:8000/mcp", "URL of the MCP server.") + // Port for the local HTTP server that will receive the authorization code. + callbackPort = flag.Int("callback_port", 3142, "Port for the local HTTP server that will receive the authorization code.") +) + +type codeReceiver struct { + authChan chan *auth.AuthorizationResult + errChan chan error + listener net.Listener + server *http.Server +} + +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + r.authChan <- &auth.AuthorizationResult{ + AuthorizationCode: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), + } + fmt.Fprint(w, "Authentication successful. You can close this window.") + }) + + r.server = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", *callbackPort), + Handler: mux, + } + if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.errChan <- err + } + return nil +} + +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, authorizationURL string) (*auth.AuthorizationResult, error) { + select { + case authRes := <-r.authChan: + return authRes, nil + case err := <-r.errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (r *codeReceiver) close() { + if r.server != nil { + r.server.Close() + } +} + +func main() { + flag.Parse() + receiver := &codeReceiver{ + authChan: make(chan *auth.AuthorizationResult), + errChan: make(chan error), + } + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *callbackPort)) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + go receiver.serveRedirectHandler(listener) + defer receiver.close() + + authHandler, err := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + AuthorizationURLHandler: receiver.getAuthorizationCode, + // Uncomment the client configuration you want to use. + // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ + // ClientID: "", + // ClientSecret: "", + // }, + // DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + // Metadata: &oauthex.ClientRegistrationMetadata{ + // ClientName: "Dynamically registered MCP client", + // RedirectURIs: []string{fmt.Sprintf("http://localhost:%d", *callbackPort)}, + // Scope: "read", + // }, + // }, + }) + if err != nil { + log.Fatalf("failed to create auth handler: %v", err) + } + + transport := &mcp.StreamableClientTransport{ + Endpoint: *serverURL, + OAuthHandler: authHandler, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatalf("client.Connect(): %v", err) + } + defer session.Close() + + tools, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("session.ListTools(): %v", err) + } + log.Println("Tools:") + for _, tool := range tools.Tools { + log.Printf("- %q", tool.Name) + } +} diff --git a/examples/auth/server/main.go b/examples/auth/server/main.go new file mode 100644 index 00000000..94ad9ae3 --- /dev/null +++ b/examples/auth/server/main.go @@ -0,0 +1,167 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// Flags. +var ( + port = flag.Int("port", 8000, "Port to listen on") +) + +// Configuration required for this example. +var ( + // Authorization server to return in the protected resource metadata. + authorizationServer = "" + // Introspection endpoint for verifying tokens. + introspectionEndpoint = "" + // Client credentials used in the introspection request. + clientID = "" + clientSecret = "" +) + +func verifyToken(ctx context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) { + data := url.Values{} + data.Set("token", token) + data.Set("token_type_hint", "access_token") + + req, err := http.NewRequestWithContext(ctx, "POST", introspectionEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + req.SetBasicAuth(clientID, clientSecret) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + dump, _ := httputil.DumpResponse(resp, true) + log.Printf("Introspection failed: %s", dump) + return nil, fmt.Errorf("introspection failed with status %d", resp.StatusCode) + } + + var result struct { + Active bool `json:"active"` + Scope string `json:"scope"` + Exp int64 `json:"exp"` + Sub string `json:"sub"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + if !result.Active { + return nil, auth.ErrInvalidToken + } + + return &auth.TokenInfo{ + Scopes: strings.Fields(result.Scope), + Expiration: time.Unix(result.Exp, 0), + UserID: result.Sub, + }, nil +} + +type args struct { + Input string `json:"input"` +} + +func echo(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: args.Input}, + }, + }, nil, nil +} + +func main() { + flag.Parse() + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: fmt.Sprintf("http://localhost:%d/mcp", *port), + AuthorizationServers: []string{authorizationServer}, + ScopesSupported: []string{"read"}, + } + http.Handle("/.well-known/oauth-protected-resource", auth.ProtectedResourceMetadataHandler(metadata)) + + server := mcp.NewServer(&mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, nil) + server.AddReceivingMiddleware(createLoggingMiddleware()) + mcp.AddTool(server, &mcp.Tool{Name: "echo"}, echo) + + handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { + return server + }, nil) + + authMiddleware := auth.RequireBearerToken(verifyToken, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, + ResourceMetadataURL: fmt.Sprintf("http://localhost:%d/.well-known/oauth-protected-resource", *port), + }) + + http.Handle("/mcp", authMiddleware(handler)) + + log.Printf("Starting server on http://localhost:%d", *port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%d", *port), nil)) +} + +// createLoggingMiddleware creates an MCP middleware that logs method calls. +func createLoggingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func( + ctx context.Context, + method string, + req mcp.Request, + ) (mcp.Result, error) { + start := time.Now() + sessionID := req.GetSession().ID() + + // Log request details. + log.Printf("[REQUEST] Session: %s | Method: %s", + sessionID, + method) + + // Call the actual handler. + result, err := next(ctx, method, req) + + // Log response details. + duration := time.Since(start) + + if err != nil { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: ERROR | Duration: %v | Error: %v", + sessionID, + method, + duration, + err) + } else { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: OK | Duration: %v", + sessionID, + method, + duration) + } + + return result, err + } + } +} diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index ada34371..22a39f90 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -232,9 +232,50 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationURLHandler: func(ctx context.Context, url string) (*auth.AuthorizationResult, error) { + // Open the URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{AuthorizationCode: code, State: state}, nil + }, +}) + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `auth.AuthorizationCodeHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security @@ -243,9 +284,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough diff --git a/mcp/streamable.go b/mcp/streamable.go index 36ae5b12..9155f7bf 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -26,7 +26,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/modelcontextprotocol/go-sdk/auth" @@ -36,6 +35,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/oauthex" ) const ( @@ -1456,6 +1456,9 @@ type StreamableClientTransport struct { // - You want to avoid maintaining a persistent connection DisableStandaloneSSE bool + // OAuthHandler is an optional field that, if provided, will be used to authorize the requests. + OAuthHandler auth.OAuthHandler + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1531,6 +1534,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er cancel: cancel, failed: make(chan struct{}), disableStandaloneSSE: t.DisableStandaloneSSE, + oauthHandler: t.OAuthHandler, } return conn, nil } @@ -1549,6 +1553,9 @@ type streamableClientConn struct { // for receiving server-to-client notifications when no request is in flight. disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + // oauthHandler is the OAuth handler for the connection. + oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1724,14 +1731,59 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - c.setMCPHeaders(req) + doRequest := func() (*http.Response, error) { + if err := c.setMCPHeaders(req); err != nil { + // Failure to set headers means that the request was not sent. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + return resp, err + } - resp, err := c.client.Do(req) + resp, err := doRequest() if err != nil { - // Any error from client.Do means the request didn't reach the server. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + return err + } + + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + resp, err = doRequest() + if err != nil { + return err + } + } + if resp.StatusCode == http.StatusForbidden && c.oauthHandler != nil { + challenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + c.logger.Warn("%s: failed to parse WWW-Authenticate header: %v", requestSummary, err) + } else if oauthex.Error(challenges) == "insufficient_scope" { + // Trigger step-up authorization flow. + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + resp, err = doRequest() + if err != nil { + return err + } + } } if err := c.checkResponse(requestSummary, resp); err != nil { @@ -1799,23 +1851,32 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -// testAuth controls whether a fake Authorization header is added to outgoing requests. -// TODO: replace with a better mechanism when client-side auth is in place. -var testAuth atomic.Bool - -func (c *streamableClientConn) setMCPHeaders(req *http.Request) { +func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { c.mu.Lock() defer c.mu.Unlock() + if c.oauthHandler != nil { + ts, err := c.oauthHandler.TokenSource(c.ctx) + if err != nil { + return err + } + if ts != nil { + token, err := ts.Token() + if err != nil { + return err + } + if token != nil { + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + } + } + } if c.initializedResult != nil { req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) } if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - if testAuth.Load() { - req.Header.Set("Authorization", "Bearer foo") - } + return nil } func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { @@ -2068,7 +2129,9 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - c.setMCPHeaders(req) + if err := c.setMCPHeaders(req); err != nil { + return nil, err + } if lastEventID != "" { req.Header.Set(lastEventIDHeader, lastEventID) } @@ -2099,8 +2162,9 @@ func (c *streamableClientConn) Close() error { if err != nil { c.closeErr = err } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + if err := c.setMCPHeaders(req); err != nil { + c.closeErr = err + } else if _, err := c.client.Do(req); err != nil { c.closeErr = err } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2cbe4002..11089535 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -34,6 +34,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/oauth2" ) func TestStreamableTransports(t *testing.T) { @@ -1667,9 +1668,6 @@ func textContent(t *testing.T, res *CallToolResult) string { } func TestTokenInfo(t *testing.T) { - oldAuth := testAuth.Load() - defer testAuth.Store(oldAuth) - testAuth.Store(true) ctx := context.Background() // Create a server with a tool that returns TokenInfo. @@ -1680,7 +1678,10 @@ func TestTokenInfo(t *testing.T) { AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } return &auth.TokenInfo{ Scopes: []string{"scope"}, // Expiration is far, far in the future. @@ -1691,7 +1692,10 @@ func TestTokenInfo(t *testing.T) { httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() - transport := &StreamableClientTransport{Endpoint: httpServer.URL} + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: &auth.FakeOAuthHandler{Token: &oauth2.Token{AccessToken: "test-token"}}, + } client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 9aa0c8d7..d9fcc9d8 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -13,7 +13,10 @@ import ( "context" "errors" "fmt" + "log" "net/http" + "net/url" + "strings" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -113,51 +116,109 @@ type AuthServerMeta struct { // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of // PKCE code challenge methods supported by this authorization server. CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` -} -var wellKnownPaths = []string{ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", + // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server + // supports client ID metadata documents. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata -// from an OAuth authorization server with the given issuerURL. +// from an OAuth authorization server with the given metadataURL. // // It follows [RFC 8414]: -// - The well-known paths specified there are inserted into the URL's path, one at time. -// The first to succeed is used. -// - The Issuer field is checked against issuerURL. +// - The Issuer field is checked against metadataURL.Issuer. +// +// It also verifies that the authorization server supports PKCE and that the URLs +// in the metadata don't use dangerous schemes. +// +// It returns an error if the request fails with a non-4xx status code or the fetched +// metadata doesn't pass security validations. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 -func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { - var errs []error - for _, p := range wellKnownPaths { - u, err := prependToPath(issuerURL, p) - if err != nil { - // issuerURL is bad; no point in continuing. - return nil, err - } - asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) - if err == nil { - if asm.Issuer != issuerURL { // section 3.3 - // Security violation; don't keep trying. - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) +func GetAuthServerMeta(ctx context.Context, metadataURL AuthorizationServerMetadataURL, c *http.Client) (*AuthServerMeta, error) { + asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL.URL, 1<<20) + if err != nil { + log.Printf("Failed to get auth server metadata from %q: %v", metadataURL.URL, err) + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { + return nil, nil } + return nil, fmt.Errorf("%v", err) // Do not expose error types. + } + } + if asm.Issuer != metadataURL.Issuer { + // Validate the Issuer field (see RFC 8414, section 3.3). + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, metadataURL.Issuer) + } - if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) - } + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", metadataURL.Issuer) + } - // Validate endpoint URLs to prevent XSS attacks (see #526). - if err := validateAuthServerMetaURLs(asm); err != nil { - return nil, err - } + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err + } + log.Printf("Fetched authorization server metadata from %q", metadataURL.URL) - return asm, nil - } - errs = append(errs, err) + return asm, nil +} + +type AuthorizationServerMetadataURL struct { + // URL where the Authorization Server Metadata may be retrieved. + URL string + // Issuer that was used to construct the [URL]. + Issuer string +} + +// AuthorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func AuthorizationServerMetadataURLs(issuerURL string) []AuthorizationServerMetadataURL { + var urls []AuthorizationServerMetadataURL + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil } - return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, AuthorizationServerMetadataURL{ + URL: baseURL.String(), + Issuer: issuerURL, + }) + return urls } // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta diff --git a/oauthex/auth_meta_test.go b/oauthex/auth_meta_test.go index 1e608824..6363e098 100644 --- a/oauthex/auth_meta_test.go +++ b/oauthex/auth_meta_test.go @@ -85,7 +85,10 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { // The fake server sets issuer to https://localhost:, so compute that issuer. u, _ := url.Parse(ts.URL) - issuer := "https://localhost:" + u.Port() + metadataURL := AuthorizationServerMetadataURL{ + URL: "https://localhost:" + u.Port() + "/.well-known/oauth-authorization-server", + Issuer: "https://localhost:" + u.Port(), + } // The fake server presents a cert for example.com; set ServerName accordingly. httpClient := ts.Client() @@ -95,7 +98,7 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { httpClient.Transport = clone } - meta, err := GetAuthServerMeta(ctx, issuer, httpClient) + meta, err := GetAuthServerMeta(ctx, metadataURL, httpClient) if tt.wantError != "" { if err == nil { t.Fatal("wanted error but got none") diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index cdda695b..5b76116d 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -36,6 +36,14 @@ func prependToPath(urlStr, pre string) (string, error) { return u.String(), nil } +type httpStatusError struct { + StatusCode int +} + +func (e *httpStatusError) Error() string { + return fmt.Sprintf("bad status %d", e.StatusCode) +} + // getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both // RFC 9728 and RFC 8414. // It will not read more than limit bytes from the body. @@ -53,11 +61,9 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 } defer res.Body.Close() - // Specs require a 200. if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bad status %s", res.Status) + return nil, &httpStatusError{StatusCode: res.StatusCode} } - // Specs require application/json. ct := res.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(ct) if err != nil || mediaType != "application/json" { diff --git a/oauthex/oauth2_test.go b/oauthex/oauth2_test.go index 08d2d314..36f732e8 100644 --- a/oauthex/oauth2_test.go +++ b/oauthex/oauth2_test.go @@ -82,13 +82,13 @@ func TestParseSingleChallenge(t *testing.T) { tests := []struct { name string input string - want challenge + want Challenge wantErr bool }{ { name: "scheme only", input: "Basic", - want: challenge{ + want: Challenge{ Scheme: "basic", }, wantErr: false, @@ -96,7 +96,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one quoted param", input: `Bearer realm="example.com"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -105,7 +105,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one unquoted param", input: `Bearer realm=example.com`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -114,7 +114,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple params", input: `Bearer realm="example", error="invalid_token", error_description="The token expired"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -127,7 +127,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple unquoted params", input: `Bearer realm=example, error=invalid_token, error_description=The token expired`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -140,7 +140,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "case-insensitive scheme and keys", input: `BEARER ReAlM="example"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example"}, }, @@ -149,7 +149,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param with escaped quote", input: `Bearer realm="example \"foo\" bar"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": `example "foo" bar`}, }, @@ -158,7 +158,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param without quotes (token)", input: "Bearer realm=example.com", - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, diff --git a/oauthex/oauthex.go b/oauthex/oauthex.go index 34ed55b5..151da7e5 100644 --- a/oauthex/oauthex.go +++ b/oauthex/oauthex.go @@ -4,89 +4,3 @@ // Package oauthex implements extensions to OAuth2. package oauthex - -// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, -// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. -// -// The following features are not supported: -// - additional keys (§2, last sentence) -// - human-readable metadata (§2.1) -// - signed metadata (§2.2) -type ProtectedResourceMetadata struct { - // GENERATED BY GEMINI 2.5. - - // Resource (resource) is the protected resource's resource identifier. - // Required. - Resource string `json:"resource"` - - // AuthorizationServers (authorization_servers) is an optional slice containing a list of - // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be - // used with this protected resource. - AuthorizationServers []string `json:"authorization_servers,omitempty"` - - // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set - // document. This contains public keys belonging to the protected resource, such as - // signing key(s) that the resource server uses to sign resource responses. - JWKSURI string `json:"jwks_uri,omitempty"` - - // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope - // values (as defined in RFC 6749) used in authorization requests to request access - // to this protected resource. - ScopesSupported []string `json:"scopes_supported,omitempty"` - - // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing - // a list of the supported methods of sending an OAuth 2.0 bearer token to the - // protected resource. Defined values are "header", "body", and "query". - BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` - - // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms (alg values) supported by the protected - // resource for signing resource responses. - ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` - - // ResourceName (resource_name) is a human-readable name of the protected resource - // intended for display to the end user. It is RECOMMENDED that this field be included. - // This value may be internationalized. - ResourceName string `json:"resource_name,omitempty"` - - // ResourceDocumentation (resource_documentation) is an optional URL of a page containing - // human-readable information for developers using the protected resource. - // This value may be internationalized. - ResourceDocumentation string `json:"resource_documentation,omitempty"` - - // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing - // human-readable policy information on how a client can use the data provided. - // This value may be internationalized. - ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` - - // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected - // resource's human-readable terms of service. This value may be internationalized. - ResourceTOSURI string `json:"resource_tos_uri,omitempty"` - - // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an - // optional boolean indicating support for mutual-TLS client certificate-bound - // access tokens (RFC 8705). Defaults to false if omitted. - TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` - - // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional - // slice of 'type' values supported by the resource server for the - // 'authorization_details' parameter (RFC 9396). - AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` - - // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms supported by the resource server for validating - // DPoP proof JWTs (RFC 9449). - DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` - - // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean - // specifying whether the protected resource always requires the use of DPoP-bound - // access tokens (RFC 9449). Defaults to false if omitted. - DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` - - // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters - // about the protected resource as claims. If present, these values take precedence - // over values conveyed in plain JSON. - // TODO:implement. - // Note that §2.2 says it's okay to ignore this. - // SignedMetadata string `json:"signed_metadata,omitempty"` -} diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index bb61f797..bd869fa0 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -11,13 +11,12 @@ package oauthex import ( "context" - "errors" "fmt" + "log" "net/http" "net/url" "path" "strings" - "unicode" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -38,6 +37,7 @@ const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resour // // It then retrieves the metadata at that location using the given client (or the // default client if nil) and validates its resource field against resourceID. +// Deprecated: Use [GetProtectedResourceMetadata] instead. func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) @@ -47,7 +47,10 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, } // Insert well-known URI into URL. u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) - return getPRM(ctx, u.String(), c, resourceID) + return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ + URL: u.String(), + Resource: resourceID, + }, c) } // GetProtectedResourceMetadataFromHeader retrieves protected resource metadata @@ -57,8 +60,8 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, // Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata // matches the serverURL (the URL that the client used to make the original request to the resource server). // If there is no metadata URL in the header, it returns nil, nil. +// Deprecated: Use [GetProtectedResourceMetadata] instead. func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { - defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] if len(headers) == 0 { return nil, nil @@ -71,22 +74,31 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if metadataURL == "" { return nil, nil } - return getPRM(ctx, metadataURL, c, serverURL) + return GetProtectedResourceMetadata(ctx, ProtectedResourceMetadataURL{ + URL: metadataURL, + Resource: serverURL, + }, c) } -// getPRM makes a GET request to the given URL, and validates the response. -// As part of the validation, it compares the returned resource field to wantResource. -func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { - if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { - return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) - } - prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server. +// The metadataURL is typically a URL with a host:port and possibly a path. +// For example: +// +// https://example.com/server +func GetProtectedResourceMetadata(ctx context.Context, metadataURL ProtectedResourceMetadataURL, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) + // TODO: where HTTPS requirement comes from? conformance tests use HTTP. + // if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { + // return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) + // } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL.URL, 1<<20) if err != nil { return nil, err } // Validate the Resource field (see RFC 9728, section 3.3). - if prm.Resource != wantResource { - return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + if prm.Resource != metadataURL.Resource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, metadataURL.Resource) } // Validate the authorization server URLs to prevent XSS attacks (see #526). for _, u := range prm.AuthorizationServers { @@ -97,22 +109,51 @@ func getPRM(ctx context.Context, purl string, c *http.Client, wantResource strin return prm, nil } -// challenge represents a single authentication challenge from a WWW-Authenticate header. -// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. -type challenge struct { - // GENERATED BY GEMINI 2.5. - // - // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). - // It is case-insensitive. A parsed value will always be lower-case. - Scheme string - // Params is a map of authentication parameters. - // Keys are case-insensitive. Parsed keys are always lower-case. - Params map[string]string +type ProtectedResourceMetadataURL struct { + // URL represents a URL where Protected Resource Metadata may be retrieved. + URL string + // Resource represents the corresponding resource URL for [URL]. + // It is required to perform validation described in RFC 9728, section 3.3. + Resource string +} + +// ProtectedResourceMetadataURLs returns a list of URLs to try when looking for +// protected resource metadata as mandated by the MCP specification. +func ProtectedResourceMetadataURLs(metadataURL, resourceURL string) []ProtectedResourceMetadataURL { + var urls []ProtectedResourceMetadataURL + if metadataURL != "" { + urls = append(urls, ProtectedResourceMetadataURL{ + URL: metadataURL, + Resource: resourceURL, + }) + } + // Produce fallbacks per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // "At the path of the server's MCP endpoint". + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, ProtectedResourceMetadataURL{ + URL: mu.String(), + Resource: resourceURL, + }) + // "At the root". + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, ProtectedResourceMetadataURL{ + URL: mu.String(), + Resource: ru.String(), + }) + log.Printf("Resource metadata URLs: %v", urls) + return urls } // ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. -func ResourceMetadataURL(cs []challenge) string { +func ResourceMetadataURL(cs []Challenge) string { for _, c := range cs { if u := c.Params["resource_metadata"]; u != "" { return u @@ -121,161 +162,11 @@ func ResourceMetadataURL(cs []challenge) string { return "" } -// ParseWWWAuthenticate parses a WWW-Authenticate header string. -// The header format is defined in RFC 9110, Section 11.6.1, and can contain -// one or more challenges, separated by commas. -// It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]challenge, error) { - // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []challenge - for _, h := range headers { - challengeStrings, err := splitChallenges(h) - if err != nil { - return nil, err - } - for _, cs := range challengeStrings { - if strings.TrimSpace(cs) == "" { - continue - } - challenge, err := parseSingleChallenge(cs) - if err != nil { - return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) - } - challenges = append(challenges, challenge) - } - } - return challenges, nil -} - -// splitChallenges splits a header value containing one or more challenges. -// It correctly handles commas within quoted strings and distinguishes between -// commas separating auth-params and commas separating challenges. -func splitChallenges(header string) ([]string, error) { - // GENERATED BY GEMINI 2.5. - var challenges []string - inQuotes := false - start := 0 - for i, r := range header { - if r == '"' { - if i > 0 && header[i-1] != '\\' { - inQuotes = !inQuotes - } else if i == 0 { - // A challenge begins with an auth-scheme, which is a token, which cannot contain - // a quote. - return nil, errors.New(`challenge begins with '"'`) - } - } else if r == ',' && !inQuotes { - // This is a potential challenge separator. - // A new challenge does not start with `key=value`. - // We check if the part after the comma looks like a parameter. - lookahead := strings.TrimSpace(header[i+1:]) - eqPos := strings.Index(lookahead, "=") - - isParam := false - if eqPos > 0 { - // Check if the part before '=' is a single token (no spaces). - token := lookahead[:eqPos] - if strings.IndexFunc(token, unicode.IsSpace) == -1 { - isParam = true - } - } - - if !isParam { - // The part after the comma does not look like a parameter, - // so this comma separates challenges. - challenges = append(challenges, header[start:i]) - start = i + 1 - } - } - } - // Add the last (or only) challenge to the list. - challenges = append(challenges, header[start:]) - return challenges, nil -} - -// parseSingleChallenge parses a string containing exactly one challenge. -// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (challenge, error) { - // GENERATED BY GEMINI 2.5, human-tweaked. - s = strings.TrimSpace(s) - if s == "" { - return challenge{}, errors.New("empty challenge string") - } - - scheme, paramsStr, found := strings.Cut(s, " ") - c := challenge{Scheme: strings.ToLower(scheme)} - if !found { - return c, nil - } - - params := make(map[string]string) - - // Parse the key-value parameters. - for paramsStr != "" { - // Find the end of the parameter key. - keyEnd := strings.Index(paramsStr, "=") - if keyEnd <= 0 { - return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) - } - key := strings.TrimSpace(paramsStr[:keyEnd]) - - // Move the string past the key and the '='. - paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) - - var value string - if strings.HasPrefix(paramsStr, "\"") { - // The value is a quoted string. - paramsStr = paramsStr[1:] // Consume the opening quote. - var valBuilder strings.Builder - i := 0 - for ; i < len(paramsStr); i++ { - // Handle escaped characters. - if paramsStr[i] == '\\' && i+1 < len(paramsStr) { - valBuilder.WriteByte(paramsStr[i+1]) - i++ // We've consumed two characters. - } else if paramsStr[i] == '"' { - // End of the quoted string. - break - } else { - valBuilder.WriteByte(paramsStr[i]) - } - } - - // A quoted string must be terminated. - if i == len(paramsStr) { - return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") - } - - value = valBuilder.String() - // Move the string past the value and the closing quote. - paramsStr = strings.TrimSpace(paramsStr[i+1:]) - } else { - // The value is a token. It ends at the next comma or the end of the string. - commaPos := strings.Index(paramsStr, ",") - if commaPos == -1 { - value = paramsStr - paramsStr = "" - } else { - value = strings.TrimSpace(paramsStr[:commaPos]) - paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check - } - } - if value == "" { - return challenge{}, fmt.Errorf("no value for auth param %q", key) - } - - // Per RFC 9110, parameter keys are case-insensitive. - params[strings.ToLower(key)] = value - - // If there is a comma, consume it and continue to the next parameter. - if strings.HasPrefix(paramsStr, ",") { - paramsStr = strings.TrimSpace(paramsStr[1:]) - } else if paramsStr != "" { - // If there's content but it's not a new parameter, the format is wrong. - return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) +func Scopes(cs []Challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) } } - - // Per RFC 9110, the scheme is case-insensitive. - return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil + return nil } diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go new file mode 100644 index 00000000..443d5ba8 --- /dev/null +++ b/oauthex/resource_meta_public.go @@ -0,0 +1,284 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +// This is a temporary file to expose the required objects to the main package. + +package oauthex + +import ( + "errors" + "fmt" + "strings" + "unicode" +) + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// Challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type Challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +func Error(cs []Challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []Challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (Challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return Challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := Challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return Challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index 531a1f9c..c13bc1d5 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -226,7 +226,11 @@ func TestGetAuthServerMetaRejectsDangerousURLs(t *testing.T) { defer server.Close() ctx := context.Background() - _, err := GetAuthServerMeta(ctx, server.URL, server.Client()) + metadataURL := AuthorizationServerMetadataURL{ + URL: server.URL, + Issuer: server.URL, + } + _, err := GetAuthServerMeta(ctx, metadataURL, server.Client()) if err == nil { t.Fatal("GetAuthServerMeta(): got nil error, want error") } diff --git a/scripts/client-conformance.sh b/scripts/client-conformance.sh index c093c75f..17528450 100755 --- a/scripts/client-conformance.sh +++ b/scripts/client-conformance.sh @@ -10,6 +10,7 @@ set -e RESULT_DIR="" WORKDIR="" CONFORMANCE_REPO="" +SUITE="core" FINAL_EXIT_CODE=0 usage() { @@ -21,9 +22,11 @@ usage() { echo " --result_dir Save results to the specified directory" echo " --conformance_repo Run conformance tests from a local checkout" echo " instead of using the latest npm release" + echo " --suite Which suite to run (default: core)" echo " --help Show this help message" } + # Parse arguments. while [[ $# -gt 0 ]]; do case $1 in @@ -35,6 +38,10 @@ while [[ $# -gt 0 ]]; do CONFORMANCE_REPO="$2" shift 2 ;; + --suite) + SUITE="$2" + shift 2 + ;; --help) usage exit 0 @@ -56,7 +63,7 @@ else fi # Build the conformance server. -go build -o "$WORKDIR/conformance-client" ./conformance/everything-client +go build -tags mcp_go_client_oauth -o "$WORKDIR/conformance-client" ./conformance/everything-client # Run conformance tests from the work directory to avoid writing results to the repo. echo "Running conformance tests..." @@ -65,13 +72,13 @@ if [ -n "$CONFORMANCE_REPO" ]; then (cd "$WORKDIR" && \ npm --prefix "$CONFORMANCE_REPO" run start -- \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? else (cd "$WORKDIR" && \ npx @modelcontextprotocol/conformance@latest \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? fi