Skip to content

Commit

Permalink
plan: support ? in Order By / Group By / Limit Offset clauses (ping…
Browse files Browse the repository at this point in the history
  • Loading branch information
dbjoa authored and zz-jason committed Oct 3, 2019
1 parent 89b35b3 commit c0c1360
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 36 deletions.
3 changes: 1 addition & 2 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sqlexec"
Expand Down Expand Up @@ -165,7 +164,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error {

// We try to build the real statement of preparedStmt.
for i := range prepared.Params {
prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0)
prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull()
}
var p plannercore.Plan
p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is)
Expand Down
51 changes: 51 additions & 0 deletions executor/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,54 @@ func generateBatchSQL(paramCount int) (sql string, paramSlice []interface{}) {
}
return "insert into t values " + strings.Join(placeholders, ","), params
}

func (s *testSuite) TestPreparedIssue8153(c *C) {
orgEnable := plannercore.PreparedPlanCacheEnabled()
orgCapacity := plannercore.PreparedPlanCacheCapacity
defer func() {
plannercore.SetPreparedPlanCache(orgEnable)
plannercore.PreparedPlanCacheCapacity = orgCapacity
}()
flags := []bool{false, true}
for _, flag := range flags {
var err error
plannercore.SetPreparedPlanCache(flag)
plannercore.PreparedPlanCacheCapacity = 100
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)")

tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`)
r := tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec(`set @param = 1`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec(`set @param = 2`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("3 1", "2 2", "1 3"))

tk.MustExec(`set @param = 3`)
_, err = tk.Exec(`execute stmt using @param;`)
c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'")

tk.MustExec(`set @param = '##'`)
r = tk.MustQuery(`execute stmt using @param;`)
r.Check(testkit.Rows("1 3", "2 2", "3 1"))

tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)")
tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`)

tk.MustExec(`set @a=1,@b=1`)
r = tk.MustQuery(`execute stmt using @a,@b;`)
r.Check(testkit.Rows("1 18"))

tk.MustExec(`set @a=1,@b=2`)
_, err = tk.Exec(`execute stmt using @a,@b;`)
c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'")
}
}
8 changes: 5 additions & 3 deletions expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo
sr.inToExpression(len(v.List), v.Not, &v.Type)
}
case *driver.ParamMarkerExpr:
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &Constant{Value: v.ValueExpr.Datum, RetType: tp}
var value Expression
value, sr.err = GetParamExpression(sr.ctx, v)
if sr.err != nil {
return retNode, false
}
sr.push(value)
case *ast.RowExpr:
sr.rowToScalarFunc(v)
Expand Down
72 changes: 72 additions & 0 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"golang.org/x/tools/container/intsets"
Expand Down Expand Up @@ -555,3 +556,74 @@ func DisableParseJSONFlag4Expr(expr Expression) {
}
expr.GetType().Flag &= ^mysql.ParseToJSONFlag
}

// DatumToConstant generates a Constant expression from a Datum.
func DatumToConstant(d types.Datum, tp byte) *Constant {
return &Constant{Value: d, RetType: types.NewFieldType(tp)}
}

// GetParamExpression generate a getparam function expression.
func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) {
useCache := ctx.GetSessionVars().StmtCtx.UseCache
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &Constant{Value: v.Datum, RetType: tp}
if useCache {
f, err := NewFunctionBase(ctx, ast.GetParam, &v.Type,
DatumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong))
if err != nil {
return nil, errors.Trace(err)
}
f.GetType().Tp = v.Type.Tp
value.DeferredExpr = f
}
return value, nil
}

// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr.
func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr {
return &ast.PositionExpr{P: p}
}

// PosFromPositionExpr generates a position value from PositionExpr.
func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) {
if v.P == nil {
return v.N, false, nil
}
value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr))
if err != nil {
return 0, true, err
}
pos, isNull, err := GetIntFromConstant(ctx, value)
if err != nil || isNull {
return 0, true, errors.Trace(err)
}
return pos, false, nil
}

// GetStringFromConstant gets a string value from the Constant expression.
func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) {
con, ok := value.(*Constant)
if !ok {
err := errors.Errorf("Not a Constant expression %+v", value)
return "", true, errors.Trace(err)
}
str, isNull, err := con.EvalString(ctx, chunk.Row{})
if err != nil || isNull {
return "", true, errors.Trace(err)
}
return str, false, nil
}

// GetIntFromConstant gets an interger value from the Constant expression.
func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) {
str, isNull, err := GetStringFromConstant(ctx, value)
if err != nil || isNull {
return 0, true, errors.Trace(err)
}
intNum, err := strconv.Atoi(str)
if err != nil {
return 0, true, nil
}
return intNum, false, nil
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@ require (
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
)

replace github.com/pingcap/parser => github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726 h1:AzGIEmaYVYMtmki
github.com/pingcap/kvproto v0.0.0-20190826051950-fc8799546726/go.mod h1:0gwbe1F2iBIjuQ9AH0DbQhL+Dpr5GofU8fgYyXk+ykk=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ=
github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw=
github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469 h1:JS/p4qMInVXTyV0kjFz+n0DBGn/n1T0cZDjEYHdTQow=
github.com/pingcap/parser v0.0.0-20190910040957-e998b3c52469/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA=
github.com/pingcap/pd v2.1.12+incompatible h1:6N3LBxx2aSZqT+IWEG730EDNDttP7dXO8J6yvBh+HXw=
github.com/pingcap/pd v2.1.12+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E=
github.com/pingcap/tidb-tools v2.1.3-0.20190116051332-34c808eef588+incompatible h1:e9Gi/LP9181HT3gBfSOeSBA+5JfemuE4aEAhqNgoE4k=
Expand Down Expand Up @@ -151,6 +149,8 @@ github.com/unrolled/render v0.0.0-20171102162132-65450fb6b2d3/go.mod h1:tu82oB5W
github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18 h1:MPPkRncZLN9Kh4MEFmbnK4h3BD7AUmskWv2+EeZJCCs=
github.com/xiang90/probing v0.0.0-20160813154853-07dd2e8dfe18/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/yookoala/realpath v1.0.0/go.mod h1:gJJMA9wuX7AcqLy1+ffPatSCySA1FQ2S8Ya9AIoYBpE=
github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e h1:oxazCGeHJ+CdDGPGVeIpIBzJ4dw0DNqDI5wdXPVZb8Q=
github.com/zz-jason/parser v0.0.0-20191003033834-cce7a9500e2e/go.mod h1:mnf7H9ngMZzobilLo3+bu86/+DSlGQBnmse9S5K8PKQ=
go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4=
Expand Down
14 changes: 14 additions & 0 deletions planner/core/cacheable_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren
checker.cacheable = false
return in, true
}
case *ast.OrderByClause:
for _, item := range node.Items {
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
checker.cacheable = false
return in, true
}
}
case *ast.GroupByClause:
for _, item := range node.Items {
if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker {
checker.cacheable = false
return in, true
}
}
case *ast.Limit:
if node.Count != nil {
if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker {
Expand Down
14 changes: 14 additions & 0 deletions planner/core/cacheable_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) {
Limit: limitStmt,
}
c.Assert(Cacheable(stmt), IsTrue)

paramExpr := &driver.ParamMarkerExpr{}
orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}}
stmt = &ast.SelectStmt{
OrderBy: orderByClause,
}
c.Assert(Cacheable(stmt), IsFalse)

valExpr := &driver.ValueExpr{}
orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}}
stmt = &ast.SelectStmt{
OrderBy: orderByClause,
}
c.Assert(Cacheable(stmt), IsTrue)
}
2 changes: 2 additions & 0 deletions planner/core/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (

codeWrongUsage = mysql.ErrWrongUsage
codeAmbiguous = mysql.ErrNonUniq
codeUnknown = mysql.ErrUnknown
codeUnknownColumn = mysql.ErrBadField
codeUnknownTable = mysql.ErrUnknownTable
codeWrongArguments = mysql.ErrWrongArguments
Expand Down Expand Up @@ -65,6 +66,7 @@ var (

ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage])
ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq])
ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown])
ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField])
ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable])
ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments])
Expand Down
31 changes: 23 additions & 8 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,10 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
value := &expression.Constant{Value: v.Datum, RetType: &v.Type}
er.ctxStack = append(er.ctxStack, value)
case *driver.ParamMarkerExpr:
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &expression.Constant{Value: v.Datum, RetType: tp}
if er.useCache() {
value.DeferredExpr = er.getParamExpression(v)
var value expression.Expression
value, er.err = expression.GetParamExpression(er.ctx, v)
if er.err != nil {
return retNode, false
}
er.ctxStack = append(er.ctxStack, value)
case *ast.VariableExpr:
Expand Down Expand Up @@ -1044,10 +1043,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) {
}

func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) {
if v.N > 0 && v.N <= er.schema.Len() {
er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1])
pos := v.N
str := strconv.Itoa(pos)
if v.P != nil {
stkLen := len(er.ctxStack)
val := er.ctxStack[stkLen-1]
intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val)
str = "?"
if err == nil {
if isNull {
return
}
pos = intNum
er.ctxStack = er.ctxStack[:stkLen-1]
}
er.err = err
}
if er.err == nil && pos > 0 && pos <= er.schema.Len() {
er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1])
} else {
er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause])
er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
}
}

Expand Down
Loading

0 comments on commit c0c1360

Please sign in to comment.