Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for pad functions #7171

Merged
merged 10 commits into from
Jul 31, 2018
50 changes: 46 additions & 4 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -1760,24 +1761,33 @@ func (c *lpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinLpadBinarySig{bf}
sig := &builtinLpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinLpadSig{bf}
sig := &builtinLpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinLpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadBinarySig) Clone() builtinFunc {
newSig := &builtinLpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1796,6 +1806,11 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand All @@ -1815,11 +1830,13 @@ func (b *builtinLpadBinarySig) evalString(row chunk.Row) (string, bool, error) {

type builtinLpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinLpadSig) Clone() builtinFunc {
newSig := &builtinLpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1838,6 +1855,11 @@ func (b *builtinLpadSig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("lpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand Down Expand Up @@ -1867,24 +1889,33 @@ func (c *rpadFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx, args[1])
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
SetBinFlagOrBinStr(args[2].GetType(), bf.tp)

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) {
sig := &builtinRpadBinarySig{bf}
sig := &builtinRpadBinarySig{bf, maxAllowedPacket}
return sig, nil
}
if bf.tp.Flen *= 4; bf.tp.Flen > mysql.MaxBlobWidth {
bf.tp.Flen = mysql.MaxBlobWidth
}
sig := &builtinRpadSig{bf}
sig := &builtinRpadSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinRpadBinarySig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadBinarySig) Clone() builtinFunc {
newSig := &builtinRpadBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1902,6 +1933,10 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) {
return "", true, errors.Trace(err)
}
targetLength := int(length)
if uint64(targetLength) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
Expand All @@ -1922,11 +1957,13 @@ func (b *builtinRpadBinarySig) evalString(row chunk.Row) (string, bool, error) {

type builtinRpadSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRpadSig) Clone() builtinFunc {
newSig := &builtinRpadSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1945,6 +1982,11 @@ func (b *builtinRpadSig) evalString(row chunk.Row) (string, bool, error) {
}
targetLength := int(length)

if uint64(targetLength)*uint64(mysql.MaxBytesOfCharacter) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("rpad", b.maxAllowedPacket))
return "", true, nil
}

padStr, isNull, err := b.args[2].EvalString(b.ctx, row)
if isNull || err != nil {
return "", true, errors.Trace(err)
Expand Down
42 changes: 42 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -1266,6 +1267,47 @@ func (s *testEvaluatorSuite) TestRpad(c *C) {
}
}

func (s *testEvaluatorSuite) TestRpadSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
{Tp: mysql.TypeVarchar},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
&Column{Index: 2, RetType: colTypes[2]},
}

base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
rpad := &builtinRpadSig{base, 1000}

input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, "abc")
input.AppendString(0, "abc")
input.AppendInt64(1, 6)
input.AppendInt64(1, 10000)
input.AppendString(2, "123")
input.AppendString(2, "123")

res, isNull, err := rpad.evalString(input.GetRow(0))
c.Assert(res, Equals, "abc123")
c.Assert(isNull, IsFalse)
c.Assert(err, IsNil)

res, isNull, err = rpad.evalString(input.GetRow(1))
c.Assert(res, Equals, "")
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}

func (s *testEvaluatorSuite) TestInstr(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Expand Down
8 changes: 6 additions & 2 deletions expression/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ import (

// Error instances.
var (
// All the exported errors are defined here:
ErrIncorrectParameterCount = terror.ClassExpression.New(mysql.ErrWrongParamcountToNativeFct, mysql.MySQLErrName[mysql.ErrWrongParamcountToNativeFct])
ErrDivisionByZero = terror.ClassExpression.New(mysql.ErrDivisionByZero, mysql.MySQLErrName[mysql.ErrDivisionByZero])
ErrRegexp = terror.ClassExpression.New(mysql.ErrRegexp, mysql.MySQLErrName[mysql.ErrRegexp])
ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns])
ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat])

// All the un-exported errors are defined here:
errFunctionNotExists = terror.ClassExpression.New(mysql.ErrSpDoesNotExist, mysql.MySQLErrName[mysql.ErrSpDoesNotExist])
errZlibZData = terror.ClassTypes.New(mysql.ErrZlibZData, mysql.MySQLErrName[mysql.ErrZlibZData])
errIncorrectArgs = terror.ClassExpression.New(mysql.ErrWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])
errUnknownCharacterSet = terror.ClassExpression.New(mysql.ErrUnknownCharacterSet, mysql.MySQLErrName[mysql.ErrUnknownCharacterSet])
errDefaultValue = terror.ClassExpression.New(mysql.ErrInvalidDefault, "invalid default value")
errDeprecatedSyntaxNoReplacement = terror.ClassExpression.New(mysql.ErrWarnDeprecatedSyntaxNoReplacement, mysql.MySQLErrName[mysql.ErrWarnDeprecatedSyntaxNoReplacement])
errBadField = terror.ClassExpression.New(mysql.ErrBadField, mysql.MySQLErrName[mysql.ErrBadField])
ErrOperandColumns = terror.ClassExpression.New(mysql.ErrOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns])
ErrCutValueGroupConcat = terror.ClassExpression.New(mysql.ErrCutValueGroupConcat, mysql.MySQLErrName[mysql.ErrCutValueGroupConcat])
errWarnAllowedPacketOverflowed = terror.ClassExpression.New(mysql.ErrWarnAllowedPacketOverflowed, mysql.MySQLErrName[mysql.ErrWarnAllowedPacketOverflowed])
)

func init() {
Expand All @@ -49,6 +52,7 @@ func init() {
mysql.ErrWarnDeprecatedSyntaxNoReplacement: mysql.ErrWarnDeprecatedSyntaxNoReplacement,
mysql.ErrOperandColumns: mysql.ErrOperandColumns,
mysql.ErrRegexp: mysql.ErrRegexp,
mysql.ErrWarnAllowedPacketOverflowed: mysql.ErrWarnAllowedPacketOverflowed,
}
terror.ErrClassToMySQLCodes[terror.ClassExpression] = expressionMySQLErrCodes
}
Expand Down
2 changes: 2 additions & 0 deletions expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (s *testEvaluatorSuite) SetUpSuite(c *C) {
s.Parser = parser.New()
s.ctx = mock.NewContext()
s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local
s.ctx.GetSessionVars().SetSystemVar("max_allowed_packet", "67108864")
}

func (s *testEvaluatorSuite) TearDownSuite(c *C) {
Expand All @@ -58,6 +59,7 @@ func (s *testEvaluatorSuite) SetUpTest(c *C) {
}

func (s *testEvaluatorSuite) TearDownTest(c *C) {
s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
testleak.AfterTest(c)()
}

Expand Down
2 changes: 1 addition & 1 deletion mysql/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ var MySQLErrName = map[uint16]string{
ErrUnknownTimeZone: "Unknown or incorrect time zone: '%-.64s'",
ErrWarnInvalidTimestamp: "Invalid TIMESTAMP value in column '%s' at row %d",
ErrInvalidCharacterString: "Invalid %s character string: '%.64s'",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than maxAllowedPacket (%d) - truncated",
ErrWarnAllowedPacketOverflowed: "Result of %s() was larger than max_allowed_packet (%d) - truncated",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is original warning msg in mysql?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this error message comes from the mysql-server code base.

ErrConflictingDeclarations: "Conflicting declarations: '%s%s' and '%s%s'",
ErrSpNoRecursiveCreate: "Can't create a %s from within another stored routine",
ErrSpAlreadyExists: "%s %s already exists",
Expand Down