Skip to content

Commit

Permalink
release-1.0: cherry pick some commits (#5082)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros authored and hanfei1991 committed Nov 13, 2017
1 parent f80c332 commit 99f2d25
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 194 deletions.
159 changes: 81 additions & 78 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ func getCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
if isTemporalWithDate {
datetimeFound = true
}
lft := args[0].GetType()
for i := range args {
rft := args[i].GetType()
var tp types.EvalType
tp, isStr, isTemporalWithDate = temporalWithDateAsNumEvalType(args[i].GetType())
if isTemporalWithDate {
Expand All @@ -337,7 +339,8 @@ func getCmpTp4MinMax(args []Expression) (argTp types.EvalType) {
if !isStr {
isAllStr = false
}
cmpEvalType = getCmpType(cmpEvalType, tp)
cmpEvalType = getBaseCmpType(cmpEvalType, tp, lft, rft)
lft = rft
}
argTp = cmpEvalType
if cmpEvalType.IsStringKind() {
Expand Down Expand Up @@ -828,8 +831,18 @@ type compareFunctionClass struct {
op opcode.Op
}

// getCmpType gets the ClassType that the two args will be treated as when comparing.
func getCmpType(lhs, rhs types.EvalType) types.EvalType {
// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft.Tp == rft.Tp {
return types.ETString
}
if lft.Tp == mysql.TypeUnspecified {
lhs = rhs
} else {
rhs = lhs
}
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if lhs == types.ETInt && rhs == types.ETInt {
Expand All @@ -841,6 +854,64 @@ func getCmpType(lhs, rhs types.EvalType) types.EvalType {
return types.ETReal
}

// GetAccurateCmpType uses a more complex logic to decide the EvalType of the two args when compare with each other than
// getBaseCmpType does.
func GetAccurateCmpType(lhs, rhs Expression) types.EvalType {
lhsFieldType, rhsFieldType := lhs.GetType(), rhs.GetType()
lhsEvalType, rhsEvalType := lhsFieldType.EvalType(), rhsFieldType.EvalType()
cmpType := getBaseCmpType(lhsEvalType, rhsEvalType, lhsFieldType, rhsFieldType)
if (lhsEvalType.IsStringKind() && rhsFieldType.Tp == mysql.TypeJSON) ||
(lhsFieldType.Tp == mysql.TypeJSON && rhsEvalType.IsStringKind()) {
cmpType = types.ETJson
} else if cmpType == types.ETString && (types.IsTypeTime(lhsFieldType.Tp) || types.IsTypeTime(rhsFieldType.Tp)) {
// date[time] <cmp> date[time]
// string <cmp> date[time]
// compare as time
if lhsFieldType.Tp == rhsFieldType.Tp {
cmpType = lhsFieldType.EvalType()
} else {
cmpType = types.ETDatetime
}
} else if lhsFieldType.Tp == mysql.TypeDuration && rhsFieldType.Tp == mysql.TypeDuration {
// duration <cmp> duration
// compare as duration
cmpType = types.ETDuration
} else if cmpType == types.ETReal || cmpType == types.ETString {
_, isLHSConst := lhs.(*Constant)
_, isRHSConst := rhs.(*Constant)
if (lhsEvalType == types.ETDecimal && !isLHSConst && rhsEvalType.IsStringKind() && isRHSConst) ||
(rhsEvalType == types.ETDecimal && !isRHSConst && lhsEvalType.IsStringKind() && isLHSConst) {
/*
<non-const decimal expression> <cmp> <const string expression>
or
<const string expression> <cmp> <non-const decimal expression>
Do comparison as decimal rather than float, in order not to lose precision.
)*/
cmpType = types.ETDecimal
} else if isTemporalColumn(lhs) && isRHSConst ||
isTemporalColumn(rhs) && isLHSConst {
/*
<temporal column> <cmp> <non-temporal constant>
or
<non-temporal constant> <cmp> <temporal column>
Convert the constant to temporal type.
*/
col, isLHSColumn := lhs.(*Column)
if !isLHSColumn {
col = rhs.(*Column)
}
if col.GetType().Tp == mysql.TypeDuration {
cmpType = types.ETDuration
} else {
cmpType = types.ETDatetime
}
}
}
return cmpType
}

// isTemporalColumn checks if a expression is a temporal column,
// temporal column indicates time column or duration column.
func isTemporalColumn(expr Expression) bool {
Expand Down Expand Up @@ -870,8 +941,8 @@ func tryToConvertConstantInt(ctx context.Context, con *Constant) *Constant {
}
}

// refineConstantArg changes the constant argument to it's ceiling or flooring result by the given op.
func refineConstantArg(ctx context.Context, con *Constant, op opcode.Op) *Constant {
// RefineConstantArg changes the constant argument to it's ceiling or flooring result by the given op.
func RefineConstantArg(ctx context.Context, con *Constant, op opcode.Op) *Constant {
sc := ctx.GetSessionVars().StmtCtx
i64, err := con.Value.ToInt64(sc)
if err != nil {
Expand Down Expand Up @@ -913,12 +984,12 @@ func (c *compareFunctionClass) refineArgs(ctx context.Context, args []Expression
arg1, arg1IsCon := args[1].(*Constant)
// int non-constant [cmp] non-int constant
if arg0IsInt && !arg0IsCon && !arg1IsInt && arg1IsCon {
arg1 = refineConstantArg(ctx, arg1, c.op)
arg1 = RefineConstantArg(ctx, arg1, c.op)
return []Expression{args[0], arg1}
}
// non-int constant [cmp] int non-constant
if arg1IsInt && !arg1IsCon && !arg0IsInt && arg0IsCon {
arg0 = refineConstantArg(ctx, arg0, symmetricOp[c.op])
arg0 = RefineConstantArg(ctx, arg0, symmetricOp[c.op])
return []Expression{arg0, args[1]}
}
return args
Expand All @@ -930,77 +1001,9 @@ func (c *compareFunctionClass) getFunction(ctx context.Context, rawArgs []Expres
return nil, errors.Trace(err)
}
args := c.refineArgs(ctx, rawArgs)
lhsFieldType, rhsFieldType := args[0].GetType(), args[1].GetType()
lhsEvalType, rhsEvalType := lhsFieldType.EvalType(), rhsFieldType.EvalType()
cmpType := getCmpType(lhsEvalType, rhsEvalType)
if (lhsEvalType.IsStringKind() && rhsFieldType.Tp == mysql.TypeJSON) ||
(lhsFieldType.Tp == mysql.TypeJSON && rhsEvalType.IsStringKind()) {
sig, err = c.generateCmpSigs(ctx, args, types.ETJson)
} else if cmpType == types.ETString && (types.IsTypeTime(lhsFieldType.Tp) || types.IsTypeTime(rhsFieldType.Tp)) {
// date[time] <cmp> date[time]
// string <cmp> date[time]
// compare as time
if lhsFieldType.Tp == rhsFieldType.Tp {
sig, err = c.generateCmpSigs(ctx, args, lhsFieldType.EvalType())
} else {
sig, err = c.generateCmpSigs(ctx, args, types.ETDatetime)
}
} else if lhsFieldType.Tp == mysql.TypeDuration && rhsFieldType.Tp == mysql.TypeDuration {
// duration <cmp> duration
// compare as duration
sig, err = c.generateCmpSigs(ctx, args, types.ETDuration)
} else if cmpType == types.ETReal || cmpType == types.ETString {
_, isLHSConst := args[0].(*Constant)
_, isRHSConst := args[1].(*Constant)
if (lhsEvalType == types.ETDecimal && !isLHSConst && rhsEvalType.IsStringKind() && isRHSConst) ||
(rhsEvalType == types.ETDecimal && !isRHSConst && lhsEvalType.IsStringKind() && isLHSConst) {
/*
<non-const decimal expression> <cmp> <const string expression>
or
<const string expression> <cmp> <non-const decimal expression>
Do comparison as decimal rather than float, in order not to lose precision.
)*/
cmpType = types.ETDecimal
} else if isTemporalColumn(args[0]) && isRHSConst ||
isTemporalColumn(args[1]) && isLHSConst {
/*
<temporal column> <cmp> <non-temporal constant>
or
<non-temporal constant> <cmp> <temporal column>
Convert the constant to temporal type.
*/
col, isLHSColumn := args[0].(*Column)
if !isLHSColumn {
col = args[1].(*Column)
}
if col.GetType().Tp == mysql.TypeDuration {
sig, err = c.generateCmpSigs(ctx, args, types.ETDuration)
} else {
sig, err = c.generateCmpSigs(ctx, args, types.ETDatetime)
}
}
}
if err != nil {
return nil, errors.Trace(err)
}
if sig == nil {
switch cmpType {
case types.ETString:
sig, err = c.generateCmpSigs(ctx, args, types.ETString)
case types.ETInt:
sig, err = c.generateCmpSigs(ctx, args, types.ETInt)
case types.ETDecimal:
sig, err = c.generateCmpSigs(ctx, args, types.ETDecimal)
case types.ETReal:
sig, err = c.generateCmpSigs(ctx, args, types.ETReal)
}
if err != nil {
return nil, errors.Trace(err)
}
}
return sig, nil
cmpType := GetAccurateCmpType(args[0], args[1])
sig, err = c.generateCmpSigs(ctx, args, cmpType)
return sig, errors.Trace(err)
}

// generateCmpSigs generates compare function signatures.
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (c *inFunctionClass) getFunction(ctx context.Context, args []Expression) (s
}
argTps := make([]types.EvalType, len(args))
for i := range args {
argTps[i] = args[i].GetType().EvalType()
argTps[i] = args[0].GetType().EvalType()
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTps...)
bf.tp.Flen = 1
Expand Down
16 changes: 11 additions & 5 deletions expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ func (s *testEvaluatorSuite) TestBitCount(c *C) {
func (s *testEvaluatorSuite) TestInFunc(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.In]
time1 := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
time2 := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
time3 := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
time4 := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
decimal1 := types.NewDecFromFloatForTest(123.121)
decimal2 := types.NewDecFromFloatForTest(123.122)
decimal3 := types.NewDecFromFloatForTest(123.123)
decimal4 := types.NewDecFromFloatForTest(123.124)
time1 := types.Time{Time: types.FromGoTime(time.Date(2017, 1, 1, 1, 1, 1, 1, time.UTC)), Fsp: 6, Type: mysql.TypeDatetime}
time2 := types.Time{Time: types.FromGoTime(time.Date(2017, 1, 2, 1, 1, 1, 1, time.UTC)), Fsp: 6, Type: mysql.TypeDatetime}
time3 := types.Time{Time: types.FromGoTime(time.Date(2017, 1, 3, 1, 1, 1, 1, time.UTC)), Fsp: 6, Type: mysql.TypeDatetime}
time4 := types.Time{Time: types.FromGoTime(time.Date(2017, 1, 4, 1, 1, 1, 1, time.UTC)), Fsp: 6, Type: mysql.TypeDatetime}
duration1 := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Minute + 1*time.Second)}
duration2 := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Minute)}
duration3 := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Second)}
Expand All @@ -104,6 +108,8 @@ func (s *testEvaluatorSuite) TestInFunc(c *C) {
{[]interface{}{1, 0, 2, 3}, int64(0)},
{[]interface{}{1.1, 1.2, 1.3}, int64(0)},
{[]interface{}{1.1, 1.1, 1.2, 1.3}, int64(1)},
{[]interface{}{decimal1, decimal2, decimal3, decimal4}, int64(0)},
{[]interface{}{decimal1, decimal2, decimal3, decimal1}, int64(1)},
{[]interface{}{"1.1", "1.1", "1.2", "1.3"}, int64(1)},
{[]interface{}{"1.1", hack.Slice("1.1"), "1.2", "1.3"}, int64(1)},
{[]interface{}{hack.Slice("1.1"), "1.1", "1.2", "1.3"}, int64(1)},
Expand All @@ -119,7 +125,7 @@ func (s *testEvaluatorSuite) TestInFunc(c *C) {
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(fn, types.MakeDatums(tc.args...))
c.Assert(err, IsNil)
c.Assert(d.GetValue(), Equals, tc.res)
c.Assert(d.GetValue(), Equals, tc.res, Commentf("%v", types.MakeDatums(tc.args)))
}
}

Expand Down
6 changes: 5 additions & 1 deletion expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,11 @@ func IndexInfo2Cols(cols []*Column, index *model.IndexInfo) ([]*Column, []int) {
return retCols, lengths
}
retCols = append(retCols, col)
lengths = append(lengths, c.Length)
if c.Length != types.UnspecifiedLength && c.Length == col.RetType.Flen {
lengths = append(lengths, types.UnspecifiedLength)
} else {
lengths = append(lengths, c.Length)
}
}
return retCols, lengths
}
3 changes: 2 additions & 1 deletion expression/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ func (s *testEvaluatorSuite) TestColInfo2Col(c *C) {
func (s *testEvaluatorSuite) TestIndexInfo2Cols(c *C) {
defer testleak.AfterTest(c)()

col0, col1 := &Column{ColName: model.NewCIStr("col0")}, &Column{ColName: model.NewCIStr("col1")}
col0 := &Column{ColName: model.NewCIStr("col0"), RetType: types.NewFieldType(mysql.TypeLonglong)}
col1 := &Column{ColName: model.NewCIStr("col1"), RetType: types.NewFieldType(mysql.TypeLonglong)}
indexCol0, indexCol1 := &model.IndexColumn{Name: model.NewCIStr("col0")}, &model.IndexColumn{Name: model.NewCIStr("col1")}
indexInfo := &model.IndexInfo{Columns: []*model.IndexColumn{indexCol0, indexCol1}}

Expand Down
17 changes: 12 additions & 5 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -894,16 +894,23 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
return
}
}
leftArg := er.ctxStack[stkLen-lLen-1]
leftEt, leftIsNull := leftArg.GetType().EvalType(), leftArg.GetType().Tp == mysql.TypeNull
args := er.ctxStack[stkLen-lLen-1:]
leftEt, leftIsNull := args[0].GetType().EvalType(), args[0].GetType().Tp == mysql.TypeNull
if leftIsNull {
er.ctxStack = er.ctxStack[:stkLen-lLen-1]
er.ctxStack = append(er.ctxStack, expression.Null.Clone())
return
}
if leftEt == types.ETInt {
for i := 1; i < len(args); i++ {
if c, ok := args[i].(*expression.Constant); ok {
args[i] = expression.RefineConstantArg(er.ctx, c, opcode.EQ)
}
}
}
allSameType := true
for i := stkLen - lLen; i < stkLen; i++ {
if er.ctxStack[i].GetType().Tp != mysql.TypeNull && er.ctxStack[i].GetType().EvalType() != leftEt {
for _, arg := range args[1:] {
if arg.GetType().Tp != mysql.TypeNull && expression.GetAccurateCmpType(args[0], arg) != leftEt {
allSameType = false
break
}
Expand All @@ -914,7 +921,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
} else {
eqFunctions := make([]expression.Expression, 0, lLen)
for i := stkLen - lLen; i < stkLen; i++ {
expr, err := er.constructBinaryOpFunction(leftArg, er.ctxStack[i], ast.EQ)
expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
if err != nil {
er.err = err
return
Expand Down
3 changes: 2 additions & 1 deletion plan/physical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func isCoveringIndex(columns []*model.ColumnInfo, indexColumns []*model.IndexCol
}
isIndexColumn := false
for _, indexCol := range indexColumns {
if colInfo.Name.L == indexCol.Name.L && indexCol.Length == types.UnspecifiedLength {
isFullLen := indexCol.Length == types.UnspecifiedLength || indexCol.Length == colInfo.Flen
if colInfo.Name.L == indexCol.Name.L && isFullLen {
isIndexColumn = true
break
}
Expand Down
Loading

0 comments on commit 99f2d25

Please sign in to comment.