From bfb563227eb2b1c4d697e3d1dc42664dd951feda Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Tue, 31 Jul 2018 15:12:31 +0800 Subject: [PATCH] expression: handle max_allowed_packet warnings for pad functions (#7171) --- expression/builtin_string.go | 50 ++++++++++++++++++++++++++++--- expression/builtin_string_test.go | 42 ++++++++++++++++++++++++++ expression/errors.go | 4 +++ expression/evaluator_test.go | 2 ++ mysql/errname.go | 2 +- 5 files changed, 95 insertions(+), 5 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index d94b5e6e700f4..f14cd2aa8f224 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/hack" @@ -1759,24 +1760,33 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) SetBinFlagOrBinStr(args[0].GetType(), bf.tp) SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { - sig := &builtinLpadBinarySig{bf} + sig := &builtinLpadBinarySig{bf, maxAllowedPacket} return sig, nil } if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinLpadSig{bf} + sig := &builtinLpadSig{bf, maxAllowedPacket} return sig, nil } type builtinLpadBinarySig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinLpadBinarySig) Clone() builtinFunc { newSig := &builtinLpadBinarySig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1795,6 +1805,11 @@ func (b *builtinLpadBinarySig) evalString(row types.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) @@ -1814,11 +1829,13 @@ func (b *builtinLpadBinarySig) evalString(row types.Row) (string, bool, error) { type builtinLpadSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinLpadSig) Clone() builtinFunc { newSig := &builtinLpadSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1837,6 +1854,11 @@ func (b *builtinLpadSig) evalString(row types.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) @@ -1866,24 +1888,33 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) SetBinFlagOrBinStr(args[0].GetType(), bf.tp) SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { - sig := &builtinRpadBinarySig{bf} + sig := &builtinRpadBinarySig{bf, maxAllowedPacket} return sig, nil } if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinRpadSig{bf} + sig := &builtinRpadSig{bf, maxAllowedPacket} return sig, nil } type builtinRpadBinarySig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRpadBinarySig) Clone() builtinFunc { newSig := &builtinRpadBinarySig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1901,6 +1932,10 @@ func (b *builtinRpadBinarySig) evalString(row types.Row) (string, bool, error) { return "", true, errors.Trace(err) } targetLength := int(length) + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket)) + return "", true, nil + } padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { @@ -1921,11 +1956,13 @@ func (b *builtinRpadBinarySig) evalString(row types.Row) (string, bool, error) { type builtinRpadSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRpadSig) Clone() builtinFunc { newSig := &builtinRpadSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1944,6 +1981,11 @@ func (b *builtinRpadSig) evalString(row types.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 24bc7122fef71..5b6115b71f4bc 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -23,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/mock" @@ -1265,6 +1266,47 @@ func (s *testEvaluatorSuite) TestRpad(c *C) { } } +func (s *testEvaluatorSuite) TestRpadSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeLonglong}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + } + + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + rpad := &builtinRpadSig{base, 1000} + + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, "abc") + input.AppendString(0, "abc") + input.AppendInt64(1, 6) + input.AppendInt64(1, 10000) + input.AppendString(2, "123") + input.AppendString(2, "123") + + res, isNull, err := rpad.evalString(input.GetRow(0)) + c.Assert(res, Equals, "abc123") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + + res, isNull, err = rpad.evalString(input.GetRow(1)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) +} + func (s *testEvaluatorSuite) TestInstr(c *C) { defer testleak.AfterTest(c)() tbl := []struct { diff --git a/expression/errors.go b/expression/errors.go index 3f11882c78f95..951ae09ce5518 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -22,15 +22,18 @@ import ( // Error instances. var ( + // All the exported errors are defined here: ErrIncorrectParameterCount = terror.ClassExpression.New(mysql.ErrWrongParamcountToNativeFct, mysql.MySQLErrName[mysql.ErrWrongParamcountToNativeFct]) ErrDivisionByZero = terror.ClassExpression.New(mysql.ErrDivisionByZero, mysql.MySQLErrName[mysql.ErrDivisionByZero]) + // All the un-exported errors are defined here: errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist]) errZlibZData = terror.ClassTypes.New(mysql.ErrZlibZData, mysql.MySQLErrName[mysql.ErrZlibZData]) errIncorrectArgs = terror.ClassExpression.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) errUnknownCharacterSet = terror.ClassExpression.New(mysql.ErrUnknownCharacterSet, mysql.MySQLErrName[mysql.ErrUnknownCharacterSet]) errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value") errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) + errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) ) func init() { @@ -43,6 +46,7 @@ func init() { mysql.ErrUnknownCharacterSet: mysql.ErrUnknownCharacterSet, mysql.ErrInvalidDefault: mysql.ErrInvalidDefault, mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement, + mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes } diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index 42dc056f21bd1..7e3fde2918b18 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -45,6 +45,7 @@ func (s *testEvaluatorSuite) SetUpSuite(c *C) { s.Parser = parser.New() s.ctx = mock.NewContext() s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + s.ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864") } func (s *testEvaluatorSuite) TearDownSuite(c *C) { @@ -55,6 +56,7 @@ func (s *testEvaluatorSuite) SetUpTest(c *C) { } func (s *testEvaluatorSuite) TearDownTest(c *C) { + s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) testleak.AfterTest(c)() } diff --git a/mysql/errname.go b/mysql/errname.go index 35734f744ea1e..94097498bad6e 100644 --- a/mysql/errname.go +++ b/mysql/errname.go @@ -316,7 +316,7 @@ var MySQLErrName = map[uint16]string{ ErrUnknownTimeZone: "Unknown or incorrect time zone: '%-.64s'", ErrWarnInvalidTimestamp: "Invalid TIMESTAMP value in column '%s' at row %d", ErrInvalidCharacterString: "Invalid %s character string: '%.64s'", - ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than maxAllowedPacket (%d) - truncated", + ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than max_allowed_packet (%d) - truncated", ErrConflictingDeclarations: "Conflicting declarations: '%s%s' and '%s%s'", ErrSpNoRecursiveCreate: "Can't create a %s from within another stored routine", ErrSpAlreadyExists: "%s %s already exists",