From 5fed20518fc2e15c5132de8a3e6e2052541fea13 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Fri, 1 Feb 2019 14:07:50 +0800 Subject: [PATCH] expression: handle empty input and improve compatibility for `format` (#8797) (#9235) --- expression/builtin_string.go | 76 +++++++++++++++++-------- expression/builtin_string_test.go | 93 ++++++++++++++++++++----------- expression/errors.go | 2 + 3 files changed, 117 insertions(+), 54 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index d2fb81287d754..cf21aebd21bdc 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -2917,7 +2917,13 @@ func (c *formatFunctionClass) getFunction(ctx sessionctx.Context, args []Express return nil, errors.Trace(err) } argTps := make([]types.EvalType, 2, 3) - argTps[0], argTps[1] = types.ETString, types.ETString + argTps[1] = types.ETInt + argTp := args[0].GetType().EvalType() + if argTp == types.ETDecimal || argTp == types.ETInt { + argTps[0] = types.ETDecimal + } else { + argTps[0] = types.ETReal + } if len(args) == 3 { argTps = append(argTps, types.ETString) } @@ -2932,6 +2938,41 @@ func (c *formatFunctionClass) getFunction(ctx sessionctx.Context, args []Express return sig, nil } +// formatMaxDecimals limits the maximum number of decimal digits for result of +// function `format`, this value is same as `FORMAT_MAX_DECIMALS` in MySQL source code. +const formatMaxDecimals int64 = 30 + +// evalNumDecArgsForFormat evaluates first 2 arguments, i.e, x and d, for function `format`. +func evalNumDecArgsForFormat(f builtinFunc, row chunk.Row) (string, string, bool, error) { + var xStr string + arg0, arg1 := f.getArgs()[0], f.getArgs()[1] + ctx := f.getCtx() + if arg0.GetType().EvalType() == types.ETDecimal { + x, isNull, err := arg0.EvalDecimal(ctx, row) + if isNull || err != nil { + return "", "", isNull, err + } + xStr = x.String() + } else { + x, isNull, err := arg0.EvalReal(ctx, row) + if isNull || err != nil { + return "", "", isNull, err + } + xStr = strconv.FormatFloat(x, 'f', -1, 64) + } + d, isNull, err := arg1.EvalInt(ctx, row) + if isNull || err != nil { + return "", "", isNull, err + } + if d < 0 { + d = 0 + } else if d > formatMaxDecimals { + d = formatMaxDecimals + } + dStr := strconv.FormatInt(d, 10) + return xStr, dStr, false, nil +} + type builtinFormatWithLocaleSig struct { baseBuiltinFunc } @@ -2945,23 +2986,20 @@ func (b *builtinFormatWithLocaleSig) Clone() builtinFunc { // evalString evals FORMAT(X,D,locale). // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_format func (b *builtinFormatWithLocaleSig) evalString(row chunk.Row) (string, bool, error) { - x, isNull, err := b.args[0].EvalString(b.ctx, row) - if isNull || err != nil { - return "", true, errors.Trace(err) - } - - d, isNull, err := b.args[1].EvalString(b.ctx, row) + x, d, isNull, err := evalNumDecArgsForFormat(b, row) if isNull || err != nil { - return "", true, errors.Trace(err) + return "", isNull, err } - locale, isNull, err := b.args[2].EvalString(b.ctx, row) - if isNull || err != nil { - return "", true, errors.Trace(err) + if err != nil { + return "", false, err + } + if isNull { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errUnknownLocale.GenWithStackByArgs("NULL")) + locale = "en_US" } - formatString, err := mysql.GetLocaleFormatFunction(locale)(x, d) - return formatString, err != nil, errors.Trace(err) + return formatString, false, err } type builtinFormatSig struct { @@ -2977,18 +3015,12 @@ func (b *builtinFormatSig) Clone() builtinFunc { // evalString evals FORMAT(X,D). // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_format func (b *builtinFormatSig) evalString(row chunk.Row) (string, bool, error) { - x, isNull, err := b.args[0].EvalString(b.ctx, row) - if isNull || err != nil { - return "", true, errors.Trace(err) - } - - d, isNull, err := b.args[1].EvalString(b.ctx, row) + x, d, isNull, err := evalNumDecArgsForFormat(b, row) if isNull || err != nil { - return "", true, errors.Trace(err) + return "", isNull, err } - formatString, err := mysql.GetLocaleFormatFunction("en_US")(x, d) - return formatString, err != nil, errors.Trace(err) + return formatString, false, err } type fromBase64FunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 5db38ff2b5575..9a51ba5f17393 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1599,38 +1599,42 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { locale string ret interface{} }{ - {12332.1234561111111111111111111111111111111111111, 4, "en_US", "12,332.1234"}, + {12332.12341111111111111111111111111111111111111, 4, "en_US", "12,332.1234"}, {nil, 22, "en_US", nil}, } formatTests1 := []struct { number interface{} precision interface{} ret interface{} + warnings int }{ - {12332.123456, 4, "12,332.1234"}, - {12332.123456, 0, "12,332"}, - {12332.123456, -4, "12,332"}, - {-12332.123456, 4, "-12,332.1234"}, - {-12332.123456, 0, "-12,332"}, - {-12332.123456, -4, "-12,332"}, - {"12332.123456", "4", "12,332.1234"}, - {"12332.123456A", "4", "12,332.1234"}, - {"-12332.123456", "4", "-12,332.1234"}, - {"-12332.123456A", "4", "-12,332.1234"}, - {"A123345", "4", "0.0000"}, - {"-A123345", "4", "0.0000"}, - {"-12332.123456", "A", "-12,332"}, - {"12332.123456", "A", "12,332"}, - {"-12332.123456", "4A", "-12,332.1234"}, - {"12332.123456", "4A", "12,332.1234"}, - {"-A12332.123456", "A", "0"}, - {"A12332.123456", "A", "0"}, - {"-A12332.123456", "4A", "0.0000"}, - {"A12332.123456", "4A", "0.0000"}, - {"-.12332.123456", "4A", "-0.1233"}, - {".12332.123456", "4A", "0.1233"}, - {"12332.1234567890123456789012345678901", 22, "12,332.1234567890123456789012"}, - {nil, 22, nil}, + {12332.123444, 4, "12,332.1234", 0}, + {12332.123444, 0, "12,332", 0}, + {12332.123444, -4, "12,332", 0}, + {-12332.123444, 4, "-12,332.1234", 0}, + {-12332.123444, 0, "-12,332", 0}, + {-12332.123444, -4, "-12,332", 0}, + {"12332.123444", "4", "12,332.1234", 0}, + {"12332.123444A", "4", "12,332.1234", 1}, + {"-12332.123444", "4", "-12,332.1234", 0}, + {"-12332.123444A", "4", "-12,332.1234", 1}, + {"A123345", "4", "0.0000", 1}, + {"-A123345", "4", "0.0000", 1}, + {"-12332.123444", "A", "-12,332", 1}, + {"12332.123444", "A", "12,332", 1}, + {"-12332.123444", "4A", "-12,332.1234", 1}, + {"12332.123444", "4A", "12,332.1234", 1}, + {"-A12332.123444", "A", "0", 2}, + {"A12332.123444", "A", "0", 2}, + {"-A12332.123444", "4A", "0.0000", 2}, + {"A12332.123444", "4A", "0.0000", 2}, + {"-.12332.123444", "4A", "-0.1233", 2}, + {".12332.123444", "4A", "0.1233", 2}, + {"12332.1234567890123456789012345678901", 22, "12,332.1234567890110000000000", 0}, + {nil, 22, nil, 0}, + {1, 1024, "1.000000000000000000000000000000", 0}, + {"", 1, "0.0", 1}, + {1, "", "1", 1}, } formatTests2 := struct { number interface{} @@ -1644,9 +1648,15 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { locale string ret interface{} }{"-12332.123456", "4", "de_GE", nil} + formatTests4 := struct { + number interface{} + precision interface{} + locale interface{} + ret interface{} + }{1, 4, nil, "1.0000"} + fc := funcs[ast.Format] for _, tt := range formatTests { - fc := funcs[ast.Format] f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tt.number, tt.precision, tt.locale))) c.Assert(err, IsNil) c.Assert(f, NotNil) @@ -1655,31 +1665,50 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret)) } + origConfig := s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning + s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning = true for _, tt := range formatTests1 { - fc := funcs[ast.Format] f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tt.number, tt.precision))) c.Assert(err, IsNil) c.Assert(f, NotNil) r, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) - c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret)) + c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret), Commentf("test %v", tt)) + if tt.warnings > 0 { + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, tt.warnings, Commentf("test %v", tt)) + for i := 0; i < tt.warnings; i++ { + c.Assert(terror.ErrorEqual(types.ErrTruncated, warnings[i].Err), IsTrue, Commentf("test %v", tt)) + } + s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{}) + } } + s.ctx.GetSessionVars().StmtCtx.TruncateAsWarning = origConfig - fc2 := funcs[ast.Format] - f2, err := fc2.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale))) + f2, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale))) c.Assert(err, IsNil) c.Assert(f2, NotNil) r2, err := evalBuiltinFunc(f2, chunk.Row{}) c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not implemented"))) c.Assert(r2, testutil.DatumEquals, types.NewDatum(formatTests2.ret)) - fc3 := funcs[ast.Format] - f3, err := fc3.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests3.number, formatTests3.precision, formatTests3.locale))) + f3, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests3.number, formatTests3.precision, formatTests3.locale))) c.Assert(err, IsNil) c.Assert(f3, NotNil) r3, err := evalBuiltinFunc(f3, chunk.Row{}) c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not support for the specific locale"))) c.Assert(r3, testutil.DatumEquals, types.NewDatum(formatTests3.ret)) + + f4, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(formatTests4.number, formatTests4.precision, formatTests4.locale))) + c.Assert(err, IsNil) + c.Assert(f4, NotNil) + r4, err := evalBuiltinFunc(f4, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(r4, testutil.DatumEquals, types.NewDatum(formatTests4.ret)) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + c.Assert(terror.ErrorEqual(errUnknownLocale, warnings[0].Err), IsTrue) + s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{}) } func (s *testEvaluatorSuite) TestFromBase64(c *C) { diff --git a/expression/errors.go b/expression/errors.go index ce35d125bddbf..1ebc4ffdaba7b 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -41,6 +41,7 @@ var ( errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) errWarnOptionIgnored = terror.ClassExpression.New(mysql.WarnOptionIgnored, mysql.MySQLErrName[mysql.WarnOptionIgnored]) errTruncatedWrongValue = terror.ClassExpression.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) + errUnknownLocale = terror.ClassExpression.New(mysql.ErrUnknownLocale, mysql.MySQLErrName[mysql.ErrUnknownLocale]) ) func init() { @@ -60,6 +61,7 @@ func init() { mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, mysql.WarnOptionIgnored: mysql.WarnOptionIgnored, mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue, + mysql.ErrUnknownLocale: mysql.ErrUnknownLocale, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes }