Skip to content

Commit

Permalink
test: Add more unit tests for server.go
Browse files Browse the repository at this point in the history
Signed-off-by: Noble Mittal <[email protected]>
  • Loading branch information
beingnoble03 committed Feb 1, 2025
1 parent 770dcf0 commit 2131df6
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 1 deletion.
47 changes: 46 additions & 1 deletion go/vt/vtctl/workflow/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,15 @@ type testTMClient struct {
createVReplicationWorkflowRequests map[uint32]*createVReplicationWorkflowRequestResponse
readVReplicationWorkflowRequests map[uint32]*readVReplicationWorkflowRequestResponse
updateVReplicationWorklowsRequests map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest
updateVReplicationWorklowRequests map[uint32]*updateVReplicationWorkflowRequestResponse
applySchemaRequests map[uint32][]*applySchemaRequestResponse
primaryPositions map[uint32]string
vdiffRequests map[uint32]*vdiffRequestResponse
refreshStateErrors map[uint32]error

// Stack of ReadVReplicationWorkflowsResponse to return, in order, for each shard
readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse
readVReplicationWorkflowsResponses map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse
validateVReplicationPermissionsResponses map[uint32]*validateVReplicationPermissionsResponse

env *testEnv // For access to the env config from tmc methods.
reverse atomic.Bool // Are we reversing traffic?
Expand All @@ -296,6 +298,7 @@ func newTestTMClient(env *testEnv) *testTMClient {
createVReplicationWorkflowRequests: make(map[uint32]*createVReplicationWorkflowRequestResponse),
readVReplicationWorkflowRequests: make(map[uint32]*readVReplicationWorkflowRequestResponse),
updateVReplicationWorklowsRequests: make(map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowsRequest),
updateVReplicationWorklowRequests: make(map[uint32]*updateVReplicationWorkflowRequestResponse),
applySchemaRequests: make(map[uint32][]*applySchemaRequestResponse),
readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse),
primaryPositions: make(map[uint32]string),
Expand Down Expand Up @@ -519,6 +522,17 @@ func (tmc *testTMClient) expectApplySchemaRequest(tabletID uint32, req *applySch
tmc.applySchemaRequests[tabletID] = append(tmc.applySchemaRequests[tabletID], req)
}

func (tmc *testTMClient) expectValidateVReplicationPermissionsResponse(tabletID uint32, req *validateVReplicationPermissionsResponse) {
tmc.mu.Lock()
defer tmc.mu.Unlock()

if tmc.validateVReplicationPermissionsResponses == nil {
tmc.validateVReplicationPermissionsResponses = make(map[uint32]*validateVReplicationPermissionsResponse)
}

tmc.validateVReplicationPermissionsResponses[tabletID] = req
}

// Note: ONLY breaks up change.SQL into individual statements and executes it. Does NOT fully implement ApplySchema.
func (tmc *testTMClient) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet, change *tmutils.SchemaChange) (*tabletmanagerdatapb.SchemaChangeResult, error) {
tmc.mu.Lock()
Expand Down Expand Up @@ -578,6 +592,17 @@ type applySchemaRequestResponse struct {
err error
}

type updateVReplicationWorkflowRequestResponse struct {
req *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest
res *tabletmanagerdatapb.UpdateVReplicationWorkflowResponse
err error
}

type validateVReplicationPermissionsResponse struct {
res *tabletmanagerdatapb.ValidateVReplicationPermissionsResponse
err error
}

func (tmc *testTMClient) expectVDiffRequest(tablet *topodatapb.Tablet, vrr *vdiffRequestResponse) {
tmc.mu.Lock()
defer tmc.mu.Unlock()
Expand Down Expand Up @@ -692,6 +717,15 @@ func (tmc *testTMClient) ReadVReplicationWorkflows(ctx context.Context, tablet *
}

func (tmc *testTMClient) UpdateVReplicationWorkflow(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.UpdateVReplicationWorkflowRequest) (*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse, error) {
tmc.mu.Lock()
defer tmc.mu.Unlock()
if expect := tmc.updateVReplicationWorklowRequests[tablet.Alias.Uid]; expect != nil {
if !proto.Equal(expect.req, req) {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected ReadVReplicationWorkflow request on tablet %s: got %+v, want %+v",
topoproto.TabletAliasString(tablet.Alias), req, expect)
}
return expect.res, expect.err
}
return &tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{
Result: &querypb.QueryResult{
RowsAffected: 1,
Expand All @@ -713,6 +747,11 @@ func (tmc *testTMClient) UpdateVReplicationWorkflows(ctx context.Context, tablet
}

func (tmc *testTMClient) ValidateVReplicationPermissions(ctx context.Context, tablet *topodatapb.Tablet, req *tabletmanagerdatapb.ValidateVReplicationPermissionsRequest) (*tabletmanagerdatapb.ValidateVReplicationPermissionsResponse, error) {
tmc.mu.Lock()
defer tmc.mu.Unlock()
if resp, ok := tmc.validateVReplicationPermissionsResponses[tablet.Alias.Uid]; ok {
return resp.res, resp.err
}
return &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{
User: "vt_filtered",
Ok: true,
Expand Down Expand Up @@ -777,6 +816,12 @@ func (tmc *testTMClient) AddUpdateVReplicationRequests(tabletUID uint32, req *ta
tmc.updateVReplicationWorklowsRequests[tabletUID] = req
}

func (tmc *testTMClient) AddUpdateVReplicationWorkflowRequestResponse(tabletUID uint32, reqres *updateVReplicationWorkflowRequestResponse) {
tmc.mu.Lock()
defer tmc.mu.Unlock()
tmc.updateVReplicationWorklowRequests[tabletUID] = reqres
}

func (tmc *testTMClient) getVReplicationWorkflowsResponse(key string) *tabletmanagerdatapb.ReadVReplicationWorkflowsResponse {
if len(tmc.readVReplicationWorkflowsResponses) == 0 {
return nil
Expand Down
178 changes: 178 additions & 0 deletions go/vt/vtctl/workflow/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/prototext"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -2397,3 +2399,179 @@ func TestCopySchemaShard(t *testing.T) {
assert.NoError(t, err)
assert.Empty(t, te.tmc.applySchemaRequests[200])
}

func TestValidateShardsHaveVReplicationPermissions(t *testing.T) {
ctx := context.Background()

sourceKeyspace := &testKeyspace{"source_keyspace", []string{"-"}}
targetKeyspace := &testKeyspace{"target_keyspace", []string{"-80", "80-"}}

te := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace)
defer te.close()

si1, err := te.ts.GetShard(ctx, targetKeyspace.KeyspaceName, targetKeyspace.ShardNames[0])
require.NoError(t, err)
si2, err := te.ts.GetShard(ctx, targetKeyspace.KeyspaceName, targetKeyspace.ShardNames[1])
require.NoError(t, err)

testcases := []struct {
name string
response *validateVReplicationPermissionsResponse
expectedErrContains string
}{
{
// Expect no error in this case.
name: "unimplemented error",
response: &validateVReplicationPermissionsResponse{
err: status.Error(codes.Unimplemented, "unimplemented test"),
},
},
{
name: "tmc error",
response: &validateVReplicationPermissionsResponse{
err: fmt.Errorf("tmc throws error"),
},
expectedErrContains: "tmc throws error",
},
{
name: "no permissions",
response: &validateVReplicationPermissionsResponse{
res: &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{
User: "vt_test_user",
Ok: false,
},
},
expectedErrContains: "vt_test_user does not have the required set of permissions",
},
{
name: "success",
response: &validateVReplicationPermissionsResponse{
res: &tabletmanagerdatapb.ValidateVReplicationPermissionsResponse{
User: "vt_filtered",
Ok: true,
},
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
te.tmc.expectValidateVReplicationPermissionsResponse(200, tc.response)
te.tmc.expectValidateVReplicationPermissionsResponse(210, tc.response)
err = te.ws.validateShardsHaveVReplicationPermissions(ctx, targetKeyspace.KeyspaceName, []*topo.ShardInfo{si1, si2})
if tc.expectedErrContains == "" {
assert.NoError(t, err)
return
}
assert.ErrorContains(t, err, tc.expectedErrContains)
})
}
}

func TestWorkflowUpdate(t *testing.T) {
ctx := context.Background()

sourceKeyspace := &testKeyspace{"source_keyspace", []string{"-"}}
targetKeyspace := &testKeyspace{"target_keyspace", []string{"-80", "80-"}}

te := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace)
defer te.close()

req := &vtctldatapb.WorkflowUpdateRequest{
Keyspace: targetKeyspace.KeyspaceName,
TabletRequest: &tabletmanagerdatapb.UpdateVReplicationWorkflowRequest{
Workflow: "wf1",
State: binlogdatapb.VReplicationWorkflowState_Running.Enum(),
},
}

testcases := []struct {
name string
response map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse
err map[uint32]error

// Match the tablet `changed` field from response.
expectedResponse map[uint32]bool
expectedErrContains string
}{
{
name: "one tablet stream changed",
response: map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{
200: {
Result: &querypb.QueryResult{
RowsAffected: 1,
},
},
210: {
Result: &querypb.QueryResult{
RowsAffected: 0,
},
},
},
expectedResponse: map[uint32]bool{
200: true,
210: false,
},
},
{
name: "two tablet stream changed",
response: map[uint32]*tabletmanagerdatapb.UpdateVReplicationWorkflowResponse{
200: {
Result: &querypb.QueryResult{
RowsAffected: 1,
},
},
210: {
Result: &querypb.QueryResult{
RowsAffected: 2,
},
},
},
expectedResponse: map[uint32]bool{
200: true,
210: true,
},
},
{
name: "tablet throws error",
err: map[uint32]error{
200: fmt.Errorf("test error from 200"),
},
expectedErrContains: "test error from 200",
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
// Add responses
for tabletID, resp := range tc.response {
te.tmc.AddUpdateVReplicationWorkflowRequestResponse(tabletID, &updateVReplicationWorkflowRequestResponse{
req: req.TabletRequest,
res: resp,
})
}
// Add errors
for tabletID, err := range tc.err {
te.tmc.AddUpdateVReplicationWorkflowRequestResponse(tabletID, &updateVReplicationWorkflowRequestResponse{
req: req.TabletRequest,
err: err,
})
}

res, err := te.ws.WorkflowUpdate(ctx, req)
if tc.expectedErrContains != "" {
assert.ErrorContains(t, err, tc.expectedErrContains)
return
}

assert.NoError(t, err)
for tabletID, changed := range tc.expectedResponse {
i := slices.IndexFunc(res.Details, func(det *vtctldatapb.WorkflowUpdateResponse_TabletInfo) bool {
return det.Tablet.Uid == tabletID
})
assert.NotEqual(t, -1, i)
assert.Equal(t, changed, res.Details[i].Changed)
}
})
}
}

0 comments on commit 2131df6

Please sign in to comment.