Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pkg/context/request_ids.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package context

import (
"context"
"crypto/rand"
"fmt"
)

type requestIDCtxKey struct{}
type operationIDCtxKey struct{}

func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDCtxKey{}, requestID)
}

func RequestID(ctx context.Context) (string, bool) {
requestID, ok := ctx.Value(requestIDCtxKey{}).(string)
return requestID, ok
}

func WithOperationID(ctx context.Context, operationID string) context.Context {
return context.WithValue(ctx, operationIDCtxKey{}, operationID)
}

func OperationID(ctx context.Context) (string, bool) {
operationID, ok := ctx.Value(operationIDCtxKey{}).(string)
return operationID, ok
}

func GenerateRequestID() (string, error) {
return generateID("req")
}

func GenerateOperationID() (string, error) {
return generateID("op")
}

func generateID(prefix string) (string, error) {
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("generate %s id: %w", prefix, err)
}
return fmt.Sprintf("%s_%x", prefix, buf), nil
}
23 changes: 23 additions & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

ghcontext "github.com/github/github-mcp-server/pkg/context"
gherrors "github.com/github/github-mcp-server/pkg/errors"
"github.com/github/github-mcp-server/pkg/inventory"
"github.com/github/github-mcp-server/pkg/octicons"
Expand Down Expand Up @@ -107,6 +108,7 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci
// and any middleware that needs to read or modify the context should be before it.
ghServer.AddReceivingMiddleware(middleware...)
ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps))
ghServer.AddReceivingMiddleware(withOperationID)
ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext)

if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 {
Expand Down Expand Up @@ -176,6 +178,27 @@ func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler {
}
}

func withOperationID(next mcp.MethodHandler) mcp.MethodHandler {
return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) {
requestID, ok := ghcontext.RequestID(ctx)
if !ok || requestID == "" {
requestID, err = ghcontext.GenerateRequestID()
if err != nil {
return nil, err
}
ctx = ghcontext.WithRequestID(ctx, requestID)
}

operationID, err := ghcontext.GenerateOperationID()
if err != nil {
return nil, err
}
ctx = ghcontext.WithOperationID(ctx, operationID)

return next(ctx, method, req)
}
}

// NewServer creates a new GitHub MCP server with the specified GH client and logger.
func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server {
if opts == nil {
Expand Down
53 changes: 53 additions & 0 deletions pkg/github/server_operation_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package github

import (
"context"
"testing"

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWithOperationID_PreservesRequestIDAndAddsOperationID(t *testing.T) {
t.Parallel()

var capturedRequestID string
var capturedOperationID string
handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
var ok bool
capturedRequestID, ok = ghcontext.RequestID(ctx)
require.True(t, ok)

capturedOperationID, ok = ghcontext.OperationID(ctx)
require.True(t, ok)
return nil, nil
})

_, err := handler(ghcontext.WithRequestID(context.Background(), "req_client"), "tools/call", nil)
require.NoError(t, err)

assert.Equal(t, "req_client", capturedRequestID)
assert.Regexp(t, `^op_[0-9a-f]+$`, capturedOperationID)
}

func TestWithOperationID_GeneratesUniqueOperationIDs(t *testing.T) {
t.Parallel()

var operationIDs []string
handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
operationID, ok := ghcontext.OperationID(ctx)
require.True(t, ok)
operationIDs = append(operationIDs, operationID)
return nil, nil
})

_, err := handler(context.Background(), "tools/call", nil)
require.NoError(t, err)
_, err = handler(context.Background(), "tools/call", nil)
require.NoError(t, err)

require.Len(t, operationIDs, 2)
assert.NotEqual(t, operationIDs[0], operationIDs[1])
}
2 changes: 2 additions & 0 deletions pkg/http/headers/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ const (
ForwardedHostHeader = "X-Forwarded-Host"
// ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying.
ForwardedProtoHeader = "X-Forwarded-Proto"
// RequestIDHeader is a standard request-correlation header.
RequestIDHeader = "X-Request-ID"

// RequestHmacHeader is used to authenticate requests to the Raw API.
RequestHmacHeader = "Request-Hmac"
Expand Down
12 changes: 12 additions & 0 deletions pkg/http/middleware/request_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ func WithRequestConfig(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

requestID := strings.TrimSpace(r.Header.Get(headers.RequestIDHeader))
if requestID == "" {
generatedRequestID, err := ghcontext.GenerateRequestID()
if err != nil {
http.Error(w, "failed to generate request id", http.StatusInternalServerError)
return
}
requestID = generatedRequestID
}
ctx = ghcontext.WithRequestID(ctx, requestID)
w.Header().Set(headers.RequestIDHeader, requestID)

// Readonly mode
if relaxedParseBool(r.Header.Get(headers.MCPReadOnlyHeader)) {
ctx = ghcontext.WithReadonly(ctx, true)
Expand Down
52 changes: 52 additions & 0 deletions pkg/http/middleware/request_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"

ghcontext "github.com/github/github-mcp-server/pkg/context"
"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWithRequestConfig_PreservesProvidedRequestID(t *testing.T) {
t.Parallel()

recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/", nil)
request.Header.Set(headers.RequestIDHeader, "client-request-id")

var requestID string
handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
var ok bool
requestID, ok = ghcontext.RequestID(r.Context())
require.True(t, ok)
}))

handler.ServeHTTP(recorder, request)

assert.Equal(t, "client-request-id", requestID)
assert.Equal(t, "client-request-id", recorder.Header().Get(headers.RequestIDHeader))
}

func TestWithRequestConfig_GeneratesRequestIDWhenMissing(t *testing.T) {
t.Parallel()

recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/", nil)

var requestID string
handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
var ok bool
requestID, ok = ghcontext.RequestID(r.Context())
require.True(t, ok)
}))

handler.ServeHTTP(recorder, request)

assert.NotEmpty(t, requestID)
assert.Equal(t, requestID, recorder.Header().Get(headers.RequestIDHeader))
assert.Regexp(t, `^req_[0-9a-f]+$`, requestID)
}