diff --git a/expression/builtin_string.go b/expression/builtin_string.go index bf1fbe1022569..c21ecae092aeb 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/chunk" @@ -1760,24 +1761,33 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) SetBinFlagOrBinStr(args[0].GetType(), bf.tp) SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { - sig := &builtinLpadBinarySig{bf} + sig := &builtinLpadBinarySig{bf, maxAllowedPacket} return sig, nil } if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinLpadSig{bf} + sig := &builtinLpadSig{bf, maxAllowedPacket} return sig, nil } type builtinLpadBinarySig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinLpadBinarySig) Clone() builtinFunc { newSig := &builtinLpadBinarySig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1796,6 +1806,11 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) @@ -1815,11 +1830,13 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) { type builtinLpadSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinLpadSig) Clone() builtinFunc { newSig := &builtinLpadSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1838,6 +1855,11 @@ func (b *builtinLpadSig) evalString(row chunk.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) @@ -1867,24 +1889,33 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1]) SetBinFlagOrBinStr(args[0].GetType(), bf.tp) SetBinFlagOrBinStr(args[2].GetType(), bf.tp) + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { - sig := &builtinRpadBinarySig{bf} + sig := &builtinRpadBinarySig{bf, maxAllowedPacket} return sig, nil } if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinRpadSig{bf} + sig := &builtinRpadSig{bf, maxAllowedPacket} return sig, nil } type builtinRpadBinarySig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRpadBinarySig) Clone() builtinFunc { newSig := &builtinRpadBinarySig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1902,6 +1933,10 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(err) } targetLength := int(length) + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket)) + return "", true, nil + } padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { @@ -1922,11 +1957,13 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) { type builtinRpadSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinRpadSig) Clone() builtinFunc { newSig := &builtinRpadSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -1945,6 +1982,11 @@ func (b *builtinRpadSig) evalString(row chunk.Row) (string, bool, error) { } targetLength := int(length) + if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket)) + return "", true, nil + } + padStr, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { return "", true, errors.Trace(err) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 64c304dc6c617..46dde9237fa0d 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -23,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/chunk" @@ -1266,6 +1267,47 @@ func (s *testEvaluatorSuite) TestRpad(c *C) { } } +func (s *testEvaluatorSuite) TestRpadSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeLonglong}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + } + + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + rpad := &builtinRpadSig{base, 1000} + + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, "abc") + input.AppendString(0, "abc") + input.AppendInt64(1, 6) + input.AppendInt64(1, 10000) + input.AppendString(2, "123") + input.AppendString(2, "123") + + res, isNull, err := rpad.evalString(input.GetRow(0)) + c.Assert(res, Equals, "abc123") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + + res, isNull, err = rpad.evalString(input.GetRow(1)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) +} + func (s *testEvaluatorSuite) TestInstr(c *C) { defer testleak.AfterTest(c)() tbl := []struct { diff --git a/expression/errors.go b/expression/errors.go index aceed242e00be..f342bf42916ae 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -22,10 +22,14 @@ import ( // Error instances. var ( + // All the exported errors are defined here: ErrIncorrectParameterCount = terror.ClassExpression.New(mysql.ErrWrongParamcountToNativeFct, mysql.MySQLErrName[mysql.ErrWrongParamcountToNativeFct]) ErrDivisionByZero = terror.ClassExpression.New(mysql.ErrDivisionByZero, mysql.MySQLErrName[mysql.ErrDivisionByZero]) ErrRegexp = terror.ClassExpression.New(mysql.ErrRegexp, mysql.MySQLErrName[mysql.ErrRegexp]) + ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) + ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat]) + // All the un-exported errors are defined here: errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist]) errZlibZData = terror.ClassTypes.New(mysql.ErrZlibZData, mysql.MySQLErrName[mysql.ErrZlibZData]) errIncorrectArgs = terror.ClassExpression.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) @@ -33,8 +37,7 @@ var ( errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value") errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement]) errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField]) - ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) - ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat]) + errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed]) ) func init() { @@ -49,6 +52,7 @@ func init() { mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.ErrOperandColumns: mysql.ErrOperandColumns, mysql.ErrRegexp: mysql.ErrRegexp, + mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed, } terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes } diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index 386d214a2fc9c..95999a46b6615 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -47,6 +47,7 @@ func (s *testEvaluatorSuite) SetUpSuite(c *C) { s.Parser = parser.New() s.ctx = mock.NewContext() s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + s.ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864") } func (s *testEvaluatorSuite) TearDownSuite(c *C) { @@ -58,6 +59,7 @@ func (s *testEvaluatorSuite) SetUpTest(c *C) { } func (s *testEvaluatorSuite) TearDownTest(c *C) { + s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) testleak.AfterTest(c)() } diff --git a/mysql/errname.go b/mysql/errname.go index 617bd69f331ad..f6283b38e5e59 100644 --- a/mysql/errname.go +++ b/mysql/errname.go @@ -316,7 +316,7 @@ var MySQLErrName = map[uint16]string{ ErrUnknownTimeZone: "Unknown or incorrect time zone: '%-.64s'", ErrWarnInvalidTimestamp: "Invalid TIMESTAMP value in column '%s' at row %d", ErrInvalidCharacterString: "Invalid %s character string: '%.64s'", - ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than maxAllowedPacket (%d) - truncated", + ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than max_allowed_packet (%d) - truncated", ErrConflictingDeclarations: "Conflicting declarations: '%s%s' and '%s%s'", ErrSpNoRecursiveCreate: "Can't create a %s from within another stored routine", ErrSpAlreadyExists: "%s %s already exists",