Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/platform.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ toolkits:
password: "${TRINO_PASSWORD}"
catalog: "iceberg"
ssl: true
description: "Production data warehouse for batch analytics and reporting"
staging:
host: "trino-staging.example.com"
port: 8080
description: "Staging environment for testing queries before production"
default: production
config:
default_limit: 1000
Expand Down
71 changes: 54 additions & 17 deletions pkg/middleware/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middleware
import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -69,28 +70,13 @@ func MCPToolCallMiddleware(authenticator Authenticator, authorizer Authorizer, t
return nil, newInvalidParamsError(fmt.Sprintf("invalid request: %v", err))
}

// Create platform context
// Build platform context and enrich the Go context.
pc := NewPlatformContext(generateRequestID())
pc.ToolName = toolName
pc.SessionID = extractSessionID(req)
pc.Transport = transport
pc.Source = "mcp"
ctx = WithPlatformContext(ctx, pc)

// Store ServerSession and progress token in context for
// progress notifications and client logging.
if ss := extractServerSession(req); ss != nil {
ctx = WithServerSession(ctx, ss)
}
if pt := extractProgressToken(req); pt != nil {
ctx = WithProgressToken(ctx, pt)
}

// Populate toolkit metadata
populateToolkitMetadata(pc, toolkitLookup, toolName)

// Bridge auth token from Streamable HTTP per-request headers.
ctx = bridgeAuthToken(ctx, req)
ctx = buildToolCallContext(ctx, req, pc, toolkitLookup, toolName)

// Authenticate and authorize
return authenticateAndAuthorize(ctx, method, req, next, authParams{
Expand All @@ -103,6 +89,35 @@ func MCPToolCallMiddleware(authenticator Authenticator, authorizer Authorizer, t
}
}

// buildToolCallContext enriches the context with session, progress, toolkit
// metadata, connection override, and auth token bridging for a tool call.
func buildToolCallContext(ctx context.Context, req mcp.Request, pc *PlatformContext, toolkitLookup ToolkitLookup, toolName string) context.Context {
ctx = WithPlatformContext(ctx, pc)

// Store ServerSession and progress token in context for
// progress notifications and client logging.
if ss := extractServerSession(req); ss != nil {
ctx = WithServerSession(ctx, ss)
}
if pt := extractProgressToken(req); pt != nil {
ctx = WithProgressToken(ctx, pt)
}

// Populate toolkit metadata (kind, name, default connection).
populateToolkitMetadata(pc, toolkitLookup, toolName)

// Override connection from request arguments for accurate audit logging.
// With multi-connection toolkits, the toolkit's Connection() returns the
// default, but the actual connection is determined by the request's
// "connection" argument.
if connFromArgs := extractConnectionArg(req); connFromArgs != "" {
pc.Connection = connFromArgs
}

// Bridge auth token from Streamable HTTP per-request headers.
return bridgeAuthToken(ctx, req)
}

// populateToolkitMetadata fills PlatformContext toolkit fields from the lookup.
func populateToolkitMetadata(pc *PlatformContext, lookup ToolkitLookup, toolName string) {
if lookup == nil {
Expand Down Expand Up @@ -349,3 +364,25 @@ func generateRequestID() string {
}
return fmt.Sprintf("req-%x", b)
}

// extractConnectionArg extracts the "connection" field from tool call arguments.
// Returns an empty string if the request has no connection argument.
func extractConnectionArg(req mcp.Request) string {
if req == nil {
return ""
}
params := req.GetParams()
if params == nil {
return ""
}
callParams, ok := params.(*mcp.CallToolParamsRaw)
if !ok || callParams == nil || len(callParams.Arguments) == 0 {
return ""
}
var args map[string]any
if err := json.Unmarshal(callParams.Arguments, &args); err != nil {
return ""
}
conn, _ := args["connection"].(string)
return conn
}
157 changes: 157 additions & 0 deletions pkg/middleware/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"context"
"encoding/json"
"errors"
"net/http"
"testing"
Expand Down Expand Up @@ -812,3 +813,159 @@ func TestNewInvalidParamsError(t *testing.T) {
t.Error("expected non-empty Error() string")
}
}

func TestExtractConnectionArg(t *testing.T) {
tests := []struct {
name string
req mcp.Request
want string
}{
{
name: "nil request",
req: nil,
want: "",
},
{
name: "no arguments",
req: newMCPTestRequest(mcpTestToolName),
want: "",
},
{
name: "connection present",
req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: mcpTestToolName,
Arguments: json.RawMessage(`{"connection":"warehouse","sql":"SELECT 1"}`),
},
},
want: "warehouse",
},
{
name: "connection absent in args",
req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: mcpTestToolName,
Arguments: json.RawMessage(`{"sql":"SELECT 1"}`),
},
},
want: "",
},
{
name: "malformed JSON arguments",
req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: mcpTestToolName,
Arguments: json.RawMessage(`{invalid`),
},
},
want: "",
},
{
name: "connection is non-string",
req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: mcpTestToolName,
Arguments: json.RawMessage(`{"connection":42}`),
},
},
want: "",
},
{
name: "nil params",
req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: nil,
},
want: "",
},
{
name: "wrong params type",
req: &mcp.ServerRequest[*mcp.ListToolsParams]{
Params: &mcp.ListToolsParams{},
},
want: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractConnectionArg(tt.req)
if got != tt.want {
t.Errorf("extractConnectionArg() = %q, want %q", got, tt.want)
}
})
}
}

func TestMCPToolCallMiddleware_ConnectionOverride(t *testing.T) {
authenticator := &mcpTestAuthenticator{
userInfo: &UserInfo{
UserID: mcpTestUserID,
Roles: []string{mcpTestPersona},
},
}
authorizer := &mcpTestAuthorizer{authorized: true, personaName: mcpTestPersona}
toolkitLookup := &mcpTestToolkitLookup{
kind: "trino",
name: "default-trino",
connection: "default-trino",
found: true,
}

middleware := MCPToolCallMiddleware(authenticator, authorizer, toolkitLookup, mcpTestStdio)

t.Run("connection arg overrides toolkit default", func(t *testing.T) {
next := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
pc := GetPlatformContext(ctx)
if pc == nil {
t.Fatal(mcpTestPCExpected)
}
if pc.Connection != "elasticsearch" {
t.Errorf("Connection = %q, want 'elasticsearch'", pc.Connection)
}
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: "ok"}},
}, nil
}

handler := middleware(next)
req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: testAuditToolName,
Arguments: json.RawMessage(`{"connection":"elasticsearch","sql":"SELECT 1"}`),
},
}

_, err := handler(context.Background(), mcpTestMethod, req)
if err != nil {
t.Fatalf(mcpTestErrFmt, err)
}
})

t.Run("no connection arg keeps toolkit default", func(t *testing.T) {
next := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
pc := GetPlatformContext(ctx)
if pc == nil {
t.Fatal(mcpTestPCExpected)
}
if pc.Connection != "default-trino" {
t.Errorf("Connection = %q, want 'default-trino'", pc.Connection)
}
return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: "ok"}},
}, nil
}

handler := middleware(next)
req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{
Params: &mcp.CallToolParamsRaw{
Name: testAuditToolName,
Arguments: json.RawMessage(`{"sql":"SELECT 1"}`),
},
}

_, err := handler(context.Background(), mcpTestMethod, req)
if err != nil {
t.Fatalf(mcpTestErrFmt, err)
}
})
}
32 changes: 24 additions & 8 deletions pkg/platform/connections_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ import (
"encoding/json"

"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/txn2/mcp-data-platform/pkg/toolkit"
)

// connectionEntry describes a single toolkit connection.
type connectionEntry struct {
Kind string `json:"kind"`
Name string `json:"name"`
Connection string `json:"connection"`
Kind string `json:"kind"`
Name string `json:"name"`
Connection string `json:"connection"`
Description string `json:"description,omitempty"`
IsDefault bool `json:"is_default,omitempty"`
}

// listConnectionsOutput is the JSON response for the list_connections tool.
Expand Down Expand Up @@ -40,11 +44,23 @@ func (p *Platform) handleListConnections(_ context.Context, _ *mcp.CallToolReque

entries := make([]connectionEntry, 0, len(toolkits))
for _, tk := range toolkits {
entries = append(entries, connectionEntry{
Kind: tk.Kind(),
Name: tk.Name(),
Connection: tk.Connection(),
})
if lister, ok := tk.(toolkit.ConnectionLister); ok {
for _, conn := range lister.ListConnections() {
entries = append(entries, connectionEntry{
Kind: tk.Kind(),
Name: conn.Name,
Connection: conn.Name,
Description: conn.Description,
IsDefault: conn.IsDefault,
})
}
} else {
entries = append(entries, connectionEntry{
Kind: tk.Kind(),
Name: tk.Name(),
Connection: tk.Connection(),
})
}
}

out := listConnectionsOutput{
Expand Down
Loading
Loading