diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 80e8d5e233e14..58c83eb7a7cd8 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -169,11 +169,14 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } // Set retType to BINARY(0) if all arguments are of type NULL. if fieldTp.Tp == mysql.TypeNull { - fieldTp.Flen, fieldTp.Decimal = 0, -1 + fieldTp.Flen, fieldTp.Decimal = 0, types.UnspecifiedLength types.SetBinChsClnFlag(fieldTp) } argTps := make([]types.EvalType, 0, l) for i := 0; i < l-1; i += 2 { + if args[i], err = wrapWithIsTrue(ctx, true, args[i]); err != nil { + return nil, err + } argTps = append(argTps, types.ETInt, tp) } if l%2 == 1 { @@ -221,7 +224,7 @@ func (b *builtinCaseWhenIntSig) Clone() builtinFunc { } // evalInt evals a builtinCaseWhenIntSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenIntSig) evalInt(row chunk.Row) (ret int64, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -257,7 +260,7 @@ func (b *builtinCaseWhenRealSig) Clone() builtinFunc { } // evalReal evals a builtinCaseWhenRealSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenRealSig) evalReal(row chunk.Row) (ret float64, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -293,7 +296,7 @@ func (b *builtinCaseWhenDecimalSig) Clone() builtinFunc { } // evalDecimal evals a builtinCaseWhenDecimalSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenDecimalSig) evalDecimal(row chunk.Row) (ret *types.MyDecimal, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -329,7 +332,7 @@ func (b *builtinCaseWhenStringSig) Clone() builtinFunc { } // evalString evals a builtinCaseWhenStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenStringSig) evalString(row chunk.Row) (ret string, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -365,7 +368,7 @@ func (b *builtinCaseWhenTimeSig) Clone() builtinFunc { } // evalTime evals a builtinCaseWhenTimeSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenTimeSig) evalTime(row chunk.Row) (ret types.Time, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -401,7 +404,7 @@ func (b *builtinCaseWhenDurationSig) Clone() builtinFunc { } // evalDuration evals a builtinCaseWhenDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenDurationSig) evalDuration(row chunk.Row) (ret types.Duration, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) @@ -437,7 +440,7 @@ func (b *builtinCaseWhenJSONSig) Clone() builtinFunc { } // evalJSON evals a builtinCaseWhenJSONSig. -// See https://dev.mysql.com/doc/refman/5.7/en/case.html +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case func (b *builtinCaseWhenJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull bool, err error) { var condition int64 args, l := b.getArgs(), len(b.getArgs()) diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 23a59ebdc22b8..00a4aef065412 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -40,6 +40,8 @@ func (s *testEvaluatorSuite) TestCaseWhen(c *C) { {[]interface{}{nil, 1, false, 2, 3}, 3}, {[]interface{}{1, jsonInt.GetMysqlJSON(), nil}, 3}, {[]interface{}{0, jsonInt.GetMysqlJSON(), nil}, nil}, + {[]interface{}{0.1, 1, 2}, 1}, + {[]interface{}{0.0, 1, 0.1, 2}, 2}, } fc := funcs[ast.Case] for _, t := range tbl {