diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index fad48e31e0c..0575965c433 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -272,13 +272,15 @@ type testTMClient struct { createVReplicationWorkflowRequests map[uint32]*createVReplicationWorkflowRequestResponse readVReplicationWorkflowRequests map[uint32]*readVReplicationWorkflowRequestResponse updateVReplicationWorklowsRequests map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest + updateVReplicationWorklowRequests map[uint32]*updateVReplicationWorkflowRequestResponse applySchemaRequests map[uint32][]*applySchemaRequestResponse primaryPositions map[uint32]string vdiffRequests map[uint32]*vdiffRequestResponse refreshStateErrors map[uint32]error // Stack of ReadVReplicationWorkflowsResponse to return, in order, for each shard - readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse + readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse + validateVReplicationPermissionsResponses map[uint32]*validateVReplicationPermissionsResponse env *testEnv // For access to the env config from tmc methods. reverse atomic.Bool // Are we reversing traffic? @@ -296,6 +298,7 @@ func newTestTMClient(env *testEnv) *testTMClient { createVReplicationWorkflowRequests: make(map[uint32]*createVReplicationWorkflowRequestResponse), readVReplicationWorkflowRequests: make(map[uint32]*readVReplicationWorkflowRequestResponse), updateVReplicationWorklowsRequests: make(map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest), + updateVReplicationWorklowRequests: make(map[uint32]*updateVReplicationWorkflowRequestResponse), applySchemaRequests: make(map[uint32][]*applySchemaRequestResponse), readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse), primaryPositions: make(map[uint32]string), @@ -519,6 +522,17 @@ func (tmc *testTMClient) expectApplySchemaRequest(tabletID uint32, req *applySch tmc.applySchemaRequests[tabletID] = append(tmc.applySchemaRequests[tabletID], req) } +func (tmc *testTMClient) expectValidateVReplicationPermissionsResponse(tabletID uint32, req *validateVReplicationPermissionsResponse) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + + if tmc.validateVReplicationPermissionsResponses == nil { + tmc.validateVReplicationPermissionsResponses = make(map[uint32]*validateVReplicationPermissionsResponse) + } + + tmc.validateVReplicationPermissionsResponses[tabletID] = req +} + // Note: ONLY breaks up change.SQL into individual statements and executes it. Does NOT fully implement ApplySchema. func (tmc *testTMClient) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet, change *tmutils.SchemaChange) (*tabletmanagerdatapb.SchemaChangeResult, error) { tmc.mu.Lock() @@ -578,6 +592,17 @@ type applySchemaRequestResponse struct { err error } +type updateVReplicationWorkflowRequestResponse struct { + req *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest + res *tabletmanagerdatapb.UpdateVReplicationWorkflowResponse + err error +} + +type validateVReplicationPermissionsResponse struct { + res *tabletmanagerdatapb.ValidateVReplicationPermissionsResponse + err error +} + func (tmc *testTMClient) expectVDiffRequest(tablet *topodatapb.Tablet, vrr *vdiffRequestResponse) { tmc.mu.Lock() defer tmc.mu.Unlock() @@ -692,6 +717,15 @@ func (tmc *testTMClient) ReadVReplicationWorkflows(ctx context.Context, tablet * } func (tmc *testTMClient) UpdateVReplicationWorkflow(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest) (*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse, error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if expect := tmc.updateVReplicationWorklowRequests[tablet.Alias.Uid]; expect != nil { + if !proto.Equal(expect.req, req) { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected ReadVReplicationWorkflow request on tablet %s: got %+v, want %+v", + topoproto.TabletAliasString(tablet.Alias), req, expect) + } + return expect.res, expect.err + } return &tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{ Result: &querypb.QueryResult{ RowsAffected: 1, @@ -713,6 +747,11 @@ func (tmc *testTMClient) UpdateVReplicationWorkflows(ctx context.Context, tablet } func (tmc *testTMClient) ValidateVReplicationPermissions(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.ValidateVReplicationPermissionsRequest) (*tabletmanagerdatapb.ValidateVReplicationPermissionsResponse, error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if resp, ok := tmc.validateVReplicationPermissionsResponses[tablet.Alias.Uid]; ok { + return resp.res, resp.err + } return &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{ User: "vt_filtered", Ok: true, @@ -777,6 +816,12 @@ func (tmc *testTMClient) AddUpdateVReplicationRequests(tabletUID uint32, req *ta tmc.updateVReplicationWorklowsRequests[tabletUID] = req } +func (tmc *testTMClient) AddUpdateVReplicationWorkflowRequestResponse(tabletUID uint32, reqres *updateVReplicationWorkflowRequestResponse) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + tmc.updateVReplicationWorklowRequests[tabletUID] = reqres +} + func (tmc *testTMClient) getVReplicationWorkflowsResponse(key string) *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse { if len(tmc.readVReplicationWorkflowsResponses) == 0 { return nil diff --git a/go/vt/vtctl/workflow/server_test.go b/go/vt/vtctl/workflow/server_test.go index 4676b732245..0e7b2c21e75 100644 --- a/go/vt/vtctl/workflow/server_test.go +++ b/go/vt/vtctl/workflow/server_test.go @@ -29,6 +29,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/prototext" "vitess.io/vitess/go/sqltypes" @@ -2397,3 +2399,179 @@ func TestCopySchemaShard(t *testing.T) { assert.NoError(t, err) assert.Empty(t, te.tmc.applySchemaRequests[200]) } + +func TestValidateShardsHaveVReplicationPermissions(t *testing.T) { + ctx := context.Background() + + sourceKeyspace := &testKeyspace{"source_keyspace", []string{"-"}} + targetKeyspace := &testKeyspace{"target_keyspace", []string{"-80", "80-"}} + + te := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace) + defer te.close() + + si1, err := te.ts.GetShard(ctx, targetKeyspace.KeyspaceName, targetKeyspace.ShardNames[0]) + require.NoError(t, err) + si2, err := te.ts.GetShard(ctx, targetKeyspace.KeyspaceName, targetKeyspace.ShardNames[1]) + require.NoError(t, err) + + testcases := []struct { + name string + response *validateVReplicationPermissionsResponse + expectedErrContains string + }{ + { + // Expect no error in this case. + name: "unimplemented error", + response: &validateVReplicationPermissionsResponse{ + err: status.Error(codes.Unimplemented, "unimplemented test"), + }, + }, + { + name: "tmc error", + response: &validateVReplicationPermissionsResponse{ + err: fmt.Errorf("tmc throws error"), + }, + expectedErrContains: "tmc throws error", + }, + { + name: "no permissions", + response: &validateVReplicationPermissionsResponse{ + res: &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{ + User: "vt_test_user", + Ok: false, + }, + }, + expectedErrContains: "vt_test_user does not have the required set of permissions", + }, + { + name: "success", + response: &validateVReplicationPermissionsResponse{ + res: &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{ + User: "vt_filtered", + Ok: true, + }, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + te.tmc.expectValidateVReplicationPermissionsResponse(200, tc.response) + te.tmc.expectValidateVReplicationPermissionsResponse(210, tc.response) + err = te.ws.validateShardsHaveVReplicationPermissions(ctx, targetKeyspace.KeyspaceName, []*topo.ShardInfo{si1, si2}) + if tc.expectedErrContains == "" { + assert.NoError(t, err) + return + } + assert.ErrorContains(t, err, tc.expectedErrContains) + }) + } +} + +func TestWorkflowUpdate(t *testing.T) { + ctx := context.Background() + + sourceKeyspace := &testKeyspace{"source_keyspace", []string{"-"}} + targetKeyspace := &testKeyspace{"target_keyspace", []string{"-80", "80-"}} + + te := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace) + defer te.close() + + req := &vtctldatapb.WorkflowUpdateRequest{ + Keyspace: targetKeyspace.KeyspaceName, + TabletRequest: &tabletmanagerdatapb.UpdateVReplicationWorkflowRequest{ + Workflow: "wf1", + State: binlogdatapb.VReplicationWorkflowState_Running.Enum(), + }, + } + + testcases := []struct { + name string + response map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse + err map[uint32]error + + // Match the tablet `changed` field from response. + expectedResponse map[uint32]bool + expectedErrContains string + }{ + { + name: "one tablet stream changed", + response: map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{ + 200: { + Result: &querypb.QueryResult{ + RowsAffected: 1, + }, + }, + 210: { + Result: &querypb.QueryResult{ + RowsAffected: 0, + }, + }, + }, + expectedResponse: map[uint32]bool{ + 200: true, + 210: false, + }, + }, + { + name: "two tablet stream changed", + response: map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{ + 200: { + Result: &querypb.QueryResult{ + RowsAffected: 1, + }, + }, + 210: { + Result: &querypb.QueryResult{ + RowsAffected: 2, + }, + }, + }, + expectedResponse: map[uint32]bool{ + 200: true, + 210: true, + }, + }, + { + name: "tablet throws error", + err: map[uint32]error{ + 200: fmt.Errorf("test error from 200"), + }, + expectedErrContains: "test error from 200", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // Add responses + for tabletID, resp := range tc.response { + te.tmc.AddUpdateVReplicationWorkflowRequestResponse(tabletID, &updateVReplicationWorkflowRequestResponse{ + req: req.TabletRequest, + res: resp, + }) + } + // Add errors + for tabletID, err := range tc.err { + te.tmc.AddUpdateVReplicationWorkflowRequestResponse(tabletID, &updateVReplicationWorkflowRequestResponse{ + req: req.TabletRequest, + err: err, + }) + } + + res, err := te.ws.WorkflowUpdate(ctx, req) + if tc.expectedErrContains != "" { + assert.ErrorContains(t, err, tc.expectedErrContains) + return + } + + assert.NoError(t, err) + for tabletID, changed := range tc.expectedResponse { + i := slices.IndexFunc(res.Details, func(det *vtctldatapb.WorkflowUpdateResponse_TabletInfo) bool { + return det.Tablet.Uid == tabletID + }) + assert.NotEqual(t, -1, i) + assert.Equal(t, changed, res.Details[i].Changed) + } + }) + } +}