Skip to content

Commit

Permalink
*: handle signed/unsigned in the partition pruning (#15436)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored Mar 24, 2020
1 parent 00dc69c commit 365ff9b
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 73 deletions.
37 changes: 30 additions & 7 deletions expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package expression

import (
"context"

"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -41,10 +43,20 @@ type simpleRewriter struct {
// The expression string must only reference the column in table Info.
func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
var stmts []ast.StmtNode
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

if err != nil {
return nil, util.SyntaxError(err)
}
Expand Down Expand Up @@ -80,12 +92,13 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo
func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}
if err != nil {
return nil, util.SyntaxWarn(err)
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

fields := stmts[0].(*ast.SelectStmt).Fields.Fields
exprs := make([]Expression, 0, len(fields))
for _, field := range fields {
Expand All @@ -102,13 +115,23 @@ func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *
// The expression string must only reference the column in the given NameSlice.
func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *Schema, names types.NameSlice) ([]Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
var stmts []ast.StmtNode
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
}
if err != nil {
return nil, util.SyntaxWarn(err)
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

fields := stmts[0].(*ast.SelectStmt).Fields.Fields
exprs := make([]Expression, 0, len(fields))
for _, field := range fields {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/partition_pruning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *testPartitionPruningSuite) TestPruneUseBinarySearch(c *C) {
}

for i, ca := range cases {
start, end := pruneUseBinarySearch(lessThan, ca.input)
start, end := pruneUseBinarySearch(lessThan, ca.input, false)
c.Assert(ca.result.start, Equals, start, Commentf("fail = %d", i))
c.Assert(ca.result.end, Equals, end, Commentf("fail = %d", i))
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ func getHashPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *a
return nil
}
// PartitionExpr don't need columns and names for hash partition.
partitionExpr, err := table.(partitionTable).PartitionExpr(ctx, nil, nil)
partitionExpr, err := table.(partitionTable).PartitionExpr()
if err != nil {
return nil
}
Expand Down
66 changes: 33 additions & 33 deletions planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ package core
import (
"context"
"sort"
"strconv"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -98,7 +96,7 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err

// partitionTable is for those tables which implement partition.
type partitionTable interface {
PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*tables.PartitionExpr, error)
PartitionExpr() (*tables.PartitionExpr, error)
}

func generateHashPartitionExpr(t table.Table, ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) {
Expand Down Expand Up @@ -185,12 +183,25 @@ func (lt *lessThanDataInt) length() int {
return len(lt.data)
}

func (lt *lessThanDataInt) compare(ith int, v int64) int {
func compareUnsigned(v1, v2 int64) int {
switch {
case uint64(v1) > uint64(v2):
return 1
case uint64(v1) == uint64(v2):
return 0
}
return -1
}

func (lt *lessThanDataInt) compare(ith int, v int64, unsigned bool) int {
if ith == len(lt.data)-1 {
if lt.maxvalue {
return 1
}
}
if unsigned {
return compareUnsigned(lt.data[ith], v)
}
switch {
case lt.data[ith] > v:
return 1
Expand Down Expand Up @@ -328,37 +339,24 @@ func (s *partitionProcessor) pruneRangePartition(ds *DataSource, pi *model.Parti
result := fullRange(len(pi.Definitions))
// Extract the partition column, if the column is not null, it's possible to prune.
if col != nil {
// TODO: Store LessThanData in the partitionExpr, avoid allocating here.
lessThan, err := makeLessThanData(pi)
partExpr, err := ds.table.(partitionTable).PartitionExpr()
if err != nil {
return nil, err
}
pruner := rangePruner{lessThan, col, fn}
pruner := rangePruner{
lessThan: lessThanDataInt{
data: partExpr.ForRangePruning.LessThan,
maxvalue: partExpr.ForRangePruning.MaxValue,
},
col: col,
partFn: fn,
}
result = partitionRangeForCNFExpr(ds.ctx, ds.allConds, &pruner, result)
}

return s.makeUnionAllChildren(ds, pi, result)
}

// makeLessThanData extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...'
func makeLessThanData(pi *model.PartitionInfo) (lessThanDataInt, error) {
var maxValue bool
lessThan := make([]int64, len(pi.Definitions))
for i := 0; i < len(pi.Definitions); i++ {
if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Use a bool flag instead of math.MaxInt64 to avoid the corner cases.
maxValue = true
} else {
var err error
lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64)
if err != nil {
return lessThanDataInt{}, errors.WithStack(err)
}
}
}
return lessThanDataInt{lessThan, maxValue}, nil
}

// makePartitionByFnCol extracts the column and function information in 'partition by ... fn(col)'.
func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, error) {
schema := expression.NewSchema(columns...)
Expand Down Expand Up @@ -431,7 +429,9 @@ func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expres
if !ok {
return 0, 0, false
}
start, end := pruneUseBinarySearch(p.lessThan, dataForPrune)

unsigned := mysql.HasUnsignedFlag(p.col.RetType.Flag)
start, end := pruneUseBinarySearch(p.lessThan, dataForPrune, unsigned)
return start, end, true
}

Expand Down Expand Up @@ -556,44 +556,44 @@ func relaxOP(op string) string {
return op
}

func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune) (start int, end int) {
func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune, unsigned bool) (start int, end int) {
length := lessThan.length()
switch data.op {
case ast.EQ:
// col = 66, lessThan = [4 7 11 14 17] => [5, 6)
// col = 14, lessThan = [4 7 11 14 17] => [4, 5)
// col = 10, lessThan = [4 7 11 14 17] => [2, 3)
// col = 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = pos, pos+1
case ast.LT:
// col < 66, lessThan = [4 7 11 14 17] => [0, 5)
// col < 14, lessThan = [4 7 11 14 17] => [0, 4)
// col < 10, lessThan = [4 7 11 14 17] => [0, 3)
// col < 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) >= 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) >= 0 })
start, end = 0, pos+1
case ast.GE:
// col >= 66, lessThan = [4 7 11 14 17] => [5, 5)
// col >= 14, lessThan = [4 7 11 14 17] => [4, 5)
// col >= 10, lessThan = [4 7 11 14 17] => [2, 5)
// col >= 3, lessThan = [4 7 11 14 17] => [0, 5)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = pos, length
case ast.GT:
// col > 66, lessThan = [4 7 11 14 17] => [5, 5)
// col > 14, lessThan = [4 7 11 14 17] => [4, 5)
// col > 10, lessThan = [4 7 11 14 17] => [3, 5)
// col > 3, lessThan = [4 7 11 14 17] => [1, 5)
// col > 2, lessThan = [4 7 11 14 17] => [0, 5)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1, unsigned) > 0 })
start, end = pos, length
case ast.LE:
// col <= 66, lessThan = [4 7 11 14 17] => [0, 6)
// col <= 14, lessThan = [4 7 11 14 17] => [0, 5)
// col <= 10, lessThan = [4 7 11 14 17] => [0, 3)
// col <= 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = 0, pos+1
case ast.IsNull:
start, end = 0, 1
Expand Down
75 changes: 53 additions & 22 deletions table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ package tables

import (
"bytes"
stderr "errors"
"fmt"
"sort"
"strconv"
"strings"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -113,6 +115,46 @@ type PartitionExpr struct {
OrigExpr ast.ExprNode
// Expr is the hash partition expression.
Expr expression.Expression
// Used in the range pruning process.
*ForRangePruning
}

// ForRangePruning is used for range partition pruning.
type ForRangePruning struct {
LessThan []int64
MaxValue bool
Unsigned bool
}

// dataForRangePruning extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...'
func dataForRangePruning(pi *model.PartitionInfo) (*ForRangePruning, error) {
var maxValue bool
var unsigned bool
lessThan := make([]int64, len(pi.Definitions))
for i := 0; i < len(pi.Definitions); i++ {
if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Use a bool flag instead of math.MaxInt64 to avoid the corner cases.
maxValue = true
} else {
var err error
lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64)
var numErr *strconv.NumError
if stderr.As(err, &numErr) && numErr.Err == strconv.ErrRange {
var tmp uint64
tmp, err = strconv.ParseUint(pi.Definitions[i].LessThan[0], 10, 64)
lessThan[i] = int64(tmp)
unsigned = true
}
if err != nil {
return nil, errors.WithStack(err)
}
}
}
return &ForRangePruning{
LessThan: lessThan,
MaxValue: maxValue,
Unsigned: unsigned,
}, nil
}

// rangePartitionString returns the partition string for a range typed partition.
Expand All @@ -134,13 +176,11 @@ func rangePartitionString(pi *model.PartitionInfo) string {
func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) {
// The caller should assure partition info is not nil.
partitionPruneExprs := make([]expression.Expression, 0, len(pi.Definitions))
locateExprs := make([]expression.Expression, 0, len(pi.Definitions))
var buf bytes.Buffer
schema := expression.NewSchema(columns...)
partStr := rangePartitionString(pi)
for i := 0; i < len(pi.Definitions); i++ {

if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Expr less than maxvalue is always true.
fmt.Fprintf(&buf, "true")
Expand All @@ -155,28 +195,19 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
return nil, errors.Trace(err)
}
locateExprs = append(locateExprs, exprs[0])

if i > 0 {
fmt.Fprintf(&buf, " and ((%s) >= (%s))", partStr, pi.Definitions[i-1].LessThan[0])
} else {
// NULL will locate in the first partition, so its expression is (expr < value or expr is null).
fmt.Fprintf(&buf, " or ((%s) is null)", partStr)
}

exprs, err = expression.ParseSimpleExprsWithNames(ctx, buf.String(), schema, names)
buf.Reset()
}
ret := &PartitionExpr{
UpperBounds: locateExprs,
}
if len(pi.Columns) == 0 {
tmp, err := dataForRangePruning(pi)
if err != nil {
// If it got an error here, ddl may hang forever, so this error log is important.
logutil.BgLogger().Error("wrong table partition expression", zap.String("expression", buf.String()), zap.Error(err))
return nil, errors.Trace(err)
}
// Get a hash code in advance to prevent data race afterwards.
exprs[0].HashCode(ctx.GetSessionVars().StmtCtx)
partitionPruneExprs = append(partitionPruneExprs, exprs[0])
buf.Reset()
ret.ForRangePruning = tmp
}
return &PartitionExpr{
UpperBounds: locateExprs,
}, nil
return ret, nil
}

func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
Expand All @@ -201,13 +232,13 @@ func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
}

// PartitionExpr returns the partition expression.
func (t *partitionedTable) PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) {
func (t *partitionedTable) PartitionExpr() (*PartitionExpr, error) {
pi := t.meta.GetPartitionInfo()
switch pi.Type {
case model.PartitionTypeHash:
return t.partitionExpr, nil
case model.PartitionTypeRange:
return generateRangePartitionExpr(ctx, pi, columns, names)
return t.partitionExpr, nil
}
panic("cannot reach here")
}
Expand Down
Loading

0 comments on commit 365ff9b

Please sign in to comment.