Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support SQL bind for Update / Delete / Insert / Replace (#20686) #21101

Merged
merged 2 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions bindinfo/bind_test.go

Large diffs are not rendered by default.

42 changes: 35 additions & 7 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,16 +596,19 @@ func (h *BindHandle) logicalDeleteBindInfoSQL(originalSQL, db string, updateTs t
// CaptureBaselines is used to automatically capture plan baselines.
func (h *BindHandle) CaptureBaselines() {
parser4Capture := parser.New()
schemas, sqls := stmtsummary.StmtSummaryByDigestMap.GetMoreThanOnceSelect()
schemas, sqls := stmtsummary.StmtSummaryByDigestMap.GetMoreThanOnceBindableStmt()
for i := range sqls {
stmt, err := parser4Capture.ParseOneStmt(sqls[i], "", "")
if err != nil {
logutil.BgLogger().Debug("parse SQL failed", zap.String("SQL", sqls[i]), zap.Error(err))
continue
}
normalizedSQL, digiest := parser.NormalizeDigest(sqls[i])
if insertStmt, ok := stmt.(*ast.InsertStmt); ok && insertStmt.Select == nil {
continue
}
normalizedSQL, digest := parser.NormalizeDigest(sqls[i])
dbName := utilparser.GetDefaultDB(stmt, schemas[i])
if r := h.GetBindRecord(digiest, normalizedSQL, dbName); r != nil && r.HasUsingBinding() {
if r := h.GetBindRecord(digest, normalizedSQL, dbName); r != nil && r.HasUsingBinding() {
continue
}
h.sctx.Lock()
Expand Down Expand Up @@ -682,10 +685,35 @@ func GenerateBindSQL(ctx context.Context, stmtNode ast.StmtNode, planHint string
logutil.Logger(ctx).Warn("Restore SQL failed", zap.Error(err))
}
bindSQL := sb.String()
selectIdx := strings.Index(bindSQL, "SELECT")
// Remove possible `explain` prefix.
bindSQL = bindSQL[selectIdx:]
return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1)
switch n := stmtNode.(type) {
case *ast.DeleteStmt:
deleteIdx := strings.Index(bindSQL, "DELETE")
// Remove possible `explain` prefix.
bindSQL = bindSQL[deleteIdx:]
return strings.Replace(bindSQL, "DELETE", fmt.Sprintf("DELETE /*+ %s*/", planHint), 1)
case *ast.UpdateStmt:
updateIdx := strings.Index(bindSQL, "UPDATE")
// Remove possible `explain` prefix.
bindSQL = bindSQL[updateIdx:]
return strings.Replace(bindSQL, "UPDATE", fmt.Sprintf("UPDATE /*+ %s*/", planHint), 1)
case *ast.SelectStmt:
selectIdx := strings.Index(bindSQL, "SELECT")
// Remove possible `explain` prefix.
bindSQL = bindSQL[selectIdx:]
return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1)
case *ast.InsertStmt:
insertIdx := int(0)
if n.IsReplace {
insertIdx = strings.Index(bindSQL, "REPLACE")
} else {
insertIdx = strings.Index(bindSQL, "INSERT")
}
// Remove possible `explain` prefix.
bindSQL = bindSQL[insertIdx:]
return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1)
}
logutil.Logger(ctx).Warn("Unexpected statement type")
return ""
}

type paramMarkerChecker struct {
Expand Down
27 changes: 23 additions & 4 deletions executor/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,36 @@ func getStmtDbLabel(stmtNode ast.StmtNode) map[string]struct{} {
}
}
case *ast.CreateBindingStmt:
var resNode ast.ResultSetNode
if x.OriginSel != nil {
originSelect := x.OriginSel.(*ast.SelectStmt)
dbLabels := getDbFromResultNode(originSelect.From.TableRefs)
switch n := x.OriginSel.(type) {
case *ast.SelectStmt:
resNode = n.From.TableRefs
case *ast.DeleteStmt:
resNode = n.TableRefs.TableRefs
case *ast.UpdateStmt:
resNode = n.TableRefs.TableRefs
case *ast.InsertStmt:
resNode = n.Table.TableRefs
}
dbLabels := getDbFromResultNode(resNode)
for _, db := range dbLabels {
dbLabelSet[db] = struct{}{}
}
}

if len(dbLabelSet) == 0 && x.HintedSel != nil {
hintedSelect := x.HintedSel.(*ast.SelectStmt)
dbLabels := getDbFromResultNode(hintedSelect.From.TableRefs)
switch n := x.HintedSel.(type) {
case *ast.SelectStmt:
resNode = n.From.TableRefs
case *ast.DeleteStmt:
resNode = n.TableRefs.TableRefs
case *ast.UpdateStmt:
resNode = n.TableRefs.TableRefs
case *ast.InsertStmt:
resNode = n.Table.TableRefs
}
dbLabels := getDbFromResultNode(resNode)
for _, db := range dbLabels {
dbLabelSet[db] = struct{}{}
}
Expand Down
6 changes: 6 additions & 0 deletions planner/core/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func GenHintsFromPhysicalPlan(p Plan) []*ast.TableOptimizerHint {
hints = genHintsFromPhysicalPlan(pp.SelectPlan, utilhint.TypeUpdate)
case *Delete:
hints = genHintsFromPhysicalPlan(pp.SelectPlan, utilhint.TypeDelete)
// For Insert, we only generate hints that would be used in select query block.
case *Insert:
hints = genHintsFromPhysicalPlan(pp.SelectPlan, utilhint.TypeSelect)
case PhysicalPlan:
hints = genHintsFromPhysicalPlan(pp, utilhint.TypeSelect)
}
Expand Down Expand Up @@ -104,6 +107,9 @@ func getJoinHints(sctx sessionctx.Context, joinType string, parentOffset int, no
}

func genHintsFromPhysicalPlan(p PhysicalPlan, nodeType utilhint.NodeType) (res []*ast.TableOptimizerHint) {
if p == nil {
return res
}
for _, child := range p.Children() {
res = append(res, genHintsFromPhysicalPlan(child, nodeType)...)
}
Expand Down
49 changes: 46 additions & 3 deletions planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,53 @@ func EraseLastSemicolon(stmt ast.StmtNode) {
}
}

func (p *preprocessor) checkBindGrammar(originSel, hintedSel ast.StmtNode) {
originSQL := parser.Normalize(originSel.(*ast.SelectStmt).Text())
hintedSQL := parser.Normalize(hintedSel.(*ast.SelectStmt).Text())
const (
// TypeInvalid for unexpected types.
TypeInvalid byte = iota
// TypeSelect for SelectStmt.
TypeSelect
// TypeDelete for DeleteStmt.
TypeDelete
// TypeUpdate for UpdateStmt.
TypeUpdate
// TypeInsert for InsertStmt.
TypeInsert
)

func bindableStmtType(node ast.StmtNode) byte {
switch node.(type) {
case *ast.SelectStmt:
return TypeSelect
case *ast.DeleteStmt:
return TypeDelete
case *ast.UpdateStmt:
return TypeUpdate
case *ast.InsertStmt:
return TypeInsert
}
return TypeInvalid
}

func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode) {
origTp := bindableStmtType(originNode)
hintedTp := bindableStmtType(hintedNode)
if origTp == TypeInvalid || hintedTp == TypeInvalid {
p.err = errors.Errorf("create binding doesn't support this type of query")
return
}
if origTp != hintedTp {
p.err = errors.Errorf("hinted sql and original sql have different query types")
return
}
if origTp == TypeInsert {
origInsert, hintedInsert := originNode.(*ast.InsertStmt), hintedNode.(*ast.InsertStmt)
if origInsert.Select == nil || hintedInsert.Select == nil {
p.err = errors.Errorf("create binding only supports INSERT / REPLACE INTO SELECT")
return
}
}
originSQL := parser.Normalize(originNode.Text())
hintedSQL := parser.Normalize(hintedNode.Text())
if originSQL != hintedSQL {
p.err = errors.Errorf("hinted sql and origin sql don't match when hinted sql erase the hint info, after erase hint info, originSQL:%s, hintedSQL:%s", originSQL, hintedSQL)
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *testValidatorSuite) SetUpTest(c *C) {

func (s *testValidatorSuite) runSQL(c *C, sql string, inPrepare bool, terr error) {
stmts, err1 := session.Parse(s.ctx, sql)
c.Assert(err1, IsNil)
c.Assert(err1, IsNil, Commentf("sql: %s", sql))
c.Assert(stmts, HasLen, 1)
stmt := stmts[0]
var opts []core.PreprocessOpt
Expand Down
2 changes: 1 addition & 1 deletion planner/core/testdata/plan_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@
{
"SQL": "insert into t select * from t where b < 1 order by d limit 1",
"Best": "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Insert",
"Hints": ""
"Hints": "use_index(@`sel_1` `test`.`t` )"
},
{
"SQL": "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)",
Expand Down
34 changes: 25 additions & 9 deletions planner/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
bestPlanAmongHints = plan
}
}
// 1. If there is already a evolution task, we do not need to handle it again.
// 2. If the origin binding contain `read_from_storage` hint, we should ignore the evolve task.
// 3. If the best plan contain TiFlash hint, we should ignore the evolve task.
if sctx.GetSessionVars().EvolvePlanBaselines && binding == nil &&
// 1. If it is a select query.
// 2. If there is already a evolution task, we do not need to handle it again.
// 3. If the origin binding contain `read_from_storage` hint, we should ignore the evolve task.
// 4. If the best plan contain TiFlash hint, we should ignore the evolve task.
if _, ok := stmtNode.(*ast.SelectStmt); ok &&
sctx.GetSessionVars().EvolvePlanBaselines && binding == nil &&
!originHints.ContainTableHint(plannercore.HintReadFromStorage) &&
!bindRecord.Bindings[0].Hint.ContainTableHint(plannercore.HintReadFromStorage) {
handleEvolveTasks(ctx, sctx, bindRecord, stmtNode, bestPlanHintStr)
Expand Down Expand Up @@ -249,19 +251,33 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
return finalPlan, names, cost, err
}

func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode) (*ast.SelectStmt, string, string) {
func extractSelectAndNormalizeDigest(stmtNode ast.StmtNode) (ast.StmtNode, string, string) {
switch x := stmtNode.(type) {
case *ast.ExplainStmt:
switch x.Stmt.(type) {
case *ast.SelectStmt:
case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
normalizeExplainSQL := parser.Normalize(x.Text())
idx := strings.Index(normalizeExplainSQL, "select")
idx := int(0)
switch n := x.Stmt.(type) {
case *ast.SelectStmt:
idx = strings.Index(normalizeExplainSQL, "select")
case *ast.DeleteStmt:
idx = strings.Index(normalizeExplainSQL, "delete")
case *ast.UpdateStmt:
idx = strings.Index(normalizeExplainSQL, "update")
case *ast.InsertStmt:
if n.IsReplace {
idx = strings.Index(normalizeExplainSQL, "replace")
} else {
idx = strings.Index(normalizeExplainSQL, "insert")
}
}
normalizeSQL := normalizeExplainSQL[idx:]
hash := parser.DigestNormalized(normalizeSQL)
return x.Stmt.(*ast.SelectStmt), normalizeSQL, hash
return x.Stmt, normalizeSQL, hash
}
case *ast.SelectStmt:
case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
plannercore.EraseLastSemicolon(x)
normalizedSQL, hash := parser.NormalizeDigest(x.Text())
return x, normalizedSQL, hash
Expand Down
4 changes: 2 additions & 2 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1778,12 +1778,12 @@ func (s *testSessionSuite3) TestUnique(c *C) {
c.Assert(err, NotNil)
// Check error type and error message
c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err))
c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(1, 1);: [kv:1062]Duplicate entry '1' for key 'PRIMARY'")
c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(1, 1): [kv:1062]Duplicate entry '1' for key 'PRIMARY'")

_, err = tk1.Exec("commit")
c.Assert(err, NotNil)
c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err))
c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(2, 2);: [kv:1062]Duplicate entry '2' for key 'val'")
c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(2, 2): [kv:1062]Duplicate entry '2' for key 'val'")

// Test for https://github.com/pingcap/tidb/issues/463
tk.MustExec("drop table test;")
Expand Down
63 changes: 51 additions & 12 deletions util/hint/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ func (hs *HintsSet) ContainTableHint(hint string) bool {
return false
}

// setTableHints4StmtNode sets table hints for select/update/delete.
func setTableHints4StmtNode(node ast.Node, hints []*ast.TableOptimizerHint) {
switch x := node.(type) {
case *ast.SelectStmt:
x.TableHints = hints
case *ast.UpdateStmt:
x.TableHints = hints
case *ast.DeleteStmt:
x.TableHints = hints
}
}

// ExtractTableHintsFromStmtNode extracts table hints from this node.
func ExtractTableHintsFromStmtNode(node ast.Node, sctx sessionctx.Context) []*ast.TableOptimizerHint {
switch x := node.(type) {
Expand Down Expand Up @@ -177,25 +189,32 @@ func (hs *HintsSet) Restore() (string, error) {
type hintProcessor struct {
*HintsSet
// bindHint2Ast indicates the behavior of the processor, `true` for bind hint to ast, `false` for extract hint from ast.
bindHint2Ast bool
tableCounter int
indexCounter int
bindHint2Ast bool
tableCounter int
indexCounter int
selectCounter int
}

func (hp *hintProcessor) Enter(in ast.Node) (ast.Node, bool) {
switch v := in.(type) {
case *ast.SelectStmt:
case *ast.SelectStmt, *ast.UpdateStmt, *ast.DeleteStmt:
if hp.bindHint2Ast {
if hp.tableCounter < len(hp.tableHints) {
v.TableHints = hp.tableHints[hp.tableCounter]
setTableHints4StmtNode(in, hp.tableHints[hp.tableCounter])
} else {
v.TableHints = nil
setTableHints4StmtNode(in, nil)
}
hp.tableCounter++
} else {
hp.tableHints = append(hp.tableHints, v.TableHints)
hp.tableHints = append(hp.tableHints, ExtractTableHintsFromStmtNode(in, nil))
}
if _, ok := in.(*ast.SelectStmt); ok {
hp.selectCounter++
}
case *ast.TableName:
if hp.selectCounter == 0 {
return in, false
}
if hp.bindHint2Ast {
if hp.indexCounter < len(hp.indexHints) {
v.IndexHints = hp.indexHints[hp.indexCounter]
Expand All @@ -211,6 +230,9 @@ func (hp *hintProcessor) Enter(in ast.Node) (ast.Node, bool) {
}

func (hp *hintProcessor) Leave(in ast.Node) (ast.Node, bool) {
if _, ok := in.(*ast.SelectStmt); ok {
hp.selectCounter--
}
return in, true
}

Expand Down Expand Up @@ -240,18 +262,19 @@ func ParseHintsSet(p *parser.Parser, sql, charset, collation, db string) (*Hints
hs := CollectHint(stmtNodes[0])
processor := &BlockHintProcessor{}
stmtNodes[0].Accept(processor)
hintNodeType := nodeType4Stmt(stmtNodes[0])
for i, tblHints := range hs.tableHints {
newHints := make([]*ast.TableOptimizerHint, 0, len(tblHints))
for _, tblHint := range tblHints {
if tblHint.HintName.L == hintQBName {
continue
}
offset := processor.GetHintOffset(tblHint.QBName, TypeSelect, i+1)
if offset < 0 || !processor.checkTableQBName(tblHint.Tables, TypeSelect) {
offset := processor.GetHintOffset(tblHint.QBName, hintNodeType, i+1)
if offset < 0 || !processor.checkTableQBName(tblHint.Tables, hintNodeType) {
hintStr := RestoreTableOptimizerHint(tblHint)
return nil, nil, errors.New(fmt.Sprintf("Unknown query block name in hint %s", hintStr))
}
tblHint.QBName = GenerateQBName(TypeSelect, offset)
tblHint.QBName = GenerateQBName(hintNodeType, offset)
for i, tbl := range tblHint.Tables {
if tbl.DBName.String() == "" {
tblHint.Tables[i].DBName = model.NewCIStr(db)
Expand Down Expand Up @@ -360,8 +383,24 @@ const (
TypeDelete
// TypeSelect for SELECT.
TypeSelect
// TypeInvalid for unexpected statements.
TypeInvalid
)

// nodeType4Stmt returns the NodeType for a statement. The type is used for SQL bind.
func nodeType4Stmt(node ast.StmtNode) NodeType {
switch node.(type) {
// This type is used by SQL bind, we only handle SQL bind for INSERT INTO SELECT, so we treat InsertStmt as TypeSelect.
case *ast.SelectStmt, *ast.InsertStmt:
return TypeSelect
case *ast.UpdateStmt:
return TypeUpdate
case *ast.DeleteStmt:
return TypeDelete
}
return TypeInvalid
}

// getBlockName finds the offset of query block name. It use 0 as offset for top level update or delete,
// -1 for invalid block name.
func (p *BlockHintProcessor) getBlockOffset(blockName model.CIStr, nodeType NodeType) int {
Expand Down Expand Up @@ -428,9 +467,9 @@ func (p *BlockHintProcessor) GetCurrentStmtHints(hints []*ast.TableOptimizerHint

// GenerateQBName builds QBName from offset.
func GenerateQBName(nodeType NodeType, blockOffset int) model.CIStr {
if nodeType == TypeDelete && blockOffset == 0 {
if nodeType == TypeDelete && (blockOffset == 0 || blockOffset == 1) {
return model.NewCIStr(defaultDeleteBlockName)
} else if nodeType == TypeUpdate && blockOffset == 0 {
} else if nodeType == TypeUpdate && (blockOffset == 0 || blockOffset == 1) {
return model.NewCIStr(defaultUpdateBlockName)
}
return model.NewCIStr(fmt.Sprintf("%s%d", defaultSelectBlockPrefix, blockOffset))
Expand Down
Loading