diff --git a/executor/prepared.go b/executor/prepared.go index d2cf1c6b9ba03..bdf1e5288d69a 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -26,7 +26,6 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" @@ -161,7 +160,7 @@ func (e *PrepareExec) Next(ctx context.Context, chk *chunk.Chunk) error { // We try to build the real statement of preparedStmt. for i := range prepared.Params { - prepared.Params[i].(*driver.ParamMarkerExpr).Datum = types.NewIntDatum(0) + prepared.Params[i].(*driver.ParamMarkerExpr).Datum.SetNull() } var p plannercore.Plan p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index 07e54ef4ff356..1e56b37ee94e9 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -688,3 +688,60 @@ func (s *testSuite) TestPrepareDealloc(c *C) { tk.MustExec("deallocate prepare stmt4") c.Assert(tk.Se.PreparedPlanCache().Size(), Equals, 0) } + +func (s *testSuite) TestPreparedIssue8153(c *C) { + orgEnable := plannercore.PreparedPlanCacheEnabled() + orgCapacity := plannercore.PreparedPlanCacheCapacity + orgMemGuardRatio := plannercore.PreparedPlanCacheMemoryGuardRatio + orgMaxMemory := plannercore.PreparedPlanCacheMaxMemory + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + plannercore.PreparedPlanCacheCapacity = orgCapacity + plannercore.PreparedPlanCacheMemoryGuardRatio = orgMemGuardRatio + plannercore.PreparedPlanCacheMaxMemory = orgMaxMemory + }() + flags := []bool{false, true} + for _, flag := range flags { + var err error + plannercore.SetPreparedPlanCache(flag) + plannercore.PreparedPlanCacheCapacity = 100 + plannercore.PreparedPlanCacheMemoryGuardRatio = 0.1 + plannercore.PreparedPlanCacheMaxMemory, err = memory.MemTotal() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int)") + tk.MustExec("insert into t (a, b) values (1,3), (2,2), (3,1)") + + tk.MustExec(`prepare stmt from 'select * from t order by ? asc'`) + r := tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 1`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec(`set @param = 2`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("3 1", "2 2", "1 3")) + + tk.MustExec(`set @param = 3`) + _, err = tk.Exec(`execute stmt using @param;`) + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'") + + tk.MustExec(`set @param = '##'`) + r = tk.MustQuery(`execute stmt using @param;`) + r.Check(testkit.Rows("1 3", "2 2", "3 1")) + + tk.MustExec("insert into t (a, b) values (1,1), (1,2), (2,1), (2,3), (3,2), (3,3)") + tk.MustExec(`prepare stmt from 'select ?, sum(a) from t group by ?'`) + + tk.MustExec(`set @a=1,@b=1`) + r = tk.MustQuery(`execute stmt using @a,@b;`) + r.Check(testkit.Rows("1 18")) + + tk.MustExec(`set @a=1,@b=2`) + _, err = tk.Exec(`execute stmt using @a,@b;`) + c.Assert(err.Error(), Equals, "[planner:1056]Can't group on 'sum(a)'") + } +} diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 46bdac88a329c..518c833452226 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -151,7 +151,7 @@ func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok boo } case *driver.ParamMarkerExpr: var value Expression - value, sr.err = GetParamExpression(sr.ctx, v, sr.useCache()) + value, sr.err = GetParamExpression(sr.ctx, v) if sr.err != nil { return retNode, false } diff --git a/expression/util.go b/expression/util.go index e5805fb3ddcab..c5e301b232d40 100644 --- a/expression/util.go +++ b/expression/util.go @@ -511,7 +511,8 @@ func DatumToConstant(d types.Datum, tp byte) *Constant { } // GetParamExpression generate a getparam function expression. -func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCache bool) (Expression, error) { +func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr) (Expression, error) { + useCache := ctx.GetSessionVars().StmtCtx.UseCache tp := types.NewFieldType(mysql.TypeUnspecified) types.DefaultParamTypeForValue(v.GetValue(), tp) value := &Constant{Value: v.Datum, RetType: tp} @@ -526,3 +527,51 @@ func GetParamExpression(ctx sessionctx.Context, v *driver.ParamMarkerExpr, useCa } return value, nil } + +// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. +func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { + return &ast.PositionExpr{P: p} +} + +// PosFromPositionExpr generates a position value from PositionExpr. +func PosFromPositionExpr(ctx sessionctx.Context, v *ast.PositionExpr) (int, bool, error) { + if v.P == nil { + return v.N, false, nil + } + value, err := GetParamExpression(ctx, v.P.(*driver.ParamMarkerExpr)) + if err != nil { + return 0, true, err + } + pos, isNull, err := GetIntFromConstant(ctx, value) + if err != nil || isNull { + return 0, true, errors.Trace(err) + } + return pos, false, nil +} + +// GetStringFromConstant gets a string value from the Constant expression. +func GetStringFromConstant(ctx sessionctx.Context, value Expression) (string, bool, error) { + con, ok := value.(*Constant) + if !ok { + err := errors.Errorf("Not a Constant expression %+v", value) + return "", true, errors.Trace(err) + } + str, isNull, err := con.EvalString(ctx, chunk.Row{}) + if err != nil || isNull { + return "", true, errors.Trace(err) + } + return str, false, nil +} + +// GetIntFromConstant gets an interger value from the Constant expression. +func GetIntFromConstant(ctx sessionctx.Context, value Expression) (int, bool, error) { + str, isNull, err := GetStringFromConstant(ctx, value) + if err != nil || isNull { + return 0, true, errors.Trace(err) + } + intNum, err := strconv.Atoi(str) + if err != nil { + return 0, true, nil + } + return intNum, false, nil +} diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 76edbd01a2d12..49e08eb8227b1 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -55,6 +55,20 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren checker.cacheable = false return in, true } + case *ast.OrderByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } + case *ast.GroupByClause: + for _, item := range node.Items { + if _, isParamMarker := item.Expr.(*driver.ParamMarkerExpr); isParamMarker { + checker.cacheable = false + return in, true + } + } case *ast.Limit: if node.Count != nil { if _, isParamMarker := node.Count.(*driver.ParamMarkerExpr); isParamMarker { diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index 67278f5cfe99b..8f3d287701533 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -177,4 +177,18 @@ func (s *testCacheableSuite) TestCacheable(c *C) { Limit: limitStmt, } c.Assert(Cacheable(stmt), IsTrue) + + paramExpr := &driver.ParamMarkerExpr{} + orderByClause := &ast.OrderByClause{Items: []*ast.ByItem{{Expr: paramExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsFalse) + + valExpr := &driver.ValueExpr{} + orderByClause = &ast.OrderByClause{Items: []*ast.ByItem{{Expr: valExpr}}} + stmt = &ast.SelectStmt{ + OrderBy: orderByClause, + } + c.Assert(Cacheable(stmt), IsTrue) } diff --git a/planner/core/errors.go b/planner/core/errors.go index 89c81fa47b086..48645a062988a 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -28,6 +28,7 @@ const ( codeWrongUsage = mysql.ErrWrongUsage codeAmbiguous = mysql.ErrNonUniq + codeUnknown = mysql.ErrUnknown codeUnknownColumn = mysql.ErrBadField codeUnknownTable = mysql.ErrUnknownTable codeWrongArguments = mysql.ErrWrongArguments @@ -64,6 +65,7 @@ var ( ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) + ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 1da8e11ed81ce..dec7ba06af6cf 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -756,7 +756,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: var value expression.Expression - value, er.err = expression.GetParamExpression(er.ctx, v, er.useCache()) + value, er.err = expression.GetParamExpression(er.ctx, v) if er.err != nil { return retNode, false } @@ -941,10 +941,26 @@ func (er *expressionRewriter) isNullToExpression(v *ast.IsNullExpr) { } func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { - if v.N > 0 && v.N <= er.schema.Len() { - er.ctxStack = append(er.ctxStack, er.schema.Columns[v.N-1]) + pos := v.N + str := strconv.Itoa(pos) + if v.P != nil { + stkLen := len(er.ctxStack) + val := er.ctxStack[stkLen-1] + intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) + str = "?" + if err == nil { + if isNull { + return + } + pos = intNum + er.ctxStack = er.ctxStack[:stkLen-1] + } + er.err = err + } + if er.err == nil && pos > 0 && pos <= er.schema.Len() { + er.ctxStack = append(er.ctxStack, er.schema.Columns[pos-1]) } else { - er.err = ErrUnknownColumn.GenWithStackByArgs(strconv.Itoa(v.N), clauseMsg[er.b.curClause]) + er.err = ErrUnknownColumn.GenWithStackByArgs(str, clauseMsg[er.b.curClause]) } } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index fd5b5ffd58460..8179ee33fbf87 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -818,6 +817,23 @@ func (by *ByItems) Clone() *ByItems { return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} } +// itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem +type itemTransformer struct { +} + +func (t *itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { + switch n := inNode.(type) { + case *driver.ParamMarkerExpr: + newNode := expression.ConstructPositionExpr(n) + return newNode, true + } + return inNode, false +} + +func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { + return inNode, false +} + func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause @@ -826,7 +842,10 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper } sort := LogicalSort{}.Init(b.ctx) exprs := make([]*ByItems, 0, len(byItems)) + transformer := &itemTransformer{} for _, item := range byItems { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) it, np, err := b.rewrite(item.Expr, p, aggMapper, true) if err != nil { return nil, errors.Trace(err) @@ -843,7 +862,27 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper // getUintForLimitOffset gets uint64 value for limit/offset. // For ordinary statement, limit/offset should be uint64 constant value. // For prepared statement, limit/offset is string. We should convert it to uint64. -func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint64, error) { +func getUintForLimitOffset(ctx sessionctx.Context, n ast.Node) (uint64, error) { + var val interface{} + switch v := n.(type) { + case *driver.ValueExpr: + val = v.GetValue() + case *driver.ParamMarkerExpr: + param, err := expression.GetParamExpression(ctx, v) + if err != nil { + return 0, errors.Trace(err) + } + str, isNull, err := expression.GetStringFromConstant(ctx, param) + if err != nil { + return 0, errors.Trace(err) + } + if isNull { + return 0, nil + } + val = str + default: + return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", v) + } switch v := val.(type) { case uint64: return v, nil @@ -852,22 +891,23 @@ func getUintForLimitOffset(sc *stmtctx.StatementContext, val interface{}) (uint6 return uint64(v), nil } case string: + sc := ctx.GetSessionVars().StmtCtx uVal, err := types.StrToUint(sc, v) return uVal, errors.Trace(err) } return 0, errors.Errorf("Invalid type %T for LogicalLimit/Offset", val) } -func extractLimitCountOffset(sc *stmtctx.StatementContext, limit *ast.Limit) (count uint64, +func extractLimitCountOffset(ctx sessionctx.Context, limit *ast.Limit) (count uint64, offset uint64, err error) { if limit.Count != nil { - count, err = getUintForLimitOffset(sc, limit.Count.(ast.ValueExpr).GetValue()) + count, err = getUintForLimitOffset(ctx, limit.Count) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } } if limit.Offset != nil { - offset, err = getUintForLimitOffset(sc, limit.Offset.(ast.ValueExpr).GetValue()) + offset, err = getUintForLimitOffset(ctx, limit.Offset) if err != nil { return 0, 0, ErrWrongArguments.GenWithStackByArgs("LIMIT") } @@ -881,8 +921,7 @@ func (b *PlanBuilder) buildLimit(src LogicalPlan, limit *ast.Limit) (LogicalPlan offset, count uint64 err error ) - sc := b.ctx.GetSessionVars().StmtCtx - if count, offset, err = extractLimitCountOffset(sc, limit); err != nil { + if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { return nil, err } @@ -1152,16 +1191,22 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega // gbyResolver resolves group by items from select fields. type gbyResolver struct { - fields []*ast.SelectField - schema *expression.Schema - err error - inExpr bool + ctx sessionctx.Context + fields []*ast.SelectField + schema *expression.Schema + err error + inExpr bool + isParam bool } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { - switch inNode.(type) { + switch n := inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true + case *driver.ParamMarkerExpr: + newNode := expression.ConstructPositionExpr(n) + g.isParam = true + return newNode, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true @@ -1196,14 +1241,21 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } case *ast.PositionExpr: - if v.N < 1 || v.N > len(g.fields) { - g.err = errors.Errorf("Unknown column '%d' in 'group statement'", v.N) + pos, isNull, err := expression.PosFromPositionExpr(g.ctx, v) + if err != nil { + g.err = ErrUnknown.GenWithStackByArgs() + } + if err != nil || isNull { return inNode, false } - ret := g.fields[v.N-1].Expr + if pos < 1 || pos > len(g.fields) { + g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) + return inNode, false + } + ret := g.fields[pos-1].Expr ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { - g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[v.N-1].Text()) + g.err = ErrWrongGroupField.GenWithStackByArgs(g.fields[pos-1].Text()) return inNode, false } return ret, true @@ -1561,6 +1613,7 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ + ctx: b.ctx, fields: fields, schema: p.Schema(), } @@ -1570,9 +1623,12 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie if resolver.err != nil { return nil, nil, errors.Trace(resolver.err) } + if !resolver.isParam { + item.Expr = retExpr.(ast.ExprNode) + } - item.Expr = retExpr.(ast.ExprNode) - expr, np, err := b.rewrite(item.Expr, p, nil, true) + itemExpr := retExpr.(ast.ExprNode) + expr, np, err := b.rewrite(itemExpr, p, nil, true) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index eb177aabfcc7a..ea73910f3d645 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -148,8 +148,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if selStmt.Having != nil || selStmt.LockTp != ast.SelectLockNone { return nil } else if selStmt.Limit != nil { - sc := ctx.GetSessionVars().StmtCtx - count, offset, err := extractLimitCountOffset(sc, selStmt.Limit) + count, offset, err := extractLimitCountOffset(ctx, selStmt.Limit) if err != nil || count == 0 || offset > 0 { return nil }