diff --git a/configs/platform.yaml b/configs/platform.yaml index 2b8acd68..faea20fe 100644 --- a/configs/platform.yaml +++ b/configs/platform.yaml @@ -142,9 +142,11 @@ toolkits: password: "${TRINO_PASSWORD}" catalog: "iceberg" ssl: true + description: "Production data warehouse for batch analytics and reporting" staging: host: "trino-staging.example.com" port: 8080 + description: "Staging environment for testing queries before production" default: production config: default_limit: 1000 diff --git a/pkg/middleware/mcp.go b/pkg/middleware/mcp.go index 43f4e292..6fc79b09 100644 --- a/pkg/middleware/mcp.go +++ b/pkg/middleware/mcp.go @@ -3,6 +3,7 @@ package middleware import ( "context" "crypto/rand" + "encoding/json" "errors" "fmt" "log/slog" @@ -69,28 +70,13 @@ func MCPToolCallMiddleware(authenticator Authenticator, authorizer Authorizer, t return nil, newInvalidParamsError(fmt.Sprintf("invalid request: %v", err)) } - // Create platform context + // Build platform context and enrich the Go context. pc := NewPlatformContext(generateRequestID()) pc.ToolName = toolName pc.SessionID = extractSessionID(req) pc.Transport = transport pc.Source = "mcp" - ctx = WithPlatformContext(ctx, pc) - - // Store ServerSession and progress token in context for - // progress notifications and client logging. - if ss := extractServerSession(req); ss != nil { - ctx = WithServerSession(ctx, ss) - } - if pt := extractProgressToken(req); pt != nil { - ctx = WithProgressToken(ctx, pt) - } - - // Populate toolkit metadata - populateToolkitMetadata(pc, toolkitLookup, toolName) - - // Bridge auth token from Streamable HTTP per-request headers. - ctx = bridgeAuthToken(ctx, req) + ctx = buildToolCallContext(ctx, req, pc, toolkitLookup, toolName) // Authenticate and authorize return authenticateAndAuthorize(ctx, method, req, next, authParams{ @@ -103,6 +89,35 @@ func MCPToolCallMiddleware(authenticator Authenticator, authorizer Authorizer, t } } +// buildToolCallContext enriches the context with session, progress, toolkit +// metadata, connection override, and auth token bridging for a tool call. +func buildToolCallContext(ctx context.Context, req mcp.Request, pc *PlatformContext, toolkitLookup ToolkitLookup, toolName string) context.Context { + ctx = WithPlatformContext(ctx, pc) + + // Store ServerSession and progress token in context for + // progress notifications and client logging. + if ss := extractServerSession(req); ss != nil { + ctx = WithServerSession(ctx, ss) + } + if pt := extractProgressToken(req); pt != nil { + ctx = WithProgressToken(ctx, pt) + } + + // Populate toolkit metadata (kind, name, default connection). + populateToolkitMetadata(pc, toolkitLookup, toolName) + + // Override connection from request arguments for accurate audit logging. + // With multi-connection toolkits, the toolkit's Connection() returns the + // default, but the actual connection is determined by the request's + // "connection" argument. + if connFromArgs := extractConnectionArg(req); connFromArgs != "" { + pc.Connection = connFromArgs + } + + // Bridge auth token from Streamable HTTP per-request headers. + return bridgeAuthToken(ctx, req) +} + // populateToolkitMetadata fills PlatformContext toolkit fields from the lookup. func populateToolkitMetadata(pc *PlatformContext, lookup ToolkitLookup, toolName string) { if lookup == nil { @@ -349,3 +364,25 @@ func generateRequestID() string { } return fmt.Sprintf("req-%x", b) } + +// extractConnectionArg extracts the "connection" field from tool call arguments. +// Returns an empty string if the request has no connection argument. +func extractConnectionArg(req mcp.Request) string { + if req == nil { + return "" + } + params := req.GetParams() + if params == nil { + return "" + } + callParams, ok := params.(*mcp.CallToolParamsRaw) + if !ok || callParams == nil || len(callParams.Arguments) == 0 { + return "" + } + var args map[string]any + if err := json.Unmarshal(callParams.Arguments, &args); err != nil { + return "" + } + conn, _ := args["connection"].(string) + return conn +} diff --git a/pkg/middleware/mcp_test.go b/pkg/middleware/mcp_test.go index e7d97565..f47916b1 100644 --- a/pkg/middleware/mcp_test.go +++ b/pkg/middleware/mcp_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "encoding/json" "errors" "net/http" "testing" @@ -812,3 +813,159 @@ func TestNewInvalidParamsError(t *testing.T) { t.Error("expected non-empty Error() string") } } + +func TestExtractConnectionArg(t *testing.T) { + tests := []struct { + name string + req mcp.Request + want string + }{ + { + name: "nil request", + req: nil, + want: "", + }, + { + name: "no arguments", + req: newMCPTestRequest(mcpTestToolName), + want: "", + }, + { + name: "connection present", + req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: mcpTestToolName, + Arguments: json.RawMessage(`{"connection":"warehouse","sql":"SELECT 1"}`), + }, + }, + want: "warehouse", + }, + { + name: "connection absent in args", + req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: mcpTestToolName, + Arguments: json.RawMessage(`{"sql":"SELECT 1"}`), + }, + }, + want: "", + }, + { + name: "malformed JSON arguments", + req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: mcpTestToolName, + Arguments: json.RawMessage(`{invalid`), + }, + }, + want: "", + }, + { + name: "connection is non-string", + req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: mcpTestToolName, + Arguments: json.RawMessage(`{"connection":42}`), + }, + }, + want: "", + }, + { + name: "nil params", + req: &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: nil, + }, + want: "", + }, + { + name: "wrong params type", + req: &mcp.ServerRequest[*mcp.ListToolsParams]{ + Params: &mcp.ListToolsParams{}, + }, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractConnectionArg(tt.req) + if got != tt.want { + t.Errorf("extractConnectionArg() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestMCPToolCallMiddleware_ConnectionOverride(t *testing.T) { + authenticator := &mcpTestAuthenticator{ + userInfo: &UserInfo{ + UserID: mcpTestUserID, + Roles: []string{mcpTestPersona}, + }, + } + authorizer := &mcpTestAuthorizer{authorized: true, personaName: mcpTestPersona} + toolkitLookup := &mcpTestToolkitLookup{ + kind: "trino", + name: "default-trino", + connection: "default-trino", + found: true, + } + + middleware := MCPToolCallMiddleware(authenticator, authorizer, toolkitLookup, mcpTestStdio) + + t.Run("connection arg overrides toolkit default", func(t *testing.T) { + next := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + pc := GetPlatformContext(ctx) + if pc == nil { + t.Fatal(mcpTestPCExpected) + } + if pc.Connection != "elasticsearch" { + t.Errorf("Connection = %q, want 'elasticsearch'", pc.Connection) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}, + }, nil + } + + handler := middleware(next) + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: testAuditToolName, + Arguments: json.RawMessage(`{"connection":"elasticsearch","sql":"SELECT 1"}`), + }, + } + + _, err := handler(context.Background(), mcpTestMethod, req) + if err != nil { + t.Fatalf(mcpTestErrFmt, err) + } + }) + + t.Run("no connection arg keeps toolkit default", func(t *testing.T) { + next := func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + pc := GetPlatformContext(ctx) + if pc == nil { + t.Fatal(mcpTestPCExpected) + } + if pc.Connection != "default-trino" { + t.Errorf("Connection = %q, want 'default-trino'", pc.Connection) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}, + }, nil + } + + handler := middleware(next) + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Params: &mcp.CallToolParamsRaw{ + Name: testAuditToolName, + Arguments: json.RawMessage(`{"sql":"SELECT 1"}`), + }, + } + + _, err := handler(context.Background(), mcpTestMethod, req) + if err != nil { + t.Fatalf(mcpTestErrFmt, err) + } + }) +} diff --git a/pkg/platform/connections_tool.go b/pkg/platform/connections_tool.go index 15ddb07c..d9754c2f 100644 --- a/pkg/platform/connections_tool.go +++ b/pkg/platform/connections_tool.go @@ -5,13 +5,17 @@ import ( "encoding/json" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/txn2/mcp-data-platform/pkg/toolkit" ) // connectionEntry describes a single toolkit connection. type connectionEntry struct { - Kind string `json:"kind"` - Name string `json:"name"` - Connection string `json:"connection"` + Kind string `json:"kind"` + Name string `json:"name"` + Connection string `json:"connection"` + Description string `json:"description,omitempty"` + IsDefault bool `json:"is_default,omitempty"` } // listConnectionsOutput is the JSON response for the list_connections tool. @@ -40,11 +44,23 @@ func (p *Platform) handleListConnections(_ context.Context, _ *mcp.CallToolReque entries := make([]connectionEntry, 0, len(toolkits)) for _, tk := range toolkits { - entries = append(entries, connectionEntry{ - Kind: tk.Kind(), - Name: tk.Name(), - Connection: tk.Connection(), - }) + if lister, ok := tk.(toolkit.ConnectionLister); ok { + for _, conn := range lister.ListConnections() { + entries = append(entries, connectionEntry{ + Kind: tk.Kind(), + Name: conn.Name, + Connection: conn.Name, + Description: conn.Description, + IsDefault: conn.IsDefault, + }) + } + } else { + entries = append(entries, connectionEntry{ + Kind: tk.Kind(), + Name: tk.Name(), + Connection: tk.Connection(), + }) + } } out := listConnectionsOutput{ diff --git a/pkg/platform/connections_tool_test.go b/pkg/platform/connections_tool_test.go index fd32fc72..d208b312 100644 --- a/pkg/platform/connections_tool_test.go +++ b/pkg/platform/connections_tool_test.go @@ -14,6 +14,7 @@ import ( "github.com/txn2/mcp-data-platform/pkg/query" "github.com/txn2/mcp-data-platform/pkg/registry" "github.com/txn2/mcp-data-platform/pkg/semantic" + "github.com/txn2/mcp-data-platform/pkg/toolkit" ) // mockToolkit implements registry.Toolkit for testing. @@ -33,6 +34,16 @@ func (*mockToolkit) SetSemanticProvider(_ semantic.Provider) {} func (*mockToolkit) SetQueryProvider(_ query.Provider) {} func (*mockToolkit) Close() error { return nil } +// mockConnectionListerToolkit implements both registry.Toolkit and toolkit.ConnectionLister. +type mockConnectionListerToolkit struct { + mockToolkit + connections []toolkit.ConnectionDetail +} + +func (m *mockConnectionListerToolkit) ListConnections() []toolkit.ConnectionDetail { + return m.connections +} + func TestHandleListConnections(t *testing.T) { t.Run("returns empty list when no toolkits", func(t *testing.T) { p := &Platform{ @@ -142,6 +153,104 @@ func TestHandleListConnections(t *testing.T) { }) } +func TestHandleListConnections_WithConnectionLister(t *testing.T) { + t.Run("expands multi-connection toolkit", func(t *testing.T) { + reg := registry.NewRegistry() + require.NoError(t, reg.Register(&mockConnectionListerToolkit{ + mockToolkit: mockToolkit{ + kind: "trino", + name: "warehouse", + tools: []string{"trino_query"}, + }, + connections: []toolkit.ConnectionDetail{ + {Name: "warehouse", Description: "Analytics warehouse", IsDefault: true}, + {Name: "elasticsearch", Description: "Sales data", IsDefault: false}, + {Name: "cassandra", Description: "", IsDefault: false}, + }, + })) + require.NoError(t, reg.Register(&mockToolkit{ + kind: "datahub", + name: "primary", + connection: "primary-datahub", + tools: []string{"datahub_search"}, + })) + + p := &Platform{toolkitRegistry: reg} + result, _, err := p.handleListConnections(context.Background(), &mcp.CallToolRequest{}) + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + + var out listConnectionsOutput + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + err = json.Unmarshal([]byte(textContent.Text), &out) + require.NoError(t, err) + + // 3 Trino connections + 1 DataHub = 4 total + assert.Equal(t, 4, out.Count) + assert.Len(t, out.Connections, 4) + + // Build map by connection name + connByName := make(map[string]connectionEntry) + for _, c := range out.Connections { + connByName[c.Name] = c + } + + wh := connByName["warehouse"] + assert.Equal(t, "trino", wh.Kind) + assert.Equal(t, "warehouse", wh.Connection) + assert.Equal(t, "Analytics warehouse", wh.Description) + assert.True(t, wh.IsDefault) + + es := connByName["elasticsearch"] + assert.Equal(t, "trino", es.Kind) + assert.Equal(t, "elasticsearch", es.Connection) + assert.Equal(t, "Sales data", es.Description) + assert.False(t, es.IsDefault) + + cass := connByName["cassandra"] + assert.Equal(t, "trino", cass.Kind) + assert.Empty(t, cass.Description) + assert.False(t, cass.IsDefault) + + dh := connByName["primary"] + assert.Equal(t, "datahub", dh.Kind) + assert.Equal(t, "primary-datahub", dh.Connection) + assert.Empty(t, dh.Description) + assert.False(t, dh.IsDefault) + }) + + t.Run("single connection lister returns one entry", func(t *testing.T) { + reg := registry.NewRegistry() + require.NoError(t, reg.Register(&mockConnectionListerToolkit{ + mockToolkit: mockToolkit{ + kind: "trino", + name: "prod", + tools: []string{"trino_query"}, + }, + connections: []toolkit.ConnectionDetail{ + {Name: "prod", Description: "Production", IsDefault: true}, + }, + })) + + p := &Platform{toolkitRegistry: reg} + result, _, err := p.handleListConnections(context.Background(), &mcp.CallToolRequest{}) + require.NoError(t, err) + + var out listConnectionsOutput + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + err = json.Unmarshal([]byte(textContent.Text), &out) + require.NoError(t, err) + + assert.Equal(t, 1, out.Count) + assert.Equal(t, "prod", out.Connections[0].Name) + assert.Equal(t, "Production", out.Connections[0].Description) + assert.True(t, out.Connections[0].IsDefault) + }) +} + func TestPlatformToolsIncludesListConnections(t *testing.T) { p := &Platform{} tools := p.PlatformTools() diff --git a/pkg/registry/factories.go b/pkg/registry/factories.go index b1580093..51de8c05 100644 --- a/pkg/registry/factories.go +++ b/pkg/registry/factories.go @@ -10,11 +10,27 @@ import ( // RegisterBuiltinFactories registers all built-in toolkit factories. func RegisterBuiltinFactories(r *Registry) { - r.RegisterFactory("trino", TrinoFactory) + r.RegisterAggregateFactory("trino", TrinoAggregateFactory) r.RegisterFactory("datahub", DataHubFactory) r.RegisterFactory("s3", S3Factory) } +// TrinoAggregateFactory creates a single multi-connection Trino toolkit +// from all configured instances. This ensures deterministic connection +// routing based on the "connection" parameter in each tool call, rather +// than the non-deterministic last-write-wins behavior of N separate toolkits. +func TrinoAggregateFactory(defaultName string, instances map[string]map[string]any) (Toolkit, error) { + multiCfg, err := trinokit.ParseMultiConfig(defaultName, instances) + if err != nil { + return nil, fmt.Errorf("parsing trino multi config: %w", err) + } + tk, err := trinokit.NewMulti(multiCfg) + if err != nil { + return nil, fmt.Errorf("creating trino toolkit: %w", err) + } + return tk, nil +} + // TrinoFactory creates a Trino toolkit from configuration. func TrinoFactory(name string, cfg map[string]any) (Toolkit, error) { config, err := trinokit.ParseConfig(cfg) diff --git a/pkg/registry/loader.go b/pkg/registry/loader.go index c107ad3b..7d795f61 100644 --- a/pkg/registry/loader.go +++ b/pkg/registry/loader.go @@ -35,12 +35,19 @@ func (l *Loader) Load(cfg LoaderConfig) error { continue } - for name, instanceCfg := range kindCfg.Instances { - // Merge kind-level config with instance config - mergedCfg := make(map[string]any) - maps.Copy(mergedCfg, kindCfg.Config) - maps.Copy(mergedCfg, instanceCfg) + // Build merged instance configs for aggregate factory detection. + mergedInstances := mergeInstanceConfigs(kindCfg.Instances, kindCfg.Config) + // Check for aggregate factory first (multi-connection → single toolkit). + if aggFactory, ok := l.registry.GetAggregateFactory(kind); ok { + if err := l.loadAggregate(kind, kindCfg.Default, mergedInstances, aggFactory); err != nil { + return err + } + continue + } + + // Fall through to per-instance factory loop. + for name, mergedCfg := range mergedInstances { toolkitCfg := ToolkitConfig{ Kind: kind, Name: name, @@ -71,31 +78,78 @@ func (l *Loader) LoadFromMap(toolkits map[string]any) error { continue } - instances, _ := kindMap["instances"].(map[string]any) - defaultName, _ := kindMap["default"].(string) - kindConfig, _ := kindMap["config"].(map[string]any) + if err := l.loadKindFromMap(kind, kindMap); err != nil { + return err + } + } - for name, instanceV := range instances { - instanceCfg, _ := instanceV.(map[string]any) + return nil +} - // Merge configs - mergedCfg := make(map[string]any) - maps.Copy(mergedCfg, kindConfig) - maps.Copy(mergedCfg, instanceCfg) +// loadKindFromMap loads all instances of a toolkit kind from a map config. +func (l *Loader) loadKindFromMap(kind string, kindMap map[string]any) error { + instances, _ := kindMap["instances"].(map[string]any) + defaultName, _ := kindMap["default"].(string) + kindConfig, _ := kindMap["config"].(map[string]any) - toolkitCfg := ToolkitConfig{ - Kind: kind, - Name: name, - Enabled: true, - Config: mergedCfg, - Default: name == defaultName, - } + mergedInstances := mergeMapInstances(instances, kindConfig) - if err := l.registry.CreateAndRegister(toolkitCfg); err != nil { - return fmt.Errorf("loading toolkit %s/%s: %w", kind, name, err) - } - } + // Check for aggregate factory first. + if aggFactory, ok := l.registry.GetAggregateFactory(kind); ok { + return l.loadAggregate(kind, defaultName, mergedInstances, aggFactory) } + // Fall through to per-instance factory loop. + for name, mergedCfg := range mergedInstances { + toolkitCfg := ToolkitConfig{ + Kind: kind, + Name: name, + Enabled: true, + Config: mergedCfg, + Default: name == defaultName, + } + + if err := l.registry.CreateAndRegister(toolkitCfg); err != nil { + return fmt.Errorf("loading toolkit %s/%s: %w", kind, name, err) + } + } return nil } + +// mergeMapInstances builds typed instance configs from untyped map, merging kind-level config. +func mergeMapInstances(instances, kindConfig map[string]any) map[string]map[string]any { + merged := make(map[string]map[string]any, len(instances)) + for name, instanceV := range instances { + instanceCfg, _ := instanceV.(map[string]any) + mergedCfg := make(map[string]any) + maps.Copy(mergedCfg, kindConfig) + maps.Copy(mergedCfg, instanceCfg) + merged[name] = mergedCfg + } + return merged +} + +// loadAggregate invokes an aggregate factory and registers the resulting toolkit. +func (l *Loader) loadAggregate( + kind, defaultName string, + instances map[string]map[string]any, + factory AggregateToolkitFactory, +) error { + toolkit, err := factory(defaultName, instances) + if err != nil { + return fmt.Errorf("loading aggregate toolkit %s: %w", kind, err) + } + return l.registry.Register(toolkit) +} + +// mergeInstanceConfigs merges kind-level config into each instance config. +func mergeInstanceConfigs(instances map[string]map[string]any, kindConfig map[string]any) map[string]map[string]any { + merged := make(map[string]map[string]any, len(instances)) + for name, instanceCfg := range instances { + mergedCfg := make(map[string]any) + maps.Copy(mergedCfg, kindConfig) + maps.Copy(mergedCfg, instanceCfg) + merged[name] = mergedCfg + } + return merged +} diff --git a/pkg/registry/loader_test.go b/pkg/registry/loader_test.go index 828c2e0a..31cb7f1f 100644 --- a/pkg/registry/loader_test.go +++ b/pkg/registry/loader_test.go @@ -1,6 +1,7 @@ package registry import ( + "fmt" "testing" ) @@ -112,6 +113,87 @@ func TestLoader_Load(t *testing.T) { t.Error("expected error for missing factory") } }) + + t.Run("aggregate factory called instead of per-instance", func(t *testing.T) { + reg := NewRegistry() + + perInstanceCalled := false + reg.RegisterFactory("agg-kind", func(_ string, _ map[string]any) (Toolkit, error) { + perInstanceCalled = true + return &mockToolkit{kind: "agg-kind", name: "should-not-be-used"}, nil + }) + + aggCalled := false + reg.RegisterAggregateFactory("agg-kind", func(defaultName string, instances map[string]map[string]any) (Toolkit, error) { + aggCalled = true + if defaultName != "inst1" { + t.Errorf("defaultName = %q, want 'inst1'", defaultName) + } + if len(instances) != 2 { + t.Errorf("expected 2 instances, got %d", len(instances)) + } + // Verify kind-level config is merged + if instances["inst1"]["shared"] != "yes" { + t.Error("expected shared config to be merged") + } + return &mockToolkit{kind: "agg-kind", name: defaultName, tools: []string{"agg_tool"}}, nil + }) + + loader := NewLoader(reg) + + cfg := LoaderConfig{ + Toolkits: map[string]ToolkitKindConfig{ + "agg-kind": { + Enabled: true, + Default: "inst1", + Config: map[string]any{"shared": "yes"}, + Instances: map[string]map[string]any{ + "inst1": {"host": "a"}, + "inst2": {"host": "b"}, + }, + }, + }, + } + + err := loader.Load(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !aggCalled { + t.Error("aggregate factory was not called") + } + if perInstanceCalled { + t.Error("per-instance factory should not be called when aggregate is registered") + } + + // Should register one toolkit, not two. + if len(reg.All()) != 1 { + t.Errorf("expected 1 toolkit, got %d", len(reg.All())) + } + }) + + t.Run("aggregate factory error propagated", func(t *testing.T) { + reg := NewRegistry() + reg.RegisterAggregateFactory("failing", func(_ string, _ map[string]map[string]any) (Toolkit, error) { + return nil, fmt.Errorf("aggregate creation failed") + }) + + loader := NewLoader(reg) + cfg := LoaderConfig{ + Toolkits: map[string]ToolkitKindConfig{ + "failing": { + Enabled: true, + Instances: map[string]map[string]any{"inst1": {}}, + }, + }, + } + + err := loader.Load(cfg) + if err == nil { + t.Error("expected error from aggregate factory") + } + }) } func TestLoader_LoadFromMap(t *testing.T) { @@ -217,4 +299,73 @@ func TestLoader_LoadFromMap(t *testing.T) { t.Error("expected error for missing factory") } }) + + t.Run("aggregate factory from map", func(t *testing.T) { + reg := NewRegistry() + + aggCalled := false + reg.RegisterAggregateFactory("agg-map", func(defaultName string, instances map[string]map[string]any) (Toolkit, error) { + aggCalled = true + if defaultName != "main" { + t.Errorf("defaultName = %q, want 'main'", defaultName) + } + if len(instances) != 2 { + t.Errorf("expected 2 instances, got %d", len(instances)) + } + // Verify kind-level config is merged into instances. + if instances["main"]["shared"] != "value" { + t.Error("expected shared config to be merged into main") + } + if instances["secondary"]["shared"] != "value" { + t.Error("expected shared config to be merged into secondary") + } + return &mockToolkit{kind: "agg-map", name: defaultName}, nil + }) + + loader := NewLoader(reg) + + toolkits := map[string]any{ + "agg-map": map[string]any{ + "enabled": true, + "default": "main", + "config": map[string]any{"shared": "value"}, + "instances": map[string]any{ + "main": map[string]any{"host": "a"}, + "secondary": map[string]any{"host": "b"}, + }, + }, + } + + err := loader.LoadFromMap(toolkits) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !aggCalled { + t.Error("aggregate factory was not called") + } + if len(reg.All()) != 1 { + t.Errorf("expected 1 toolkit, got %d", len(reg.All())) + } + }) + + t.Run("aggregate factory error from map", func(t *testing.T) { + reg := NewRegistry() + reg.RegisterAggregateFactory("fail-map", func(_ string, _ map[string]map[string]any) (Toolkit, error) { + return nil, fmt.Errorf("map aggregate failed") + }) + + loader := NewLoader(reg) + toolkits := map[string]any{ + "fail-map": map[string]any{ + "enabled": true, + "instances": map[string]any{"inst1": map[string]any{}}, + }, + } + + err := loader.LoadFromMap(toolkits) + if err == nil { + t.Error("expected error from aggregate factory") + } + }) } diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index ad8fa1f2..6b1c8242 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -21,6 +21,9 @@ type Registry struct { // Factory functions by kind factories map[string]ToolkitFactory + // Aggregate factory functions by kind (multi-instance → single toolkit) + aggregateFactories map[string]AggregateToolkitFactory + // Providers for cross-injection semanticProvider semantic.Provider queryProvider query.Provider @@ -29,8 +32,9 @@ type Registry struct { // NewRegistry creates a new toolkit registry. func NewRegistry() *Registry { return &Registry{ - toolkits: make(map[string]Toolkit), - factories: make(map[string]ToolkitFactory), + toolkits: make(map[string]Toolkit), + factories: make(map[string]ToolkitFactory), + aggregateFactories: make(map[string]AggregateToolkitFactory), } } @@ -41,6 +45,23 @@ func (r *Registry) RegisterFactory(kind string, factory ToolkitFactory) { r.factories[kind] = factory } +// RegisterAggregateFactory registers an aggregate toolkit factory for a kind. +// Aggregate factories receive all instance configs and produce a single toolkit +// that handles multi-connection routing internally. +func (r *Registry) RegisterAggregateFactory(kind string, factory AggregateToolkitFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.aggregateFactories[kind] = factory +} + +// GetAggregateFactory returns the aggregate factory for a kind, if registered. +func (r *Registry) GetAggregateFactory(kind string) (AggregateToolkitFactory, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + f, ok := r.aggregateFactories[kind] + return f, ok +} + // SetSemanticProvider sets the semantic provider for all toolkits. func (r *Registry) SetSemanticProvider(provider semantic.Provider) { r.mu.Lock() diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index f7b0afb0..ba818aeb 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -226,16 +226,10 @@ func TestRegisterBuiltinFactories(t *testing.T) { reg := NewRegistry() RegisterBuiltinFactories(reg) - // Verify all three factories are registered by trying to create with invalid config - t.Run("trino factory registered", func(t *testing.T) { - // Should fail with invalid config (missing host) - err := reg.CreateAndRegister(ToolkitConfig{ - Kind: regTestTrino, - Name: regTestTest, - Config: map[string]any{}, - }) - if err == nil { - t.Error("expected error for missing trino config") + t.Run("trino aggregate factory registered", func(t *testing.T) { + _, ok := reg.GetAggregateFactory(regTestTrino) + if !ok { + t.Error("expected trino aggregate factory to be registered") } }) @@ -270,6 +264,98 @@ func TestTrinoFactory(t *testing.T) { } } +func TestRegisterAggregateFactory(t *testing.T) { + reg := NewRegistry() + + called := false + factory := func(defaultName string, instances map[string]map[string]any) (Toolkit, error) { + called = true + if defaultName != "primary" { + t.Errorf("defaultName = %q, want 'primary'", defaultName) + } + if len(instances) != 2 { + t.Errorf("expected 2 instances, got %d", len(instances)) + } + return &mockToolkit{kind: "agg", name: defaultName, tools: []string{"agg_tool"}}, nil + } + + reg.RegisterAggregateFactory("agg", factory) + + got, ok := reg.GetAggregateFactory("agg") + if !ok { + t.Fatal("expected aggregate factory to be registered") + } + + tk, err := got("primary", map[string]map[string]any{ + "primary": {"host": "a"}, + "secondary": {"host": "b"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("factory was not called") + } + if tk.Name() != "primary" { + t.Errorf("Name() = %q, want 'primary'", tk.Name()) + } +} + +func TestGetAggregateFactory_NotFound(t *testing.T) { + reg := NewRegistry() + _, ok := reg.GetAggregateFactory("nonexistent") + if ok { + t.Error("expected false for unregistered aggregate factory") + } +} + +func TestTrinoAggregateFactory(t *testing.T) { + t.Run("valid multi-instance config", func(t *testing.T) { + tk, err := TrinoAggregateFactory("warehouse", map[string]map[string]any{ + "warehouse": { + "host": "warehouse.example.com", + "user": "trino", + "port": 8080, + }, + "elasticsearch": { + "host": "es.example.com", + "user": "trino", + "port": 8080, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tk.Kind() != "trino" { + t.Errorf("Kind() = %q, want 'trino'", tk.Kind()) + } + if tk.Name() != "warehouse" { + t.Errorf("Name() = %q, want 'warehouse'", tk.Name()) + } + if tk.Connection() != "warehouse" { + t.Errorf("Connection() = %q, want 'warehouse'", tk.Connection()) + } + }) + + t.Run("invalid config returns error", func(t *testing.T) { + _, err := TrinoAggregateFactory("bad", map[string]map[string]any{ + "bad": {"timeout": "invalid-duration"}, + }) + if err == nil { + t.Error("expected error for invalid config") + } + }) + + t.Run("missing host returns error", func(t *testing.T) { + _, err := TrinoAggregateFactory("empty", map[string]map[string]any{ + "empty": {}, + }) + if err == nil { + t.Error("expected error for missing host") + } + }) +} + func TestDataHubFactory(t *testing.T) { // Test with invalid config _, err := DataHubFactory(regTestTest, map[string]any{}) diff --git a/pkg/registry/toolkit.go b/pkg/registry/toolkit.go index 4dd8761d..d029e41e 100644 --- a/pkg/registry/toolkit.go +++ b/pkg/registry/toolkit.go @@ -39,6 +39,11 @@ type Toolkit interface { // ToolkitFactory creates a toolkit from configuration. type ToolkitFactory func(name string, config map[string]any) (Toolkit, error) +// AggregateToolkitFactory creates a single toolkit from multiple instance configs. +// Used for toolkit kinds that support multi-connection routing internally +// (e.g., Trino with multiserver.Manager). +type AggregateToolkitFactory func(defaultName string, instances map[string]map[string]any) (Toolkit, error) + // ToolkitConfig holds configuration for a toolkit instance. type ToolkitConfig struct { Kind string diff --git a/pkg/toolkit/connection.go b/pkg/toolkit/connection.go new file mode 100644 index 00000000..d8ecdcd2 --- /dev/null +++ b/pkg/toolkit/connection.go @@ -0,0 +1,19 @@ +// Package toolkit provides shared types for toolkit implementations and the +// platform layer. This package has zero internal dependencies to avoid import +// cycles between pkg/registry (which imports toolkit implementations) and the +// toolkit implementations themselves. +package toolkit + +// ConnectionDetail provides information about a single connection within a toolkit. +type ConnectionDetail struct { + Name string + Description string + IsDefault bool +} + +// ConnectionLister is an optional interface for toolkits that manage multiple +// connections internally. Toolkits implementing this interface expose all their +// connections for discovery via the list_connections tool. +type ConnectionLister interface { + ListConnections() []ConnectionDetail +} diff --git a/pkg/toolkit/connection_test.go b/pkg/toolkit/connection_test.go new file mode 100644 index 00000000..75257e66 --- /dev/null +++ b/pkg/toolkit/connection_test.go @@ -0,0 +1,35 @@ +package toolkit + +import "testing" + +// TestConnectionDetailFields verifies the ConnectionDetail struct fields. +func TestConnectionDetailFields(t *testing.T) { + d := ConnectionDetail{ + Name: "warehouse", + Description: "Data warehouse", + IsDefault: true, + } + if d.Name != "warehouse" { + t.Errorf("Name = %q", d.Name) + } + if d.Description != "Data warehouse" { + t.Errorf("Description = %q", d.Description) + } + if !d.IsDefault { + t.Error("IsDefault = false") + } +} + +// TestConnectionListerInterface is a compile-time check that the interface is usable. +func TestConnectionListerInterface(t *testing.T) { + var _ ConnectionLister = mockLister{} + t.Log("ConnectionLister interface is satisfiable") +} + +type mockLister struct{} + +func (mockLister) ListConnections() []ConnectionDetail { + return []ConnectionDetail{ + {Name: "test", Description: "test desc", IsDefault: true}, + } +} diff --git a/pkg/toolkits/trino/config.go b/pkg/toolkits/trino/config.go index 141170fe..bad303fe 100644 --- a/pkg/toolkits/trino/config.go +++ b/pkg/toolkits/trino/config.go @@ -5,6 +5,34 @@ import ( "time" ) +// MultiConfig holds configuration for a multi-connection Trino toolkit. +type MultiConfig struct { + // DefaultConnection is the name of the default connection. + DefaultConnection string + + // Instances maps connection names to their parsed configurations. + Instances map[string]Config +} + +// ParseMultiConfig builds a MultiConfig from the aggregate factory's instance map. +func ParseMultiConfig(defaultName string, instances map[string]map[string]any) (MultiConfig, error) { + mc := MultiConfig{ + DefaultConnection: defaultName, + Instances: make(map[string]Config, len(instances)), + } + for name, raw := range instances { + cfg, err := ParseConfig(raw) + if err != nil { + return mc, fmt.Errorf("instance %s: %w", name, err) + } + if cfg.ConnectionName == "" { + cfg.ConnectionName = name + } + mc.Instances[name] = cfg + } + return mc, nil +} + // ParseConfig parses a Trino toolkit configuration from a map. func ParseConfig(cfg map[string]any) (Config, error) { c := Config{ @@ -28,6 +56,7 @@ func ParseConfig(cfg map[string]any) (Config, error) { c.Catalog = getString(cfg, "catalog") c.Schema = getString(cfg, "schema") c.ConnectionName = getString(cfg, "connection_name") + c.Description = getString(cfg, "description") // Optional int fields c.Port = getInt(cfg, "port", c.Port) diff --git a/pkg/toolkits/trino/config_test.go b/pkg/toolkits/trino/config_test.go index e4633b26..8d616e9d 100644 --- a/pkg/toolkits/trino/config_test.go +++ b/pkg/toolkits/trino/config_test.go @@ -297,6 +297,35 @@ func TestParseConfig_WithDescriptions(t *testing.T) { } } +func TestParseConfig_WithDescription(t *testing.T) { + cfg := map[string]any{ + "host": "trino.example.com", + "description": "Production data warehouse for analytics", + } + + result, err := ParseConfig(cfg) + if err != nil { + t.Fatalf(trinoCfgTestUnexpectedErr, err) + } + if result.Description != "Production data warehouse for analytics" { + t.Errorf("Description = %q, want 'Production data warehouse for analytics'", result.Description) + } +} + +func TestParseConfig_NoDescription(t *testing.T) { + cfg := map[string]any{ + "host": "trino.example.com", + } + + result, err := ParseConfig(cfg) + if err != nil { + t.Fatalf(trinoCfgTestUnexpectedErr, err) + } + if result.Description != "" { + t.Errorf("Description should be empty, got %q", result.Description) + } +} + func TestParseConfig_NoDescriptions(t *testing.T) { cfg := map[string]any{ "host": "trino.example.com", @@ -434,6 +463,91 @@ func TestParseConfig_NoAnnotations(t *testing.T) { } } +func TestParseMultiConfig(t *testing.T) { + t.Run("multiple instances", func(t *testing.T) { + instances := map[string]map[string]any{ + trinoTestWarehouse: { + "host": "warehouse.example.com", + "user": "trino", + "port": trinoTestPort8080, + }, + "elasticsearch": { + "host": "es.example.com", + "user": "es-user", + }, + } + + mc, err := ParseMultiConfig(trinoTestWarehouse, instances) + if err != nil { + t.Fatalf(trinoCfgTestUnexpectedErr, err) + } + + if mc.DefaultConnection != trinoTestWarehouse { + t.Errorf("DefaultConnection = %q, want %q", mc.DefaultConnection, trinoTestWarehouse) + } + if len(mc.Instances) != 2 { + t.Fatalf("expected 2 instances, got %d", len(mc.Instances)) + } + + wh := mc.Instances[trinoTestWarehouse] + if wh.Host != "warehouse.example.com" { + t.Errorf("warehouse.Host = %q", wh.Host) + } + if wh.ConnectionName != trinoTestWarehouse { + t.Errorf("warehouse.ConnectionName = %q, want %q", wh.ConnectionName, trinoTestWarehouse) + } + + es := mc.Instances["elasticsearch"] + if es.Host != "es.example.com" { + t.Errorf("es.Host = %q", es.Host) + } + if es.ConnectionName != "elasticsearch" { + t.Errorf("es.ConnectionName = %q, want 'elasticsearch'", es.ConnectionName) + } + }) + + t.Run("preserves explicit connection name", func(t *testing.T) { + instances := map[string]map[string]any{ + "main": { + "host": "trino.example.com", + "user": "trino", + "connection_name": "custom-name", + }, + } + + mc, err := ParseMultiConfig("main", instances) + if err != nil { + t.Fatalf(trinoCfgTestUnexpectedErr, err) + } + if mc.Instances["main"].ConnectionName != "custom-name" { + t.Errorf("ConnectionName = %q, want 'custom-name'", mc.Instances["main"].ConnectionName) + } + }) + + t.Run("error in instance config", func(t *testing.T) { + instances := map[string]map[string]any{ + "good": {"host": "good.example.com"}, + "bad": {"timeout": "not-a-duration"}, + } + + _, err := ParseMultiConfig("good", instances) + if err == nil { + t.Error("expected error for invalid instance config") + } + }) + + t.Run("missing host returns error", func(t *testing.T) { + instances := map[string]map[string]any{ + "missing-host": {"user": "testuser"}, + } + + _, err := ParseMultiConfig("missing-host", instances) + if err == nil { + t.Error("expected error for missing host") + } + }) +} + func TestGetDuration(t *testing.T) { cfg := map[string]any{ trinoCfgTestString: "5m", diff --git a/pkg/toolkits/trino/connection_required.go b/pkg/toolkits/trino/connection_required.go new file mode 100644 index 00000000..997cac26 --- /dev/null +++ b/pkg/toolkits/trino/connection_required.go @@ -0,0 +1,107 @@ +package trino + +import ( + "context" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + trinotools "github.com/txn2/mcp-trino/pkg/tools" +) + +// ConnectionDescription holds display information about a connection +// for error messages when the connection parameter is missing. +type ConnectionDescription struct { + Name string + Description string + IsDefault bool +} + +// ConnectionRequiredMiddleware rejects tool calls that omit the connection +// parameter when multiple Trino connections are configured. The error message +// lists all available connections with their descriptions so the LLM can +// choose the correct one. +type ConnectionRequiredMiddleware struct { + connections []ConnectionDescription +} + +// NewConnectionRequiredMiddleware creates a middleware that enforces explicit +// connection selection. The connections slice describes all available backends. +func NewConnectionRequiredMiddleware(connections []ConnectionDescription) *ConnectionRequiredMiddleware { + // Sort by name for deterministic error messages. + sorted := make([]ConnectionDescription, len(connections)) + copy(sorted, connections) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Name < sorted[j].Name + }) + return &ConnectionRequiredMiddleware{connections: sorted} +} + +// Before checks that the connection parameter is set for tools that need it. +func (m *ConnectionRequiredMiddleware) Before(ctx context.Context, tc *trinotools.ToolContext) (context.Context, error) { + // list_connections doesn't need a connection parameter. + if tc.Name == trinotools.ToolListConnections { + return ctx, nil + } + + conn := extractConnectionFromInput(tc.Input) + if conn != "" { + return ctx, nil + } + + return ctx, fmt.Errorf("multiple Trino connections are configured — you must specify the 'connection' parameter.\n\n%s", + m.formatAvailableConnections()) +} + +// After is a no-op — validation happens before execution. +func (*ConnectionRequiredMiddleware) After( + _ context.Context, + _ *trinotools.ToolContext, + result *mcp.CallToolResult, + handlerErr error, +) (*mcp.CallToolResult, error) { + return result, handlerErr +} + +// extractConnectionFromInput extracts the Connection field from a tool input +// struct using reflection. All Trino tool inputs (except ListConnectionsInput) +// have a Connection string field. +func extractConnectionFromInput(input any) string { + if input == nil { + return "" + } + v := reflect.ValueOf(input) + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return "" + } + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return "" + } + f := v.FieldByName("Connection") + if !f.IsValid() || f.Kind() != reflect.String { + return "" + } + return f.String() +} + +// formatAvailableConnections builds a human-readable list of available connections. +func (m *ConnectionRequiredMiddleware) formatAvailableConnections() string { + var b strings.Builder + b.WriteString("Available connections:\n") + for _, c := range m.connections { + fmt.Fprintf(&b, " - %s", c.Name) + if c.IsDefault { + b.WriteString(" (default)") + } + if c.Description != "" { + fmt.Fprintf(&b, ": %s", c.Description) + } + b.WriteByte('\n') + } + return b.String() +} diff --git a/pkg/toolkits/trino/connection_required_test.go b/pkg/toolkits/trino/connection_required_test.go new file mode 100644 index 00000000..21b138a0 --- /dev/null +++ b/pkg/toolkits/trino/connection_required_test.go @@ -0,0 +1,253 @@ +package trino + +import ( + "context" + "strings" + "testing" + + trinotools "github.com/txn2/mcp-trino/pkg/tools" +) + +func TestConnectionRequiredMiddleware_Before(t *testing.T) { + conns := []ConnectionDescription{ + {Name: trinoTestWarehouse, Description: "Data warehouse for analytics", IsDefault: true}, + {Name: "elasticsearch", Description: "Elasticsearch for sales data", IsDefault: false}, + {Name: "cassandra", Description: "", IsDefault: false}, + } + mw := NewConnectionRequiredMiddleware(conns) + + t.Run("passes when connection is set", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolQuery, trinotools.QueryInput{ + SQL: "SELECT 1", + Connection: trinoTestWarehouse, + }) + + _, err := mw.Before(context.Background(), tc) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("rejects when connection is empty", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolQuery, trinotools.QueryInput{ + SQL: "SELECT 1", + }) + + _, err := mw.Before(context.Background(), tc) + if err == nil { + t.Fatal("expected error for missing connection") + } + + errMsg := err.Error() + if !strings.Contains(errMsg, "multiple Trino connections") { + t.Errorf("error should mention multiple connections, got: %s", errMsg) + } + if !strings.Contains(errMsg, trinoTestWarehouse) { + t.Errorf("error should list warehouse, got: %s", errMsg) + } + if !strings.Contains(errMsg, "elasticsearch") { + t.Errorf("error should list elasticsearch, got: %s", errMsg) + } + if !strings.Contains(errMsg, "Data warehouse for analytics") { + t.Errorf("error should include descriptions, got: %s", errMsg) + } + if !strings.Contains(errMsg, "(default)") { + t.Errorf("error should mark default connection, got: %s", errMsg) + } + }) + + t.Run("skips list_connections tool", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolListConnections, trinotools.ListConnectionsInput{}) + + _, err := mw.Before(context.Background(), tc) + if err != nil { + t.Errorf("list_connections should be skipped, got: %v", err) + } + }) + + t.Run("passes for describe_table with connection", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolDescribeTable, trinotools.DescribeTableInput{ + Catalog: "hive", + Schema: "default", + Table: "users", + Connection: trinoTestWarehouse, + }) + + _, err := mw.Before(context.Background(), tc) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("rejects describe_table without connection", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolDescribeTable, trinotools.DescribeTableInput{ + Catalog: "hive", + Schema: "default", + Table: "users", + }) + + _, err := mw.Before(context.Background(), tc) + if err == nil { + t.Fatal("expected error for missing connection on describe_table") + } + }) + + t.Run("rejects list_catalogs without connection", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolListCatalogs, trinotools.ListCatalogsInput{}) + + _, err := mw.Before(context.Background(), tc) + if err == nil { + t.Fatal("expected error for missing connection on list_catalogs") + } + }) + + t.Run("passes list_schemas with connection", func(t *testing.T) { + tc := trinotools.NewToolContext(trinotools.ToolListSchemas, trinotools.ListSchemasInput{ + Catalog: "hive", + Connection: "elasticsearch", + }) + + _, err := mw.Before(context.Background(), tc) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("connection without description in error", func(t *testing.T) { + _, err := mw.Before(context.Background(), trinotools.NewToolContext( + trinotools.ToolQuery, trinotools.QueryInput{SQL: "SELECT 1"}, + )) + if err == nil { + t.Fatal("expected error") + } + // cassandra has no description, should just show the name + if !strings.Contains(err.Error(), "cassandra") { + t.Errorf("error should list cassandra, got: %s", err.Error()) + } + }) +} + +func TestConnectionRequiredMiddleware_After(t *testing.T) { + mw := NewConnectionRequiredMiddleware(nil) + result, err := mw.After(context.Background(), nil, nil, nil) + if result != nil { + t.Error("expected nil result passthrough") + } + if err != nil { + t.Errorf("expected nil error passthrough, got: %v", err) + } +} + +func TestNewConnectionRequiredMiddleware_SortsDeterministically(t *testing.T) { + conns := []ConnectionDescription{ + {Name: "zebra"}, + {Name: "alpha"}, + {Name: "middle"}, + } + mw := NewConnectionRequiredMiddleware(conns) + + if mw.connections[0].Name != "alpha" { + t.Errorf("expected first connection to be 'alpha', got %q", mw.connections[0].Name) + } + if mw.connections[1].Name != "middle" { + t.Errorf("expected second connection to be 'middle', got %q", mw.connections[1].Name) + } + if mw.connections[2].Name != "zebra" { + t.Errorf("expected third connection to be 'zebra', got %q", mw.connections[2].Name) + } +} + +func TestExtractConnectionFromInput(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + if got := extractConnectionFromInput(nil); got != "" { + t.Errorf("expected empty, got %q", got) + } + }) + + t.Run("non-struct input", func(t *testing.T) { + if got := extractConnectionFromInput("not a struct"); got != "" { + t.Errorf("expected empty for string, got %q", got) + } + }) + + t.Run("struct without Connection field", func(t *testing.T) { + type noConn struct { + SQL string + } + if got := extractConnectionFromInput(noConn{SQL: "SELECT 1"}); got != "" { + t.Errorf("expected empty for struct without Connection, got %q", got) + } + }) + + t.Run("struct with Connection field", func(t *testing.T) { + type withConn struct { + Connection string + } + if got := extractConnectionFromInput(withConn{Connection: trinoTestWarehouse}); got != trinoTestWarehouse { + t.Errorf("expected 'warehouse', got %q", got) + } + }) + + t.Run("pointer to struct", func(t *testing.T) { + type withConn struct { + Connection string + } + input := &withConn{Connection: "prod"} + if got := extractConnectionFromInput(input); got != "prod" { + t.Errorf("expected 'prod', got %q", got) + } + }) + + t.Run("nil pointer", func(t *testing.T) { + var input *trinotools.QueryInput + if got := extractConnectionFromInput(input); got != "" { + t.Errorf("expected empty for nil pointer, got %q", got) + } + }) + + t.Run("struct with non-string Connection", func(t *testing.T) { + type badConn struct { + Connection int + } + if got := extractConnectionFromInput(badConn{Connection: 42}); got != "" { + t.Errorf("expected empty for non-string Connection, got %q", got) + } + }) + + t.Run("real QueryInput", func(t *testing.T) { + input := trinotools.QueryInput{SQL: "SELECT 1", Connection: trinoTestWarehouse} + if got := extractConnectionFromInput(input); got != trinoTestWarehouse { + t.Errorf("expected 'warehouse', got %q", got) + } + }) + + t.Run("real ExecuteInput", func(t *testing.T) { + input := trinotools.ExecuteInput{SQL: "INSERT INTO t VALUES (1)", Connection: trinoTestWarehouse} + if got := extractConnectionFromInput(input); got != trinoTestWarehouse { + t.Errorf("expected 'warehouse', got %q", got) + } + }) +} + +func TestFormatAvailableConnections(t *testing.T) { + mw := NewConnectionRequiredMiddleware([]ConnectionDescription{ + {Name: trinoTestWarehouse, Description: "Analytics warehouse", IsDefault: true}, + {Name: "elasticsearch", Description: "Sales data", IsDefault: false}, + {Name: "bare", Description: "", IsDefault: false}, + }) + + output := mw.formatAvailableConnections() + + if !strings.Contains(output, "Available connections:") { + t.Error("expected header") + } + if !strings.Contains(output, "bare") { + t.Error("expected bare connection") + } + if !strings.Contains(output, "warehouse (default): Analytics warehouse") { + t.Errorf("expected formatted default with description, got:\n%s", output) + } + if !strings.Contains(output, "elasticsearch: Sales data") { + t.Errorf("expected formatted non-default with description, got:\n%s", output) + } +} diff --git a/pkg/toolkits/trino/toolkit.go b/pkg/toolkits/trino/toolkit.go index c111ad97..cad0f0bb 100644 --- a/pkg/toolkits/trino/toolkit.go +++ b/pkg/toolkits/trino/toolkit.go @@ -7,10 +7,12 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" trinoclient "github.com/txn2/mcp-trino/pkg/client" + "github.com/txn2/mcp-trino/pkg/multiserver" trinotools "github.com/txn2/mcp-trino/pkg/tools" "github.com/txn2/mcp-data-platform/pkg/query" "github.com/txn2/mcp-data-platform/pkg/semantic" + "github.com/txn2/mcp-data-platform/pkg/toolkit" ) const ( @@ -45,6 +47,7 @@ type Config struct { MaxLimit int `yaml:"max_limit"` ReadOnly bool `yaml:"read_only"` ConnectionName string `yaml:"connection_name"` + Description string `yaml:"description"` // Human-readable description of this connection's purpose Descriptions map[string]string `yaml:"descriptions"` Annotations map[string]AnnotationConfig `yaml:"annotations"` @@ -85,6 +88,7 @@ type Toolkit struct { name string config Config client *trinoclient.Client + manager *multiserver.Manager // non-nil in multi-connection mode trinoToolkit *trinotools.Toolkit semanticProvider semantic.Provider @@ -92,6 +96,9 @@ type Toolkit struct { // elicitation holds the middleware so providers can be propagated after init. elicitation *ElicitationMiddleware + + // connectionDescriptions maps connection name → description (multi-connection mode). + connectionDescriptions map[string]string } // New creates a new Trino toolkit. @@ -126,6 +133,168 @@ func New(name string, cfg Config) (*Toolkit, error) { return t, nil } +// NewMulti creates a multi-connection Trino toolkit that routes requests +// to the correct backend based on the "connection" parameter in each tool call. +// This replaces the previous pattern of creating N separate single-client +// toolkits that would clobber each other's tool registrations. +func NewMulti(cfg MultiConfig) (*Toolkit, error) { + if len(cfg.Instances) == 0 { + return nil, fmt.Errorf("at least one trino instance is required") + } + + // Resolve the default connection name. + defaultName := cfg.DefaultConnection + if defaultName == "" { + // Pick the first instance alphabetically for determinism. + for name := range cfg.Instances { + if defaultName == "" || name < defaultName { + defaultName = name + } + } + } + + defaultCfg, ok := cfg.Instances[defaultName] + if !ok { + return nil, fmt.Errorf("default connection %q not found in instances", defaultName) + } + + // Validate all instance configs. + for name, instCfg := range cfg.Instances { + if err := validateConfig(instCfg); err != nil { + return nil, fmt.Errorf("instance %s: %w", name, err) + } + } + + // Build multiserver config from instance configs. + msCfg := buildMultiserverConfig(defaultName, defaultCfg, cfg.Instances) + + mgr := multiserver.NewManager(msCfg) + + // Use the default instance config for toolkit-level settings. + defaultCfg = applyDefaults(defaultName, defaultCfg) + + descs := make(map[string]string, len(cfg.Instances)) + for name, instCfg := range cfg.Instances { + descs[name] = instCfg.Description + } + + t := &Toolkit{ + name: defaultName, + config: defaultCfg, + manager: mgr, + connectionDescriptions: descs, + } + + connRequired := buildConnectionRequired(defaultName, cfg.Instances) + opts := buildToolkitOptions(defaultCfg, nil, connRequired) // elicitation not supported in multi-mode yet + t.trinoToolkit = trinotools.NewToolkitWithManager(mgr, trinotools.Config{ + DefaultLimit: defaultCfg.DefaultLimit, + MaxLimit: defaultCfg.MaxLimit, + }, opts...) + + return t, nil +} + +// buildConnectionRequired creates a ConnectionRequiredMiddleware when multiple +// instances are configured. Returns nil for single-instance deployments. +func buildConnectionRequired(defaultName string, instances map[string]Config) *ConnectionRequiredMiddleware { + if len(instances) <= 1 { + return nil + } + connDescs := make([]ConnectionDescription, 0, len(instances)) + for name, instCfg := range instances { + connDescs = append(connDescs, ConnectionDescription{ + Name: name, + Description: instCfg.Description, + IsDefault: name == defaultName, + }) + } + return NewConnectionRequiredMiddleware(connDescs) +} + +// buildMultiserverConfig constructs a multiserver.Config from instance configs. +func buildMultiserverConfig( + defaultName string, + defaultCfg Config, + instances map[string]Config, +) multiserver.Config { + defaultCfg = applyDefaults(defaultName, defaultCfg) + primary := trinoclient.Config{ + Host: defaultCfg.Host, + Port: defaultCfg.Port, + User: defaultCfg.User, + Password: defaultCfg.Password, + Catalog: defaultCfg.Catalog, + Schema: defaultCfg.Schema, + SSL: defaultCfg.SSL, + SSLVerify: defaultCfg.SSLVerify, + Timeout: defaultCfg.Timeout, + Source: "mcp-data-platform", + } + + connections := make(map[string]multiserver.ConnectionConfig, len(instances)-1) + for name, instCfg := range instances { + if name == defaultName { + continue + } + cc := multiserver.ConnectionConfig{ + Host: instCfg.Host, + } + if instCfg.Port != 0 { + cc.Port = instCfg.Port + } + if instCfg.User != "" { + cc.User = instCfg.User + } + if instCfg.Password != "" { + cc.Password = instCfg.Password + } + if instCfg.Catalog != "" { + cc.Catalog = instCfg.Catalog + } + if instCfg.Schema != "" { + cc.Schema = instCfg.Schema + } + if instCfg.SSL { + ssl := true + cc.SSL = &ssl + } + connections[name] = cc + } + + return multiserver.Config{ + Default: defaultName, + Primary: primary, + Connections: connections, + } +} + +// buildToolkitOptions constructs toolkit options from config. +func buildToolkitOptions(cfg Config, elicit *ElicitationMiddleware, connRequired *ConnectionRequiredMiddleware) []trinotools.ToolkitOption { + var opts []trinotools.ToolkitOption + + if cfg.ReadOnly { + opts = append(opts, trinotools.WithQueryInterceptor(NewReadOnlyInterceptor())) + } + if len(cfg.Descriptions) > 0 { + opts = append(opts, trinotools.WithDescriptions(toTrinoToolNames(cfg.Descriptions))) + } + if len(cfg.Annotations) > 0 { + opts = append(opts, trinotools.WithAnnotations(toTrinoAnnotations(cfg.Annotations))) + } + if connRequired != nil { + opts = append(opts, trinotools.WithMiddleware(connRequired)) + } + if cfg.ProgressEnabled { + opts = append(opts, trinotools.WithMiddleware(&ProgressInjector{})) + } + if elicit != nil { + opts = append(opts, trinotools.WithMiddleware(elicit)) + } + + return opts +} + // validateConfig validates the required configuration fields. func validateConfig(cfg Config) error { if cfg.Host == "" { @@ -201,33 +370,7 @@ func toTrinoToolNames(m map[string]string) map[trinotools.ToolName]string { // createToolkit creates the mcp-trino toolkit with appropriate options. func createToolkit(client *trinoclient.Client, cfg Config, elicit *ElicitationMiddleware) *trinotools.Toolkit { - var opts []trinotools.ToolkitOption - - // Add read-only interceptor if configured - if cfg.ReadOnly { - opts = append(opts, trinotools.WithQueryInterceptor(NewReadOnlyInterceptor())) - } - - // Add description overrides if configured - if len(cfg.Descriptions) > 0 { - opts = append(opts, trinotools.WithDescriptions(toTrinoToolNames(cfg.Descriptions))) - } - - // Add annotation overrides if configured - if len(cfg.Annotations) > 0 { - opts = append(opts, trinotools.WithAnnotations(toTrinoAnnotations(cfg.Annotations))) - } - - // Add progress notifier injector if enabled - if cfg.ProgressEnabled { - opts = append(opts, trinotools.WithMiddleware(&ProgressInjector{})) - } - - // Add elicitation middleware if enabled - if elicit != nil { - opts = append(opts, trinotools.WithMiddleware(elicit)) - } - + opts := buildToolkitOptions(cfg, elicit, nil) return trinotools.NewToolkit(client, trinotools.Config{ DefaultLimit: cfg.DefaultLimit, MaxLimit: cfg.MaxLimit, @@ -322,8 +465,38 @@ func (t *Toolkit) SetQueryProvider(provider query.Provider) { t.queryProvider = provider } +// ListConnections returns details for all connections managed by this toolkit. +// Implements toolkit.ConnectionLister. +func (t *Toolkit) ListConnections() []toolkit.ConnectionDetail { + if t.manager == nil { + // Single-client mode: one connection. + return []toolkit.ConnectionDetail{{ + Name: t.name, + Description: t.config.Description, + IsDefault: true, + }} + } + + infos := t.manager.ConnectionInfos() + details := make([]toolkit.ConnectionDetail, len(infos)) + for i, info := range infos { + details[i] = toolkit.ConnectionDetail{ + Name: info.Name, + Description: t.connectionDescriptions[info.Name], + IsDefault: info.IsDefault, + } + } + return details +} + // Close releases resources. func (t *Toolkit) Close() error { + if t.manager != nil { + if err := t.manager.Close(); err != nil { + return fmt.Errorf("closing trino manager: %w", err) + } + return nil + } if t.client != nil { if err := t.client.Close(); err != nil { return fmt.Errorf("closing trino client: %w", err) @@ -343,13 +516,16 @@ func (t *Toolkit) Config() Config { } // Verify interface compliance. -var _ interface { - Kind() string - Name() string - Connection() string - RegisterTools(s *mcp.Server) - Tools() []string - SetSemanticProvider(provider semantic.Provider) - SetQueryProvider(provider query.Provider) - Close() error -} = (*Toolkit)(nil) +var ( + _ interface { + Kind() string + Name() string + Connection() string + RegisterTools(s *mcp.Server) + Tools() []string + SetSemanticProvider(provider semantic.Provider) + SetQueryProvider(provider query.Provider) + Close() error + } = (*Toolkit)(nil) + _ toolkit.ConnectionLister = (*Toolkit)(nil) +) diff --git a/pkg/toolkits/trino/toolkit_test.go b/pkg/toolkits/trino/toolkit_test.go index 0b58936b..8dd3e599 100644 --- a/pkg/toolkits/trino/toolkit_test.go +++ b/pkg/toolkits/trino/toolkit_test.go @@ -10,6 +10,7 @@ import ( "github.com/txn2/mcp-data-platform/pkg/query" "github.com/txn2/mcp-data-platform/pkg/semantic" + "github.com/txn2/mcp-data-platform/pkg/toolkit" ) const ( @@ -25,6 +26,8 @@ const ( trinoTestDefLimit = 1000 trinoTestDefMaxLimit = 10000 trinoTestDefTimeoutSec = 120 + trinoTestKind = "trino" + trinoTestWarehouse = "warehouse" ) func TestNew(t *testing.T) { @@ -284,8 +287,8 @@ func newTestTrinoToolkit() *Toolkit { func TestToolkit_KindAndName(t *testing.T) { tk := newTestTrinoToolkit() - if tk.Kind() != "trino" { - t.Errorf("Kind() = %q, want 'trino'", tk.Kind()) + if tk.Kind() != trinoTestKind { + t.Errorf("Kind() = %q, want %q", tk.Kind(), trinoTestKind) } if tk.Name() != "test-toolkit" { t.Errorf("Name() = %q", tk.Name()) @@ -628,3 +631,330 @@ func TestCreateToolkit_WithProgressAndElicitation(t *testing.T) { t.Fatal("expected non-nil toolkit") } } + +func TestNewMulti(t *testing.T) { + t.Run("empty instances", func(t *testing.T) { + _, err := NewMulti(MultiConfig{}) + if err == nil { + t.Error("expected error for empty instances") + } + }) + + t.Run("default not found in instances", func(t *testing.T) { + _, err := NewMulti(MultiConfig{ + DefaultConnection: "nonexistent", + Instances: map[string]Config{ + "warehouse": {Host: "localhost", User: "testuser"}, + }, + }) + if err == nil { + t.Error("expected error for missing default connection") + } + }) + + t.Run("invalid instance config", func(t *testing.T) { + _, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "localhost", User: "testuser"}, + "bad": {Host: ""}, // missing host triggers validation error + }, + }) + if err == nil { + t.Error("expected error for invalid instance config") + } + }) + + t.Run("single instance succeeds", func(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "localhost", User: "testuser", Port: trinoTestPort8080}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tk.Kind() != trinoTestKind { + t.Errorf("Kind() = %q, want %q", tk.Kind(), trinoTestKind) + } + if tk.Name() != trinoTestWarehouse { + t.Errorf("Name() = %q, want %q", tk.Name(), trinoTestWarehouse) + } + if tk.Connection() != trinoTestWarehouse { + t.Errorf("Connection() = %q, want %q", tk.Connection(), trinoTestWarehouse) + } + if tk.manager == nil { + t.Error("expected non-nil manager") + } + if tk.client != nil { + t.Error("expected nil client in multi-connection mode") + } + + tools := tk.Tools() + if len(tools) != 7 { //nolint:mnd // 7 trino tools + t.Errorf("expected 7 tools, got %d", len(tools)) + } + }) + + t.Run("multiple instances succeeds", func(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "warehouse.example.com", User: "trino", Port: trinoTestPort443, SSL: true, Catalog: "hive"}, + "elasticsearch": {Host: "es.example.com", User: "trino", Port: trinoTestPort443, SSL: true, Catalog: "elasticsearch"}, + "cassandra": {Host: "cass.example.com", User: "trino", Port: trinoTestPort443, SSL: true, Catalog: "cassandra"}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tk.Name() != trinoTestWarehouse { + t.Errorf("Name() = %q, want %q", tk.Name(), trinoTestWarehouse) + } + if tk.Connection() != trinoTestWarehouse { + t.Errorf("Connection() = %q, want %q", tk.Connection(), trinoTestWarehouse) + } + }) + + t.Run("auto-selects default alphabetically when not specified", func(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + Instances: map[string]Config{ + "charlie": {Host: "c.example.com", User: "trino", Port: trinoTestPort8080}, + "alpha": {Host: "a.example.com", User: "trino", Port: trinoTestPort8080}, + "bravo": {Host: "b.example.com", User: "trino", Port: trinoTestPort8080}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tk.Name() != "alpha" { + t.Errorf("Name() = %q, want 'alpha' (first alphabetically)", tk.Name()) + } + }) + + t.Run("close delegates to manager", func(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "localhost", User: "testuser", Port: trinoTestPort8080}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Close should not error (no active clients yet). + if err := tk.Close(); err != nil { + t.Errorf("Close() error = %v", err) + } + }) + + t.Run("register tools on server", func(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "localhost", User: "testuser", Port: trinoTestPort8080}, + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) + tk.RegisterTools(server) // Should not panic. + }) +} + +func TestListConnections_SingleMode(t *testing.T) { + tk := &Toolkit{ + name: "prod-trino", + config: Config{ + Description: "Production data warehouse", + }, + } + conns := tk.ListConnections() + if len(conns) != 1 { + t.Fatalf("expected 1 connection, got %d", len(conns)) + } + if conns[0].Name != "prod-trino" { + t.Errorf("Name = %q, want 'prod-trino'", conns[0].Name) + } + if conns[0].Description != "Production data warehouse" { + t.Errorf("Description = %q", conns[0].Description) + } + if !conns[0].IsDefault { + t.Error("expected IsDefault=true for single connection") + } +} + +func TestListConnections_MultiMode(t *testing.T) { + tk, err := NewMulti(MultiConfig{ + DefaultConnection: trinoTestWarehouse, + Instances: map[string]Config{ + "warehouse": {Host: "wh.example.com", User: "trino", Port: trinoTestPort443, SSL: true, Description: "Analytics warehouse"}, + "elasticsearch": {Host: "es.example.com", User: "trino", Port: trinoTestPort443, SSL: true, Description: "Sales data"}, + }, + }) + if err != nil { + t.Fatalf("NewMulti error: %v", err) + } + + conns := tk.ListConnections() + if len(conns) != 2 { + t.Fatalf("expected 2 connections, got %d", len(conns)) + } + + byName := make(map[string]toolkit.ConnectionDetail, len(conns)) + for _, c := range conns { + byName[c.Name] = c + } + + wh, ok := byName["warehouse"] + if !ok { + t.Fatal("missing warehouse connection") + } + if wh.Description != "Analytics warehouse" { + t.Errorf("warehouse.Description = %q", wh.Description) + } + if !wh.IsDefault { + t.Error("warehouse should be default") + } + + es, ok := byName["elasticsearch"] + if !ok { + t.Fatal("missing elasticsearch connection") + } + if es.Description != "Sales data" { + t.Errorf("elasticsearch.Description = %q", es.Description) + } + if es.IsDefault { + t.Error("elasticsearch should not be default") + } +} + +func TestListConnections_ImplementsConnectionLister(t *testing.T) { + tk := &Toolkit{name: "test"} + var _ toolkit.ConnectionLister = tk // compile-time check + conns := tk.ListConnections() + if len(conns) != 1 { + t.Errorf("expected 1 connection, got %d", len(conns)) + } +} + +func TestBuildMultiserverConfig(t *testing.T) { + instances := map[string]Config{ + "warehouse": { + Host: "warehouse.example.com", User: "trino", Port: trinoTestPort443, + SSL: true, Catalog: "hive", Schema: "default", Password: "pass1", + }, + "elasticsearch": { + Host: "es.example.com", User: "es-user", Catalog: "elasticsearch", + SSL: true, Password: "pass2", + }, + "cassandra": { + Host: "cass.example.com", Catalog: "cassandra", + }, + } + defaultCfg := instances[trinoTestWarehouse] + + msCfg := buildMultiserverConfig(trinoTestWarehouse, defaultCfg, instances) + + // Primary should reflect the default instance. + if msCfg.Primary.Host != "warehouse.example.com" { + t.Errorf("Primary.Host = %q", msCfg.Primary.Host) + } + if msCfg.Primary.Source != "mcp-data-platform" { + t.Errorf("Primary.Source = %q", msCfg.Primary.Source) + } + if msCfg.Default != trinoTestWarehouse { + t.Errorf("Default = %q, want %q", msCfg.Default, trinoTestWarehouse) + } + + // Should have connections for non-default instances. + if len(msCfg.Connections) != 2 { + t.Fatalf("expected 2 connections, got %d", len(msCfg.Connections)) + } + + esCfg, ok := msCfg.Connections["elasticsearch"] + if !ok { + t.Fatal("missing elasticsearch connection") + } + if esCfg.Host != "es.example.com" { + t.Errorf("es.Host = %q", esCfg.Host) + } + if esCfg.User != "es-user" { + t.Errorf("es.User = %q", esCfg.User) + } + if esCfg.Catalog != "elasticsearch" { + t.Errorf("es.Catalog = %q", esCfg.Catalog) + } + if esCfg.SSL == nil || !*esCfg.SSL { + t.Error("es.SSL should be true") + } + + cassCfg := msCfg.Connections["cassandra"] + if cassCfg.Host != "cass.example.com" { + t.Errorf("cass.Host = %q", cassCfg.Host) + } + if cassCfg.SSL != nil { + t.Error("cass.SSL should be nil (not explicitly set)") + } +} + +func TestBuildToolkitOptions(t *testing.T) { + t.Run("empty config produces no options", func(t *testing.T) { + opts := buildToolkitOptions(Config{}, nil, nil) + if len(opts) != 0 { + t.Errorf("expected 0 options, got %d", len(opts)) + } + }) + + t.Run("read-only adds interceptor", func(t *testing.T) { + opts := buildToolkitOptions(Config{ReadOnly: true}, nil, nil) + if len(opts) != 1 { + t.Errorf("expected 1 option, got %d", len(opts)) + } + }) + + t.Run("descriptions and annotations add options", func(t *testing.T) { + opts := buildToolkitOptions(Config{ + Descriptions: map[string]string{"trino_query": "custom"}, + Annotations: map[string]AnnotationConfig{"trino_query": {}}, + }, nil, nil) + if len(opts) != 2 { + t.Errorf("expected 2 options, got %d", len(opts)) + } + }) + + t.Run("progress adds middleware", func(t *testing.T) { + opts := buildToolkitOptions(Config{ProgressEnabled: true}, nil, nil) + if len(opts) != 1 { + t.Errorf("expected 1 option, got %d", len(opts)) + } + }) + + t.Run("connection required adds middleware", func(t *testing.T) { + cr := NewConnectionRequiredMiddleware([]ConnectionDescription{ + {Name: "a"}, {Name: "b"}, + }) + opts := buildToolkitOptions(Config{}, nil, cr) + if len(opts) != 1 { + t.Errorf("expected 1 option, got %d", len(opts)) + } + }) + + t.Run("all features combined", func(t *testing.T) { + em := &ElicitationMiddleware{} + cr := NewConnectionRequiredMiddleware([]ConnectionDescription{ + {Name: "a"}, {Name: "b"}, + }) + opts := buildToolkitOptions(Config{ + ReadOnly: true, + Descriptions: map[string]string{"a": "b"}, + Annotations: map[string]AnnotationConfig{"a": {}}, + ProgressEnabled: true, + }, em, cr) + if len(opts) != 6 { //nolint:mnd // 6 option types: readonly + descs + annots + connRequired + progress + elicit + t.Errorf("expected 6 options, got %d", len(opts)) + } + }) +}