diff --git a/executor/executor.go b/executor/executor.go index 5b7c6c1d653db..f6b6688577beb 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -76,6 +76,7 @@ var ( ErrWrongValueCountOnRow = terror.ClassExecutor.New(codeWrongValueCountOnRow, "Column count doesn't match value count at row %d") ErrPasswordFormat = terror.ClassExecutor.New(codePasswordFormat, "The password hash doesn't have the expected format. Check if the correct password algorithm is being used with the PASSWORD() function.") ErrCantChangeTxCharacteristics = terror.ClassExecutor.New(codeErrCantChangeTxCharacteristics, "Transaction characteristics can't be changed while a transaction is in progress") + ErrPsManyParam = terror.ClassExecutor.New(mysql.ErrPsManyParam, mysql.MySQLErrName[mysql.ErrPsManyParam]) ) // Error codes. @@ -91,6 +92,7 @@ const ( codeWrongValueCountOnRow terror.ErrCode = 1136 // MySQL error code codePasswordFormat terror.ErrCode = 1827 // MySQL error code codeErrCantChangeTxCharacteristics terror.ErrCode = 1568 + codeErrPsManyParam terror.ErrCode = 1390 ) type baseExecutor struct { @@ -611,6 +613,7 @@ func init() { codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, codePasswordFormat: mysql.ErrPasswordFormat, codeErrCantChangeTxCharacteristics: mysql.ErrCantChangeTxCharacteristics, + codeErrPsManyParam: mysql.ErrPsManyParam, } terror.ErrClassToMySQLCodes[terror.ClassExecutor] = tableMySQLErrCodes } diff --git a/executor/prepared.go b/executor/prepared.go index 7d640af04381c..ae92d12a25c57 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -126,6 +126,13 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { } var extractor paramMarkerExtractor stmt.Accept(&extractor) + + // Prepare parameters should NOT over 2 bytes(MaxUint16) + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK. + if len(extractor.markers) > math.MaxUint16 { + return ErrPsManyParam + } + err = plan.Preprocess(e.ctx, stmt, e.is, true) if err != nil { return errors.Trace(err) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 1873a9cf2af5c..75f888c657515 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -14,6 +14,9 @@ package executor_test import ( + "math" + "strings" + "github.com/juju/errors" . "github.com/pingcap/check" "github.com/pingcap/tidb/executor" @@ -267,3 +270,28 @@ func (s *testSuite) TestPreparedNameResolver(c *C) { _, err = tk.Exec("prepare stmt from '(select * FROM t) union all (select * FROM t) order by a limit ?'") c.Assert(err.Error(), Equals, "[planner:1054]Unknown column 'a' in 'order clause'") } + +func (s *testSuite) TestPrepareMaxParamCountCheck(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (v int)") + normalSQL, normalParams := generateBatchSQL(math.MaxUint16) + _, err := tk.Exec(normalSQL, normalParams...) + c.Assert(err, IsNil) + + bigSQL, bigParams := generateBatchSQL(math.MaxUint16 + 2) + _, err = tk.Exec(bigSQL, bigParams...) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[executor:1390]Prepared statement contains too many placeholders") +} + +func generateBatchSQL(paramCount int) (sql string, paramSlice []interface{}) { + params := make([]interface{}, 0, paramCount) + placeholders := make([]string, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, i) + placeholders = append(placeholders, "(?)") + } + return "insert into t values " + strings.Join(placeholders, ","), params +}