From 9dda778d7ea261ce4962485906659e091897ca85 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Thu, 23 Jan 2025 08:47:46 -0600 Subject: [PATCH] Finalize the Values OpCode and TestFindRouteValuesJoin Signed-off-by: Florent Poinsard --- go/vt/vtgate/engine/routing.go | 128 ++++++++++-------- go/vt/vtgate/engine/routing_parameter_test.go | 45 +++--- 2 files changed, 94 insertions(+), 79 deletions(-) diff --git a/go/vt/vtgate/engine/routing.go b/go/vt/vtgate/engine/routing.go index 47d4cea04a6..5c77d8b03a5 100644 --- a/go/vt/vtgate/engine/routing.go +++ b/go/vt/vtgate/engine/routing.go @@ -187,75 +187,83 @@ func (rp *RoutingParameters) findRoute(ctx context.Context, vcursor VCursor, bin case vindexes.MultiColumn: return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported multi column vindex for values") default: - if len(rp.Values) < 2 { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "values slice must at least be of length two for a values") - } - env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) - value, err := env.Evaluate(rp.Values[0]) - if err != nil { - return nil, nil, err - } + return rp.values(ctx, vcursor, bindVars) + } + default: + // Unreachable. + return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported opcode: %v", rp.Opcode) + } +} - rval, ok := rp.Values[0].(*evalengine.BindVariable) - if !ok { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot transform evalengine expr to bind variable for values") - } +// values is used by the "Values" OpCode. It takes a tuple of tuple in the bindVars (from a VALUES JOIN), and +// will split all the rows from the tuple to their own shards. Minimizing the amount of bindVars we send to each shard. +// rp.Values has to be formatted a certain way by the planner: The first index has to be the expression that returns a +// tuple of tuples. The second index has to be the offset where the vindex values can be found in every row of the outer tuple. +func (rp *RoutingParameters) values(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, error) { + if len(rp.Values) < 2 { + return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "values slice must at least be of length two for a values") + } + env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor) + value, err := env.Evaluate(rp.Values[0]) + if err != nil { + return nil, nil, err + } - tuple := value.TupleValues() + rval, ok := rp.Values[0].(*evalengine.BindVariable) + if !ok { + return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot transform evalengine expr to bind variable for values") + } - type rssValue struct { - rss *srvtopo.ResolvedShard - vals []sqltypes.Value - } - r := map[string]rssValue{} - for _, row := range tuple { - env.Row = nil - err = row.ForEachValue(func(bv sqltypes.Value) { - env.Row = append(env.Row, bv) - }) - if err != nil { - return nil, nil, err - } - val, err := env.Evaluate(rp.Values[1]) - if err != nil { - return nil, nil, err - } + tuple := value.TupleValues() - rss, _, err := resolveShards(ctx, vcursor, rp.Vindex.(vindexes.SingleColumn), rp.Keyspace, []sqltypes.Value{val.Value(vcursor.ConnCollation())}) - if err != nil { - return nil, nil, err - } - if len(rss) > 1 { - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "andres is confused") - } - r[rss[0].Target.String()] = rssValue{ - rss: rss[0], - vals: append(r[rss[0].Target.String()].vals, val.Value(collations.Unknown)), - } - } - var resultRss []*srvtopo.ResolvedShard - var resultBvs []map[string]*querypb.BindVariable - for _, rssVals := range r { - resultRss = append(resultRss, rssVals.rss) + type rssValue struct { + rss *srvtopo.ResolvedShard + vals []sqltypes.Value + } + r := map[string]rssValue{} + for _, row := range tuple { + env.Row = nil + err = row.ForEachValue(func(bv sqltypes.Value) { + env.Row = append(env.Row, bv) + }) + if err != nil { + return nil, nil, err + } + val, err := env.Evaluate(rp.Values[1]) + if err != nil { + return nil, nil, err + } - clonedBindVars := maps.Clone(bindVars) + rss, _, err := resolveShards(ctx, vcursor, rp.Vindex.(vindexes.SingleColumn), rp.Keyspace, []sqltypes.Value{val.Value(vcursor.ConnCollation())}) + if err != nil { + return nil, nil, err + } + if len(rss) > 1 { + return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "andres is confused") + } + r[rss[0].Target.String()] = rssValue{ + rss: rss[0], + vals: append(r[rss[0].Target.String()].vals, val.Value(collations.Unknown)), + } + } + var resultRss []*srvtopo.ResolvedShard + var resultBvs []map[string]*querypb.BindVariable + for _, rssVals := range r { + resultRss = append(resultRss, rssVals.rss) - newBv := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, - } - for _, s := range rssVals.vals { - newBv.Values = append(newBv.Values, sqltypes.ValueToProto(s)) - } + clonedBindVars := maps.Clone(bindVars) - clonedBindVars[rval.Key] = newBv - resultBvs = append(resultBvs, clonedBindVars) - } - return resultRss, resultBvs, nil + newBv := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, } - default: - // Unreachable. - return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unsupported opcode: %v", rp.Opcode) + for _, s := range rssVals.vals { + newBv.Values = append(newBv.Values, sqltypes.ValueToProto(s)) + } + + clonedBindVars[rval.Key] = newBv + resultBvs = append(resultBvs, clonedBindVars) } + return resultRss, resultBvs, nil } func (rp *RoutingParameters) systemQuery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) ([]*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, error) { diff --git a/go/vt/vtgate/engine/routing_parameter_test.go b/go/vt/vtgate/engine/routing_parameter_test.go index 17f6b9d4eca..863e27d874f 100644 --- a/go/vt/vtgate/engine/routing_parameter_test.go +++ b/go/vt/vtgate/engine/routing_parameter_test.go @@ -36,30 +36,37 @@ func TestFindRouteValuesJoin(t *testing.T) { bv := &querypb.BindVariable{ Type: querypb.Type_TUPLE, + Values: []*querypb.Value{ + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewVarBinary("hello")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(2), sqltypes.NewVarBinary("good morning")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(3), sqltypes.NewVarBinary("bonjour")}), + sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(4), sqltypes.NewVarBinary("bonjour")}), + }, } - bv.Values = append( - bv.Values, - sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(1), sqltypes.NewVarBinary("hello")}), - ) - bv.Values = append( - bv.Values, - sqltypes.TupleToProto([]sqltypes.Value{sqltypes.NewInt64(2), sqltypes.NewVarBinary("good morning")}), - ) - vc := newTestVCursor("0") + vc := newTestVCursor("-20", "20-") + vc.shardForKsid = []string{"-20", "-20", "20-", "20-"} rss, bvs, err := rp.findRoute(context.Background(), vc, map[string]*querypb.BindVariable{ valueBvName: bv, }) + require.NoError(t, err) - require.Len(t, rss, 1) - require.Len(t, bvs, 1) - var s []int64 - for _, value := range bvs[0][valueBvName].Values { - v := sqltypes.ProtoToValue(value) - require.Equal(t, sqltypes.Int64, v.Type()) - i, err := v.ToInt64() - require.NoError(t, err) - s = append(s, i) + require.Len(t, rss, 2) + require.Len(t, bvs, 2) + + expectedIdsPerShard := [][]int64{ + {1, 2}, + {3, 4}, + } + for i, ids := range expectedIdsPerShard { + var s []int64 + for _, value := range bvs[i][valueBvName].Values { + v := sqltypes.ProtoToValue(value) + require.Equal(t, sqltypes.Int64, v.Type()) + i, err := v.ToInt64() + require.NoError(t, err) + s = append(s, i) + } + require.Equal(t, ids, s) } - require.Equal(t, []int64{1, 2}, s) }