diff --git a/executor/point_get.go b/executor/point_get.go index e9063386fbe0c..bf3ff62de8231 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -46,6 +46,7 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { idxVals: p.IndexValues, handle: p.Handle, startTS: startTS, + done: p.UnsignedHandle && p.Handle < 0, } } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 0b36025af5d7a..0be899e4fb87e 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -41,6 +41,7 @@ type PointGetPlan struct { IndexInfo *model.IndexInfo Handle int64 HandleParam *driver.ParamMarkerExpr + UnsignedHandle bool IndexValues []types.Datum IndexValueParams []*driver.ParamMarkerExpr expr expression.Expression @@ -185,7 +186,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if pairs == nil { return nil } - handlePair := findPKHandle(tbl, pairs) + handlePair, unsigned := findPKHandle(tbl, pairs) if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 { schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, selStmt.Fields.Fields) if schema == nil { @@ -197,6 +198,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if err != nil { return nil } + p.UnsignedHandle = unsigned p.HandleParam = handlePair.param return p } @@ -362,20 +364,20 @@ func getNameValuePairs(nvPairs []nameValuePair, expr ast.ExprNode) []nameValuePa return nil } -func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair) { +func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, unsigned bool) { if !tblInfo.PKIsHandle { - return handlePair + return handlePair, unsigned } for _, col := range tblInfo.Columns { if mysql.HasPriKeyFlag(col.Flag) { i := findInPairs(col.Name.L, pairs) if i == -1 { - return handlePair + return handlePair, unsigned } - return pairs[i] + return pairs[i], mysql.HasUnsignedFlag(col.Flag) } } - return handlePair + return handlePair, unsigned } func getIndexValues(idxInfo *model.IndexInfo, pairs []nameValuePair) ([]types.Datum, []*driver.ParamMarkerExpr) { diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index 6a548bca53cdc..7d8eece9df1f8 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -54,7 +54,7 @@ func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { core.PreparedPlanCacheMaxMemory.Store(math.MaxUint64) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int primary key, b int, c int, key idx_bc(b,c))") + tk.MustExec("create table t(a bigint unsigned primary key, b int, c int, key idx_bc(b,c))") tk.MustExec("insert into t values(1, 1, 1), (2, 2, 2), (3, 3, 3)") tk.MustQuery("explain select * from t where a = 1").Check(testkit.Rows( "Point_Get_1 1.00 root table:t, handle:1", @@ -68,6 +68,11 @@ func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { tk.MustQuery("explain delete from t where a = 1").Check(testkit.Rows( "Point_Get_1 1.00 root table:t, handle:1", )) + tk.MustQuery("explain select a from t where a = -1").Check(testkit.Rows( + "TableDual_5 0.00 root rows:0")) + tk.MustExec(`prepare stmt0 from "select a from t where a = ?"`) + tk.MustExec("set @p0 = -1") + tk.MustQuery("execute stmt0 using @p0").Check(testkit.Rows()) metrics.ResettablePlanCacheCounterFortTest = true metrics.PlanCacheCounter.Reset() counter := metrics.PlanCacheCounter.WithLabelValues("prepare") @@ -137,4 +142,13 @@ func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { counter.Write(pb) hit = pb.GetCounter().GetValue() c.Check(hit, Equals, float64(2)) + tk.MustExec("insert into t (a, b, c) values (18446744073709551615, 4, 4)") + tk.MustExec("set @p1=-1") + tk.MustExec("set @p2=1") + tk.MustExec(`prepare stmt7 from "select a from t where a = ?"`) + tk.MustQuery("execute stmt7 using @p1").Check(testkit.Rows()) + tk.MustQuery("execute stmt7 using @p2").Check(testkit.Rows("1")) + counter.Write(pb) + hit = pb.GetCounter().GetValue() + c.Check(hit, Equals, float64(3)) }