Skip to content

Commit a7a44c6

Browse files
vmgsystay
andauthored
evalengine: fix numeric coercibility (#14473)
Signed-off-by: Vicent Marti <[email protected]> Signed-off-by: Andres Taylor <[email protected]> Co-authored-by: Andres Taylor <[email protected]>
1 parent a7f0ead commit a7a44c6

28 files changed

+229
-181
lines changed

go/sqltypes/type.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func IsNull(t querypb.Type) bool {
136136
// switch statements for those who want to cover types
137137
// by their category.
138138
const (
139-
Unknown = -1
139+
Unknown = querypb.Type(-1)
140140
Null = querypb.Type_NULL_TYPE
141141
Int8 = querypb.Type_INT8
142142
Uint8 = querypb.Type_UINT8

go/test/endtoend/vtgate/queries/random/random_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -339,4 +339,5 @@ func TestBuggyQueries(t *testing.T) {
339339
mcmp.Exec("select count(tbl1.dname) as caggr1 from dept as tbl0 left join dept as tbl1 on tbl1.dname > tbl1.loc where tbl1.loc <=> tbl1.dname group by tbl1.dname order by tbl1.dname asc")
340340
mcmp.Exec("select count(*) from (select count(*) from dept as tbl0) as tbl0")
341341
mcmp.Exec("select count(*), count(*) from (select count(*) from dept as tbl0) as tbl0, dept as tbl1")
342+
mcmp.Exec(`select distinct case max(tbl0.ename) when min(tbl0.job) then 'sole' else count(case when false then -27 when 'gazelle' then tbl0.deptno end) end as caggr0 from emp as tbl0`)
342343
}

go/vt/vtgate/engine/opcode/constants.go

+6
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,19 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
139139
case AggregateUnassigned:
140140
return sqltypes.Null
141141
case AggregateGroupConcat:
142+
if typ == sqltypes.Unknown {
143+
return sqltypes.Unknown
144+
}
142145
if sqltypes.IsBinary(typ) {
143146
return sqltypes.Blob
144147
}
145148
return sqltypes.Text
146149
case AggregateMax, AggregateMin, AggregateAnyValue:
147150
return typ
148151
case AggregateSumDistinct, AggregateSum:
152+
if typ == sqltypes.Unknown {
153+
return sqltypes.Unknown
154+
}
149155
if sqltypes.IsIntegral(typ) || sqltypes.IsDecimal(typ) {
150156
return sqltypes.Decimal
151157
}

go/vt/vtgate/evalengine/api_literal.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func NewColumn(offset int, typ Type, original sqlparser.Expr) *Column {
223223
return &Column{
224224
Offset: offset,
225225
Type: typ.Type,
226-
Collation: defaultCoercionCollation(typ.Coll),
226+
Collation: typedCoercionCollation(typ.Type, typ.Coll),
227227
Original: original,
228228
dynamicTypeOffset: -1,
229229
}

go/vt/vtgate/evalengine/collation.go

+109-6
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,115 @@ limitations under the License.
1616

1717
package evalengine
1818

19-
import "vitess.io/vitess/go/mysql/collations"
19+
import (
20+
"vitess.io/vitess/go/mysql/collations"
21+
"vitess.io/vitess/go/mysql/collations/colldata"
22+
"vitess.io/vitess/go/sqltypes"
23+
)
2024

21-
func defaultCoercionCollation(id collations.ID) collations.TypedCollation {
22-
return collations.TypedCollation{
23-
Collation: id,
24-
Coercibility: collations.CoerceCoercible,
25-
Repertoire: collations.RepertoireUnicode,
25+
func typedCoercionCollation(typ sqltypes.Type, id collations.ID) collations.TypedCollation {
26+
switch {
27+
case sqltypes.IsNull(typ):
28+
return collationNull
29+
case sqltypes.IsNumber(typ) || sqltypes.IsDateOrTime(typ):
30+
return collationNumeric
31+
case typ == sqltypes.TypeJSON:
32+
return collationJSON
33+
default:
34+
return collations.TypedCollation{
35+
Collation: id,
36+
Coercibility: collations.CoerceCoercible,
37+
Repertoire: collations.RepertoireUnicode,
38+
}
2639
}
2740
}
41+
42+
func evalCollation(e eval) collations.TypedCollation {
43+
switch e := e.(type) {
44+
case nil:
45+
return collationNull
46+
case evalNumeric, *evalTemporal:
47+
return collationNumeric
48+
case *evalJSON:
49+
return collationJSON
50+
case *evalBytes:
51+
return e.col
52+
default:
53+
return collationBinary
54+
}
55+
}
56+
57+
func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (collations.TypedCollation, colldata.Coercion, colldata.Coercion, error) {
58+
if c1.Collation == c2.Collation {
59+
return c1, nil, nil, nil
60+
}
61+
62+
lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1)
63+
rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2)
64+
if !lt || !rt {
65+
if lt {
66+
return c1, nil, nil, nil
67+
}
68+
if rt {
69+
return c2, nil, nil, nil
70+
}
71+
return collationBinary, nil, nil, nil
72+
}
73+
74+
env := collations.Local()
75+
return colldata.Merge(env, c1, c2, colldata.CoercionOptions{
76+
ConvertToSuperset: true,
77+
ConvertWithCoercion: true,
78+
})
79+
}
80+
81+
func mergeAndCoerceCollations(left, right eval) (eval, eval, collations.TypedCollation, error) {
82+
lt := left.SQLType()
83+
rt := right.SQLType()
84+
85+
mc, coerceLeft, coerceRight, err := mergeCollations(evalCollation(left), evalCollation(right), lt, rt)
86+
if err != nil {
87+
return nil, nil, collations.TypedCollation{}, err
88+
}
89+
if coerceLeft == nil && coerceRight == nil {
90+
return left, right, mc, nil
91+
}
92+
93+
left1 := newEvalRaw(lt, left.(*evalBytes).bytes, mc)
94+
right1 := newEvalRaw(rt, right.(*evalBytes).bytes, mc)
95+
96+
if coerceLeft != nil {
97+
left1.bytes, err = coerceLeft(nil, left1.bytes)
98+
if err != nil {
99+
return nil, nil, collations.TypedCollation{}, err
100+
}
101+
}
102+
if coerceRight != nil {
103+
right1.bytes, err = coerceRight(nil, right1.bytes)
104+
if err != nil {
105+
return nil, nil, collations.TypedCollation{}, err
106+
}
107+
}
108+
return left1, right1, mc, nil
109+
}
110+
111+
type collationAggregation struct {
112+
cur collations.TypedCollation
113+
}
114+
115+
func (ca *collationAggregation) add(env *collations.Environment, tc collations.TypedCollation) error {
116+
if ca.cur.Collation == collations.Unknown {
117+
ca.cur = tc
118+
} else {
119+
var err error
120+
ca.cur, _, _, err = colldata.Merge(env, ca.cur, tc, colldata.CoercionOptions{ConvertToSuperset: true, ConvertWithCoercion: true})
121+
if err != nil {
122+
return err
123+
}
124+
}
125+
return nil
126+
}
127+
128+
func (ca *collationAggregation) result() collations.TypedCollation {
129+
return ca.cur
130+
}

go/vt/vtgate/evalengine/compiler_asm.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -4165,13 +4165,13 @@ func (asm *assembler) Fn_DATEADD_D(unit datetime.IntervalType, sub bool) {
41654165
}
41664166

41674167
tmp := env.vm.stack[env.vm.sp-2].(*evalTemporal)
4168-
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.TypedCollation{}, env.now)
4168+
env.vm.stack[env.vm.sp-2] = tmp.addInterval(interval, collations.Unknown, env.now)
41694169
env.vm.sp--
41704170
return 1
41714171
}, "FN DATEADD TEMPORAL(SP-2), INTERVAL(SP-1)")
41724172
}
41734173

4174-
func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.TypedCollation) {
4174+
func (asm *assembler) Fn_DATEADD_s(unit datetime.IntervalType, sub bool, col collations.ID) {
41754175
asm.adjustStack(-1)
41764176
asm.emit(func(env *ExpressionEnv) int {
41774177
var interval *datetime.Interval

go/vt/vtgate/evalengine/compiler_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ func TestCompilerSingle(t *testing.T) {
164164
expression string
165165
values []sqltypes.Value
166166
result string
167+
collation collations.ID
167168
}{
168169
{
169170
expression: "1 + column0",
@@ -489,6 +490,12 @@ func TestCompilerSingle(t *testing.T) {
489490
expression: `'2020-01-01' + interval month(date_sub(FROM_UNIXTIME(1234), interval 1 month))-1 month`,
490491
result: `CHAR("2020-12-01")`,
491492
},
493+
{
494+
expression: `case column0 when 1 then column1 else column2 end`,
495+
values: []sqltypes.Value{sqltypes.NewInt64(42), sqltypes.NewVarChar("sole"), sqltypes.NewInt64(0)},
496+
result: `VARCHAR("0")`,
497+
collation: collations.CollationUtf8mb4ID,
498+
},
492499
}
493500

494501
tz, _ := time.LoadLocation("Europe/Madrid")
@@ -524,6 +531,9 @@ func TestCompilerSingle(t *testing.T) {
524531
if expected.String() != tc.result {
525532
t.Fatalf("bad evaluation from eval engine: got %s, want %s", expected.String(), tc.result)
526533
}
534+
if tc.collation != collations.Unknown && tc.collation != expected.Collation() {
535+
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
536+
}
527537

528538
// re-run the same evaluation multiple times to ensure results are always consistent
529539
for i := 0; i < 8; i++ {
@@ -535,6 +545,9 @@ func TestCompilerSingle(t *testing.T) {
535545
if res.String() != tc.result {
536546
t.Errorf("bad evaluation from compiler: got %s, want %s (iteration %d)", res, tc.result, i)
537547
}
548+
if tc.collation != collations.Unknown && tc.collation != res.Collation() {
549+
t.Fatalf("bad collation evaluation from compiler: got %d, want %d", res.Collation(), tc.collation)
550+
}
538551
}
539552
})
540553
}

go/vt/vtgate/evalengine/eval.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
238238
fval, _ := fastparse.ParseFloat64(v.RawStr())
239239
return newEvalFloat(fval), nil
240240
default:
241-
e, err := valueToEval(v, defaultCoercionCollation(collation))
241+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
242242
if err != nil {
243243
return nil, err
244244
}
@@ -265,7 +265,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
265265
fval, _ := fastparse.ParseFloat64(v.RawStr())
266266
dec = decimal.NewFromFloat(fval)
267267
default:
268-
e, err := valueToEval(v, defaultCoercionCollation(collation))
268+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
269269
if err != nil {
270270
return nil, err
271271
}
@@ -285,7 +285,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
285285
i, err := fastparse.ParseInt64(v.RawStr(), 10)
286286
return newEvalInt64(i), err
287287
default:
288-
e, err := valueToEval(v, defaultCoercionCollation(collation))
288+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
289289
if err != nil {
290290
return nil, err
291291
}
@@ -304,7 +304,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
304304
u, err := fastparse.ParseUint64(v.RawStr(), 10)
305305
return newEvalUint64(u), err
306306
default:
307-
e, err := valueToEval(v, defaultCoercionCollation(collation))
307+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
308308
if err != nil {
309309
return nil, err
310310
}
@@ -315,15 +315,15 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
315315
case sqltypes.IsText(typ) || sqltypes.IsBinary(typ):
316316
switch {
317317
case v.IsText() || v.IsBinary():
318-
return newEvalRaw(v.Type(), v.Raw(), defaultCoercionCollation(collation)), nil
318+
return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil
319319
case sqltypes.IsText(typ):
320-
e, err := valueToEval(v, defaultCoercionCollation(collation))
320+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
321321
if err != nil {
322322
return nil, err
323323
}
324324
return evalToVarchar(e, collation, true)
325325
default:
326-
e, err := valueToEval(v, defaultCoercionCollation(collation))
326+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
327327
if err != nil {
328328
return nil, err
329329
}
@@ -333,7 +333,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
333333
case typ == sqltypes.TypeJSON:
334334
return json.NewFromSQL(v)
335335
case typ == sqltypes.Date:
336-
e, err := valueToEval(v, defaultCoercionCollation(collation))
336+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
337337
if err != nil {
338338
return nil, err
339339
}
@@ -344,7 +344,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
344344
}
345345
return d, nil
346346
case typ == sqltypes.Datetime || typ == sqltypes.Timestamp:
347-
e, err := valueToEval(v, defaultCoercionCollation(collation))
347+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
348348
if err != nil {
349349
return nil, err
350350
}
@@ -355,7 +355,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I
355355
}
356356
return dt, nil
357357
case typ == sqltypes.Time:
358-
e, err := valueToEval(v, defaultCoercionCollation(collation))
358+
e, err := valueToEval(v, typedCoercionCollation(v.Type(), collation))
359359
if err != nil {
360360
return nil, err
361361
}

go/vt/vtgate/evalengine/eval_temporal.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func (e *evalTemporal) isZero() bool {
140140
return e.dt.IsZero()
141141
}
142142

143-
func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collations.TypedCollation, now time.Time) eval {
143+
func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval {
144144
var tmp *evalTemporal
145145
var ok bool
146146

@@ -150,16 +150,16 @@ func (e *evalTemporal) addInterval(interval *datetime.Interval, strcoll collatio
150150
tmp.dt.Date, ok = e.dt.Date.AddInterval(interval)
151151
case tt == sqltypes.Time && !interval.Unit().HasDateParts():
152152
tmp = &evalTemporal{t: e.t}
153-
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, strcoll.Valid())
153+
tmp.dt.Time, tmp.prec, ok = e.dt.Time.AddInterval(interval, coll != collations.Unknown)
154154
case tt == sqltypes.Datetime || tt == sqltypes.Timestamp || (tt == sqltypes.Date && interval.Unit().HasTimeParts()) || (tt == sqltypes.Time && interval.Unit().HasDateParts()):
155155
tmp = e.toDateTime(int(e.prec), now)
156-
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, strcoll.Valid())
156+
tmp.dt, tmp.prec, ok = e.dt.AddInterval(interval, coll != collations.Unknown)
157157
}
158158
if !ok {
159159
return nil
160160
}
161-
if strcoll.Valid() {
162-
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), strcoll)
161+
if coll != collations.Unknown {
162+
return newEvalRaw(sqltypes.Char, tmp.ToRawBytes(), typedCoercionCollation(sqltypes.Char, coll))
163163
}
164164
return tmp
165165
}

go/vt/vtgate/evalengine/expr_bvar.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
7070

7171
tuple := make([]eval, 0, len(bvar.Values))
7272
for _, value := range bvar.Values {
73-
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), defaultCoercionCollation(collations.CollationForType(value.Type, bv.Collation)))
73+
e, err := valueToEval(sqltypes.MakeTrusted(value.Type, value.Value), typedCoercionCollation(value.Type, collations.CollationForType(value.Type, bv.Collation)))
7474
if err != nil {
7575
return nil, err
7676
}
@@ -86,7 +86,7 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
8686
if bv.typed() {
8787
typ = bv.Type
8888
}
89-
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), defaultCoercionCollation(collations.CollationForType(typ, bv.Collation)))
89+
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)))
9090
}
9191
}
9292

@@ -110,7 +110,7 @@ func (bv *BindVariable) typeof(env *ExpressionEnv) (ctype, error) {
110110
case sqltypes.BitNum:
111111
return ctype{Type: sqltypes.VarBinary, Flag: flagBit, Col: collationNumeric}, nil
112112
default:
113-
return ctype{Type: tt, Flag: 0, Col: defaultCoercionCollation(collations.CollationForType(tt, bv.Collation))}, nil
113+
return ctype{Type: tt, Flag: 0, Col: typedCoercionCollation(tt, collations.CollationForType(tt, bv.Collation))}, nil
114114
}
115115
}
116116

@@ -119,7 +119,7 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) {
119119

120120
if bvar.typed() {
121121
typ.Type = bvar.Type
122-
typ.Col = defaultCoercionCollation(collations.CollationForType(bvar.Type, bvar.Collation))
122+
typ.Col = typedCoercionCollation(bvar.Type, collations.CollationForType(bvar.Type, bvar.Collation))
123123
} else if c.dynamicTypes != nil {
124124
typ = c.dynamicTypes[bvar.dynamicTypeOffset]
125125
} else {

0 commit comments

Comments
 (0)