From bf8163cc49762ac609fe5f314e87b4abf0873392 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 02:18:56 +0800 Subject: [PATCH 01/18] support query-binding --- bindinfo/bind_test.go | 156 +++++++++++++++++++++++++++++++++++++ bindinfo/cache.go | 8 +- bindinfo/handle.go | 66 +++++++++++++++- bindinfo/session_handle.go | 14 +++- domain/domain.go | 16 +++- executor/bind.go | 1 + executor/compiler.go | 149 +++++++++++++++++++++++++++++++++++ session/session.go | 72 +++++++++++++++-- 8 files changed, 467 insertions(+), 15 deletions(-) diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index 423ef629bc3d8..f3f4ce24f559f 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -257,3 +257,159 @@ func (s *testSuite) TestSessionBinding(c *C) { c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 0) } + +func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustQuery("explain SELECT /*+ TIDB_INLJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_7 12487.50 root inner join, inner:TableReader_14, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_11 9990.00 root data:Selection_10", + "│ └─Selection_10 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + +func (s *testSuite) TestComplexSqlBinding(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + tk.MustExec("create index index_t1 on t1(id)") + tk.MustExec("create index index_t1 on t2(id)") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_8 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─IndexReader_20 9990.00 root index:IndexScan_19", + "│ └─IndexScan_19 9990.00 cop table:t1, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", + "└─IndexReader_22 9990.00 root index:IndexScan_21", + " └─IndexScan_21 9990.00 cop table:t2, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", + )) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using select * from t1 use index(index_t1), t2 use index(index_t2) where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_8 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─IndexReader_20 9990.00 root index:IndexScan_19", + "│ └─IndexScan_19 9990.00 cop table:t1, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", + "└─IndexReader_22 9990.00 root index:IndexScan_21", + " └─IndexScan_21 9990.00 cop table:t2, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", + )) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + +func (s *testSuite) TestExplain(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", + "├─Sort_11 9990.00 root test.t1.id:asc", + "│ └─TableReader_10 9990.00 root data:Selection_9", + "│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Sort_15 9990.00 root test.t2.id:asc", + " └─TableReader_14 9990.00 root data:Selection_13", + " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") +} + +func (s *testSuite) TestErrorBind(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(i int, s varchar(20))") + tk.MustExec("create table t1(i int, s varchar(20))") + tk.MustExec("create index index_t on t(i,s)") + + _, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100") + c.Assert(err, IsNil, Commentf("err %v", err)) + + bindData := s.domain.BindHandle().GetBindRecord("select * from t where i > ?", "test") + c.Check(bindData, NotNil) + c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?") + c.Check(bindData.BindSQL, Equals, "select * from t use index(index_t) where i>100") + c.Check(bindData.Db, Equals, "test") + c.Check(bindData.Status, Equals, "using") + c.Check(bindData.Charset, NotNil) + c.Check(bindData.Collation, NotNil) + c.Check(bindData.CreateTime, NotNil) + c.Check(bindData.UpdateTime, NotNil) + + tk.MustExec("drop index index_t on t") + _, err = tk.Exec("select * from t where i > 10") + c.Check(err, IsNil) + + s.domain.BindHandle().HandleDropBindRecord() + + rs, err := tk.Exec("show global bindings") + c.Assert(err, IsNil) + chk := rs.NewRecordBatch() + err = rs.Next(context.TODO(), chk) + c.Check(err, IsNil) + c.Check(chk.NumRows(), Equals, 0) +} diff --git a/bindinfo/cache.go b/bindinfo/cache.go index b8731628ec080..f64d02f3b4cfb 100644 --- a/bindinfo/cache.go +++ b/bindinfo/cache.go @@ -20,16 +20,18 @@ import ( ) const ( - // using is the bind info's in use status. - using = "using" + // Using is the bind info's in use status. + Using = "using" // deleted is the bind info's deleted status. deleted = "deleted" + // Invalid is the bind info's invalid status. + Invalid = "invalid" ) // bindMeta stores the basic bind info and bindSql astNode. type bindMeta struct { *BindRecord - ast ast.StmtNode //ast will be used to do query sql bind check + Ast ast.StmtNode //ast will be used to do query sql bind check } // cache is a k-v map, key is original sql, value is a slice of bindMeta. diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 0be680217c186..4bfff3632e945 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -20,6 +20,7 @@ import ( "go.uber.org/zap" "sync" "sync/atomic" + "time" "github.com/pingcap/parser" "github.com/pingcap/parser/mysql" @@ -60,15 +61,26 @@ type BindHandle struct { atomic.Value } + dropBindRecordMap struct { + sync.Mutex + atomic.Value + } + parser *parser.Parser lastUpdateTime types.Time } +type droppedBindRecord struct { + bindRecord *BindRecord + droppedTime time.Time +} + // NewBindHandle creates a new BindHandle. func NewBindHandle(ctx sessionctx.Context, parser *parser.Parser) *BindHandle { handle := &BindHandle{parser: parser} handle.sctx.Context = ctx handle.bindInfo.Value.Store(make(cache, 32)) + handle.dropBindRecordMap.Value.Store(make(map[string]*droppedBindRecord)) return handle } @@ -106,7 +118,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { } newCache.removeStaleBindMetas(hash, meta) - if meta.Status == using { + if meta.Status == Using { newCache[hash] = append(newCache[hash], meta) } } @@ -163,7 +175,7 @@ func (h *BindHandle) AddBindRecord(record *BindRecord) (err error) { Fsp: 3, } record.UpdateTime = record.CreateTime - record.Status = using + record.Status = Using record.BindSQL = h.getEscapeCharacter(record.BindSQL) // insert the BindRecord to the storage. @@ -217,6 +229,46 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { return err } +// HandleDropBindRecord execute the drop bindRecord task. +func (h *BindHandle) HandleDropBindRecord() { + h.dropBindRecordMap.Lock() + dropBindRecordMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Load().(map[string]*droppedBindRecord)) + for key, droppedBindRecord := range dropBindRecordMap { + if droppedBindRecord.droppedTime.IsZero() { + err := h.DropBindRecord(droppedBindRecord.bindRecord) + if err != nil { + logutil.Logger(context.Background()).Error("handleDropBindRecord failed", zap.Error(err)) + } + droppedBindRecord.droppedTime = time.Now() + continue + } + + if time.Since(droppedBindRecord.droppedTime) > 6*time.Second { + delete(dropBindRecordMap, key) + } + } + h.dropBindRecordMap.Store(dropBindRecordMap) + h.dropBindRecordMap.Unlock() +} + +// AddDropBindRecordTask add bindRecord to dropBindRecordMap when the bindRecord need to be deleted. +func (h *BindHandle) AddDropBindRecordTask(dropBindRecord *BindRecord) { + key := dropBindRecord.OriginalSQL + ":" + dropBindRecord.Db + if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)[key]; ok { + return + } + h.dropBindRecordMap.Lock() + if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)[key]; ok { + return + } + newMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)) + newMap[key] = &droppedBindRecord{ + bindRecord: dropBindRecord, + } + h.dropBindRecordMap.Store(newMap) + h.dropBindRecordMap.Unlock() +} + // Size return the size of bind info cache. func (h *BindHandle) Size() int { size := 0 @@ -255,7 +307,7 @@ func (h *BindHandle) newBindMeta(record *BindRecord) (hash string, meta *bindMet if err != nil { return "", nil, err } - meta = &bindMeta{BindRecord: record, ast: stmtNodes[0]} + meta = &bindMeta{BindRecord: record, Ast: stmtNodes[0]} return hash, meta, nil } @@ -337,6 +389,14 @@ func (c cache) copy() cache { return newCache } +func copyDroppedBindRecordMap(oldMap map[string]*droppedBindRecord) map[string]*droppedBindRecord { + newMap := make(map[string]*droppedBindRecord, len(oldMap)) + for k, v := range oldMap { + newMap[k] = v + } + return newMap +} + // isStale checks whether this bindMeta is stale compared with the other bindMeta. func (m *bindMeta) isStale(other *bindMeta) bool { return m.OriginalSQL == other.OriginalSQL && m.Db == other.Db && diff --git a/bindinfo/session_handle.go b/bindinfo/session_handle.go index e6ae2f0d37089..1340386115b15 100644 --- a/bindinfo/session_handle.go +++ b/bindinfo/session_handle.go @@ -48,13 +48,12 @@ func (h *SessionHandle) newBindMeta(record *BindRecord) (hash string, meta *bind if err != nil { return "", nil, err } - meta = &bindMeta{BindRecord: record, ast: stmtNodes[0]} + meta = &bindMeta{BindRecord: record, Ast: stmtNodes[0]} return hash, meta, nil } // AddBindRecord new a BindRecord with bindMeta, add it to the cache. func (h *SessionHandle) AddBindRecord(record *BindRecord) error { - record.Status = using record.CreateTime = types.Time{ Time: types.FromGoTime(time.Now()), Type: mysql.TypeDatetime, @@ -84,6 +83,17 @@ func (h *SessionHandle) GetBindRecord(normdOrigSQL, db string) *bindMeta { return nil } +// DropBindRecord remove the bindRecord from session cache. +func (h *SessionHandle) DropBindRecord(normdOrigSQL, db string) { + hash := parser.DigestHash(normdOrigSQL) + record := &BindRecord{ + OriginalSQL: normdOrigSQL, + Db: db, + } + meta := &bindMeta{BindRecord: record} + h.ch.removeDeletedBindMeta(hash, meta) +} + // sessionBindInfoKeyType is a dummy type to avoid naming collision in context. type sessionBindInfoKeyType int diff --git a/domain/domain.go b/domain/domain.go index 45b8a1d5cec5a..523fc0d53af80 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -792,7 +792,7 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser } duration := 3 * time.Second - do.wg.Add(1) + do.wg.Add(2) go func() { defer do.wg.Done() defer recoverInDomain("loadBindInfoLoop", false) @@ -808,6 +808,20 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser } } }() + + handleInvaildTaskDuration := 3 * time.Second + go func() { + defer do.wg.Done() + defer recoverInDomain("loadBindInfoLoop-dropInvalidBindInfo", false) + for { + select { + case <-do.exit: + return + case <-time.After(handleInvaildTaskDuration): + } + do.bindHandle.HandleDropBindRecord() + } + }() return nil } diff --git a/executor/bind.go b/executor/bind.go index f23ba3cc1bb41..c7049efe606aa 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -75,6 +75,7 @@ func (e *SQLBindExec) createSQLBind() error { Db: e.ctx.GetSessionVars().CurrentDB, Charset: e.charset, Collation: e.collation, + Status: bindinfo.Using, } if !e.isGlobal { handle := e.ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) diff --git a/executor/compiler.go b/executor/compiler.go index 23b2239475ca9..6db7c3258d4db 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -18,8 +18,11 @@ import ( "fmt" "github.com/opentracing/opentracing-go" + "github.com/pingcap/parser" "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner" @@ -49,11 +52,24 @@ type Compiler struct { // Compile compiles an ast.StmtNode to a physical plan. func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (*ExecStmt, error) { + return c.compile(ctx, stmtNode, false) +} + +// SkipBindCompile compiles an ast.StmtNode to a physical plan without SQL bind. +func (c *Compiler) SkipBindCompile(ctx context.Context, node ast.StmtNode) (*ExecStmt, error) { + return c.compile(ctx, node, true) +} + +func (c *Compiler) compile(ctx context.Context, stmtNode ast.StmtNode, skipBind bool) (*ExecStmt, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("executor.Compile", opentracing.ChildOf(span.Context())) defer span1.Finish() } + if !skipBind { + stmtNode = addHint(c.Ctx, stmtNode) + } + infoSchema := GetInfoSchema(c.Ctx) if err := plannercore.Preprocess(c.Ctx, stmtNode, infoSchema); err != nil { return nil, err @@ -367,3 +383,136 @@ func GetInfoSchema(ctx sessionctx.Context) infoschema.InfoSchema { } return is } + +func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { + switch x := stmtNode.(type) { + case *ast.ExplainStmt: + switch x.Stmt.(type) { + case *ast.SelectStmt: + x.Stmt.SetText(x.Text()[len("explain "):]) + x.Stmt = addHintForSelect(ctx, x.Stmt) + } + return x + case *ast.SelectStmt: + return addHintForSelect(ctx, x) + default: + return stmtNode + } +} + +func addHintForSelect(ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { + if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. + return stmt + } + + normdOrigSQL := parser.Normalize(stmt.Text()) + sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) + if bindRecord != nil { + if bindRecord.Status == bindinfo.Invalid { + return stmt + } + if bindRecord.Status == bindinfo.Using { + return bindHint(stmt, bindRecord.Ast) + } + } + globalHandle := domain.GetDomain(ctx).BindHandle() + bindRecord = globalHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) + if bindRecord == nil { + bindRecord = globalHandle.GetBindRecord(normdOrigSQL, "") + } + if bindRecord != nil { + return bindHint(stmt, bindRecord.Ast) + } + return stmt +} + +func bindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode { + switch x := originStmt.(type) { + case *ast.SelectStmt: + return selectBind(x, hintedStmt.(*ast.SelectStmt)) + default: + return originStmt + } +} + +func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { + if hintedNode.TableHints != nil { + originalNode.TableHints = hintedNode.TableHints + } + if originalNode.From != nil { + originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join) + } + if originalNode.Where != nil { + originalNode.Where = selectionBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) + } + return originalNode +} + +func selectionBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { + switch v := where.(type) { + case *ast.SubqueryExpr: + if v.Query != nil { + v.Query = resultSetNodeBind(v.Query, hintedWhere.(*ast.SubqueryExpr).Query) + } + case *ast.ExistsSubqueryExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) + } + case *ast.PatternInExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) + } + } + return where +} + +func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode { + switch x := originalNode.(type) { + case *ast.Join: + return joinBind(x, hintedNode.(*ast.Join)) + case *ast.TableSource: + ts, _ := hintedNode.(*ast.TableSource) + switch v := x.Source.(type) { + case *ast.SelectStmt: + x.Source = selectBind(v, ts.Source.(*ast.SelectStmt)) + case *ast.UnionStmt: + x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt)) + case *ast.TableName: + x.Source = dataSourceBind(v, ts.Source.(*ast.TableName)) + } + return x + case *ast.SelectStmt: + return selectBind(x, hintedNode.(*ast.SelectStmt)) + case *ast.UnionStmt: + return unionSelectBind(x, hintedNode.(*ast.UnionStmt)) + default: + return x + } +} + +func dataSourceBind(originalNode, hintedNode *ast.TableName) *ast.TableName { + originalNode.IndexHints = hintedNode.IndexHints + return originalNode +} + +func joinBind(originalNode, hintedNode *ast.Join) *ast.Join { + if originalNode.Left != nil { + originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left) + } + + if hintedNode.Right != nil { + originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right) + } + + return originalNode +} + +func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode { + selects := originalNode.SelectList.Selects + for i := len(selects) - 1; i >= 0; i-- { + originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i]) + } + + return originalNode +} diff --git a/session/session.go b/session/session.go index 9685a4aa9e5c5..bfd9f03907a19 100644 --- a/session/session.go +++ b/session/session.go @@ -997,8 +997,9 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec sessionExecuteParseDurationGeneral.Observe(time.Since(startTS).Seconds()) } + var tempStmtNodes []ast.StmtNode compiler := executor.Compiler{Ctx: s} - for _, stmtNode := range stmtNodes { + for idx, stmtNode := range stmtNodes { s.PrepareTxnCtx(ctx) // Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). @@ -1009,11 +1010,19 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { - s.rollbackOnError(ctx) - logutil.Logger(ctx).Warn("compile sql error", - zap.Error(err), - zap.String("sql", sql)) - return nil, err + if tempStmtNodes == nil { + tempStmtNodes, _, _ = s.ParseSQL(ctx, sql, charsetInfo, collation) + } + stmtNode = tempStmtNodes[idx] + stmt, err = compiler.SkipBindCompile(ctx, stmtNode) + if err != nil { + s.rollbackOnError(ctx) + logutil.Logger(ctx).Warn("compile sql error", + zap.Error(err), + zap.String("sql", sql)) + return nil, err + } + s.handleInValidBindRecord(ctx, stmtNode) } if isInternal { sessionExecuteCompileDurationInternal.Observe(time.Since(startTS).Seconds()) @@ -1039,6 +1048,57 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec return recordSets, nil } +func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.StmtNode) { + var stmt ast.StmtNode + switch x := stmtNode.(type) { + case *ast.ExplainStmt: + switch x.Stmt.(type) { + case *ast.SelectStmt: + x.Stmt.SetText(x.Text()[len("explain "):]) + stmt = x.Stmt + } + case *ast.SelectStmt: + stmt = x + default: + return + } + normdOrigSQL := parser.Normalize(stmt.Text()) + sessionHandle := s.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindMeta := sessionHandle.GetBindRecord(normdOrigSQL, s.GetSessionVars().CurrentDB) + if bindMeta != nil { + bindMeta.Status = bindinfo.Invalid + return + } + + globalHandle := domain.GetDomain(s).BindHandle() + bindMeta = globalHandle.GetBindRecord(normdOrigSQL, s.GetSessionVars().CurrentDB) + if bindMeta == nil { + bindMeta = globalHandle.GetBindRecord(normdOrigSQL, "") + } + if bindMeta != nil { + record := &bindinfo.BindRecord{ + OriginalSQL: bindMeta.OriginalSQL, + BindSQL: bindMeta.BindSQL, + Db: s.GetSessionVars().CurrentDB, + Charset: bindMeta.Charset, + Collation: bindMeta.Collation, + Status: bindinfo.Invalid, + } + + err := sessionHandle.AddBindRecord(record) + if err != nil { + logutil.Logger(ctx).Warn("handleInValidBindRecord failed", zap.Error(err)) + } + + globalHandle := domain.GetDomain(s).BindHandle() + dropBindRecord := &bindinfo.BindRecord{ + OriginalSQL: bindMeta.OriginalSQL, + Db: bindMeta.Db, + } + globalHandle.AddDropBindRecordTask(dropBindRecord) + } +} + // rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. func (s *session) rollbackOnError(ctx context.Context) { if !s.sessionVars.InTxn() { From f3b67c0e0d53c115ff838f35dd1cc1190392c95c Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 11:21:54 +0800 Subject: [PATCH 02/18] error check --- session/session.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/session/session.go b/session/session.go index bfd9f03907a19..9fbcbcd9bc599 100644 --- a/session/session.go +++ b/session/session.go @@ -1011,7 +1011,10 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { if tempStmtNodes == nil { - tempStmtNodes, _, _ = s.ParseSQL(ctx, sql, charsetInfo, collation) + tempStmtNodes, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) + if err != nil || warns != nil{ + //just skip errcheck, because parse will not return an error. + } } stmtNode = tempStmtNodes[idx] stmt, err = compiler.SkipBindCompile(ctx, stmtNode) From 3af4a08375b90bcf818f030dce680812f06b63e5 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 11:22:28 +0800 Subject: [PATCH 03/18] fmt --- session/session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/session/session.go b/session/session.go index 9fbcbcd9bc599..67eb01d1eca20 100644 --- a/session/session.go +++ b/session/session.go @@ -1012,7 +1012,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec if err != nil { if tempStmtNodes == nil { tempStmtNodes, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) - if err != nil || warns != nil{ + if err != nil || warns != nil { //just skip errcheck, because parse will not return an error. } } From 7347e2537bcb803a0e3014faa5f7fce1d4d32108 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 15:29:37 +0800 Subject: [PATCH 04/18] add lower case --- executor/compiler.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/executor/compiler.go b/executor/compiler.go index 6db7c3258d4db..45601bd8bea92 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -16,6 +16,7 @@ package executor import ( "context" "fmt" + "strings" "github.com/opentracing/opentracing-go" "github.com/pingcap/parser" @@ -389,23 +390,25 @@ func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { case *ast.ExplainStmt: switch x.Stmt.(type) { case *ast.SelectStmt: - x.Stmt.SetText(x.Text()[len("explain "):]) - x.Stmt = addHintForSelect(ctx, x.Stmt) + normalizeExplainSQL := parser.Normalize(x.Text()) + lowerSQL := strings.ToLower(normalizeExplainSQL) + idx := strings.Index(lowerSQL, "select") + normalizeSQL := normalizeExplainSQL[idx:] + x.Stmt = addHintForSelect(normalizeSQL, ctx, x.Stmt) } return x case *ast.SelectStmt: - return addHintForSelect(ctx, x) + return addHintForSelect(parser.Normalize(x.Text()), ctx, x) default: return stmtNode } } -func addHintForSelect(ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { +func addHintForSelect(normdOrigSQL string, ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. return stmt } - normdOrigSQL := parser.Normalize(stmt.Text()) sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) if bindRecord != nil { From 7a48661d6ae78f8b03d863c4a24ea6e5483e5d72 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 21:14:39 +0800 Subject: [PATCH 05/18] fix comment --- bindinfo/bind_test.go | 55 +++++++++++++------------------------------ bindinfo/handle.go | 2 +- 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index d219ae7aebeea..ec577ff3232c8 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -290,6 +290,15 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { tk.MustExec("create table t1(id int)") tk.MustExec("create table t2(id int)") + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", "├─Sort_11 9990.00 root test.t1.id:asc", @@ -302,16 +311,6 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) - tk.MustQuery("explain SELECT /*+ TIDB_INLJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( - "HashLeftJoin_7 12487.50 root inner join, inner:TableReader_14, equal:[eq(test.t1.id, test.t2.id)]", - "├─TableReader_11 9990.00 root data:Selection_10", - "│ └─Selection_10 9990.00 cop not(isnull(test.t1.id))", - "│ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", - "└─TableReader_14 9990.00 root data:Selection_13", - " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", - " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", - )) - tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( @@ -327,38 +326,16 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { )) tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") -} - -func (s *testSuite) TestComplexSqlBinding(c *C) { - tk := testkit.NewTestKit(c, s.store) - s.cleanBindingEnv(tk) - tk.MustExec("use test") - tk.MustExec("drop table if exists t1") - tk.MustExec("drop table if exists t2") - tk.MustExec("create table t1(id int)") - tk.MustExec("create table t2(id int)") - tk.MustExec("create index index_t1 on t1(id)") - tk.MustExec("create index index_t1 on t2(id)") tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( - "MergeJoin_8 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", - "├─IndexReader_20 9990.00 root index:IndexScan_19", - "│ └─IndexScan_19 9990.00 cop table:t1, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", - "└─IndexReader_22 9990.00 root index:IndexScan_21", - " └─IndexScan_21 9990.00 cop table:t2, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) - - tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using select * from t1 use index(index_t1), t2 use index(index_t2) where t1.id = t2.id") - - tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( - "MergeJoin_8 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", - "├─IndexReader_20 9990.00 root index:IndexScan_19", - "│ └─IndexScan_19 9990.00 cop table:t1, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", - "└─IndexReader_22 9990.00 root index:IndexScan_21", - " └─IndexScan_21 9990.00 cop table:t2, index:id, range:[-inf,+inf], keep order:true, stats:pseudo", - )) - - tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") } func (s *testSuite) TestExplain(c *C) { diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 6d483e483b117..a0019aaf6985d 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -258,6 +258,7 @@ func (h *BindHandle) AddDropBindRecordTask(dropBindRecord *BindRecord) { return } h.dropBindRecordMap.Lock() + defer h.dropBindRecordMap.Unlock() if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)[key]; ok { return } @@ -266,7 +267,6 @@ func (h *BindHandle) AddDropBindRecordTask(dropBindRecord *BindRecord) { bindRecord: dropBindRecord, } h.dropBindRecordMap.Store(newMap) - h.dropBindRecordMap.Unlock() } // Size return the size of bind info cache. From 87532ea6660520429dd7fc4e6bc2658440cac1f9 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Sun, 28 Apr 2019 21:29:52 +0800 Subject: [PATCH 06/18] code format --- bindinfo/bind.go | 107 ++++++++++++++++++++++++++++++++++++++++++ bindinfo/bind_test.go | 11 +++++ domain/domain.go | 18 +++++-- executor/compiler.go | 94 +------------------------------------ session/session.go | 13 +++-- 5 files changed, 141 insertions(+), 102 deletions(-) create mode 100644 bindinfo/bind.go diff --git a/bindinfo/bind.go b/bindinfo/bind.go new file mode 100644 index 0000000000000..6ba6bdfa200dc --- /dev/null +++ b/bindinfo/bind.go @@ -0,0 +1,107 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo + +import "github.com/pingcap/parser/ast" + +// BindHint will add hints for originStmt according to hintedStmt' hints. +func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode { + switch x := originStmt.(type) { + case *ast.SelectStmt: + return selectBind(x, hintedStmt.(*ast.SelectStmt)) + default: + return originStmt + } +} + +func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { + if hintedNode.TableHints != nil { + originalNode.TableHints = hintedNode.TableHints + } + if originalNode.From != nil { + originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join) + } + if originalNode.Where != nil { + originalNode.Where = selectionBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) + } + return originalNode +} + +func selectionBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { + switch v := where.(type) { + case *ast.SubqueryExpr: + if v.Query != nil { + v.Query = resultSetNodeBind(v.Query, hintedWhere.(*ast.SubqueryExpr).Query) + } + case *ast.ExistsSubqueryExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) + } + case *ast.PatternInExpr: + if v.Sel != nil { + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) + } + } + return where +} + +func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode { + switch x := originalNode.(type) { + case *ast.Join: + return joinBind(x, hintedNode.(*ast.Join)) + case *ast.TableSource: + ts, _ := hintedNode.(*ast.TableSource) + switch v := x.Source.(type) { + case *ast.SelectStmt: + x.Source = selectBind(v, ts.Source.(*ast.SelectStmt)) + case *ast.UnionStmt: + x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt)) + case *ast.TableName: + x.Source = dataSourceBind(v, ts.Source.(*ast.TableName)) + } + return x + case *ast.SelectStmt: + return selectBind(x, hintedNode.(*ast.SelectStmt)) + case *ast.UnionStmt: + return unionSelectBind(x, hintedNode.(*ast.UnionStmt)) + default: + return x + } +} + +func dataSourceBind(originalNode, hintedNode *ast.TableName) *ast.TableName { + originalNode.IndexHints = hintedNode.IndexHints + return originalNode +} + +func joinBind(originalNode, hintedNode *ast.Join) *ast.Join { + if originalNode.Left != nil { + originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left) + } + + if hintedNode.Right != nil { + originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right) + } + + return originalNode +} + +func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode { + selects := originalNode.SelectList.Selects + for i := len(selects) - 1; i >= 0; i-- { + originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i]) + } + + return originalNode +} diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index ec577ff3232c8..069207d2c6a03 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -299,6 +299,7 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", "├─Sort_11 9990.00 root test.t1.id:asc", @@ -347,6 +348,16 @@ func (s *testSuite) TestExplain(c *C) { tk.MustExec("create table t1(id int)") tk.MustExec("create table t2(id int)") + tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( + "HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_15 9990.00 root data:Selection_14", + " └─Selection_14 9990.00 cop not(isnull(test.t2.id))", + " └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", + )) + tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", "├─Sort_11 9990.00 root test.t1.id:asc", diff --git a/domain/domain.go b/domain/domain.go index 523fc0d53af80..b0c8f4914c838 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -791,8 +791,14 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser return err } + do.loadBindInfoLoop() + do.handleInvalidBindTaskLoop() + return nil +} + +func (do *Domain) loadBindInfoLoop() { duration := 3 * time.Second - do.wg.Add(2) + do.wg.Add(1) go func() { defer do.wg.Done() defer recoverInDomain("loadBindInfoLoop", false) @@ -802,14 +808,17 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser return case <-time.After(duration): } - err = do.bindHandle.Update(false) + err := do.bindHandle.Update(false) if err != nil { logutil.Logger(context.Background()).Error("update bindinfo failed", zap.Error(err)) } } }() +} - handleInvaildTaskDuration := 3 * time.Second +func (do *Domain) handleInvalidBindTaskLoop() { + handleInvalidTaskDuration := 3 * time.Second + do.wg.Add(1) go func() { defer do.wg.Done() defer recoverInDomain("loadBindInfoLoop-dropInvalidBindInfo", false) @@ -817,12 +826,11 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context, parser *parser.Parser select { case <-do.exit: return - case <-time.After(handleInvaildTaskDuration): + case <-time.After(handleInvalidTaskDuration): } do.bindHandle.HandleDropBindRecord() } }() - return nil } // StatsHandle returns the statistic handle. diff --git a/executor/compiler.go b/executor/compiler.go index 45601bd8bea92..5195ca6e43d5d 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -416,7 +416,7 @@ func addHintForSelect(normdOrigSQL string, ctx sessionctx.Context, stmt ast.Stmt return stmt } if bindRecord.Status == bindinfo.Using { - return bindHint(stmt, bindRecord.Ast) + return bindinfo.BindHint(stmt, bindRecord.Ast) } } globalHandle := domain.GetDomain(ctx).BindHandle() @@ -425,97 +425,7 @@ func addHintForSelect(normdOrigSQL string, ctx sessionctx.Context, stmt ast.Stmt bindRecord = globalHandle.GetBindRecord(normdOrigSQL, "") } if bindRecord != nil { - return bindHint(stmt, bindRecord.Ast) + return bindinfo.BindHint(stmt, bindRecord.Ast) } return stmt } - -func bindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode { - switch x := originStmt.(type) { - case *ast.SelectStmt: - return selectBind(x, hintedStmt.(*ast.SelectStmt)) - default: - return originStmt - } -} - -func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { - if hintedNode.TableHints != nil { - originalNode.TableHints = hintedNode.TableHints - } - if originalNode.From != nil { - originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join) - } - if originalNode.Where != nil { - originalNode.Where = selectionBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) - } - return originalNode -} - -func selectionBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { - switch v := where.(type) { - case *ast.SubqueryExpr: - if v.Query != nil { - v.Query = resultSetNodeBind(v.Query, hintedWhere.(*ast.SubqueryExpr).Query) - } - case *ast.ExistsSubqueryExpr: - if v.Sel != nil { - v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) - } - case *ast.PatternInExpr: - if v.Sel != nil { - v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) - } - } - return where -} - -func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode { - switch x := originalNode.(type) { - case *ast.Join: - return joinBind(x, hintedNode.(*ast.Join)) - case *ast.TableSource: - ts, _ := hintedNode.(*ast.TableSource) - switch v := x.Source.(type) { - case *ast.SelectStmt: - x.Source = selectBind(v, ts.Source.(*ast.SelectStmt)) - case *ast.UnionStmt: - x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt)) - case *ast.TableName: - x.Source = dataSourceBind(v, ts.Source.(*ast.TableName)) - } - return x - case *ast.SelectStmt: - return selectBind(x, hintedNode.(*ast.SelectStmt)) - case *ast.UnionStmt: - return unionSelectBind(x, hintedNode.(*ast.UnionStmt)) - default: - return x - } -} - -func dataSourceBind(originalNode, hintedNode *ast.TableName) *ast.TableName { - originalNode.IndexHints = hintedNode.IndexHints - return originalNode -} - -func joinBind(originalNode, hintedNode *ast.Join) *ast.Join { - if originalNode.Left != nil { - originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left) - } - - if hintedNode.Right != nil { - originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right) - } - - return originalNode -} - -func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode { - selects := originalNode.SelectList.Selects - for i := len(selects) - 1; i >= 0; i-- { - originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i]) - } - - return originalNode -} diff --git a/session/session.go b/session/session.go index 67eb01d1eca20..5b710e189069b 100644 --- a/session/session.go +++ b/session/session.go @@ -1052,20 +1052,23 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.StmtNode) { - var stmt ast.StmtNode + var normdOrigSQL string switch x := stmtNode.(type) { case *ast.ExplainStmt: switch x.Stmt.(type) { case *ast.SelectStmt: - x.Stmt.SetText(x.Text()[len("explain "):]) - stmt = x.Stmt + normalizeExplainSQL := parser.Normalize(x.Text()) + lowerSQL := strings.ToLower(normalizeExplainSQL) + idx := strings.Index(lowerSQL, "select") + normdOrigSQL = normalizeExplainSQL[idx:] + default: + return } case *ast.SelectStmt: - stmt = x + normdOrigSQL = parser.Normalize(x.Text()) default: return } - normdOrigSQL := parser.Normalize(stmt.Text()) sessionHandle := s.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) bindMeta := sessionHandle.GetBindRecord(normdOrigSQL, s.GetSessionVars().CurrentDB) if bindMeta != nil { From e87dc6f7e6d5662c5cd013a7c7cb9e41edee231a Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 01:04:05 +0800 Subject: [PATCH 07/18] add bind exprNode --- bindinfo/bind.go | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index 6ba6bdfa200dc..e932fc7622ce8 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -33,12 +33,43 @@ func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join) } if originalNode.Where != nil { - originalNode.Where = selectionBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) + originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode) } + + if originalNode.Having != nil { + originalNode.Having = havingBind(originalNode.Having, hintedNode.Having) + } + + if originalNode.OrderBy != nil { + originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy) + } + + if originalNode.Fields != nil { + for idx := 0; idx < len(originalNode.Fields.Fields); idx++ { + originalNode.Fields.Fields[idx] = selectFieldBind(originalNode.Fields.Fields[idx], hintedNode.Fields.Fields[idx]) + } + } + return originalNode +} + +func selectFieldBind(originalNode, hintedNode *ast.SelectField) *ast.SelectField { + originalNode.Expr = exprBind(originalNode.Expr, hintedNode.Expr) + return originalNode +} + +func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause { + for idx := 0; idx < len(originalNode.Items); idx++ { + originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr) + } + return originalNode +} + +func havingBind(originalNode, hintedNode *ast.HavingClause) *ast.HavingClause { + originalNode.Expr = exprBind(originalNode.Expr, hintedNode.Expr) return originalNode } -func selectionBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { +func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { switch v := where.(type) { case *ast.SubqueryExpr: if v.Query != nil { From 75a5eae5bae177a0bc8ab4d3c0eaf4063d16d840 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 10:41:49 +0800 Subject: [PATCH 08/18] add where supExpr --- bindinfo/bind.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index e932fc7622ce8..98a531871ed8e 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -83,6 +83,16 @@ func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { if v.Sel != nil { v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) } + case *ast.BinaryOperationExpr: + if v.L != nil { + switch v.L.(type) { + case *ast.BinaryOperationExpr: + v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) + v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) + case *ast.PatternInExpr: + v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) + } + } } return where } From 179024795f37b4bfe2421d8d9450d84184450d31 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 10:57:08 +0800 Subject: [PATCH 09/18] update v.r --- bindinfo/bind.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index 98a531871ed8e..33ba8af509853 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -88,11 +88,18 @@ func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { switch v.L.(type) { case *ast.BinaryOperationExpr: v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) - v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) case *ast.PatternInExpr: v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) } } + if v.R != nil{ + switch v.R.(type) { + case *ast.BinaryOperationExpr: + v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) + case *ast.PatternInExpr: + v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) + } + } } return where } From 0cd298d35b6e68674a1949bfd2531ab39de85041 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 11:39:49 +0800 Subject: [PATCH 10/18] add like isTruth and compare expr bind --- bindinfo/bind.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index 33ba8af509853..756b5d0a29313 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -92,7 +92,7 @@ func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) } } - if v.R != nil{ + if v.R != nil { switch v.R.(type) { case *ast.BinaryOperationExpr: v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) @@ -100,6 +100,25 @@ func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) } } + case *ast.IsNullExpr: + if v.Expr != nil { + v.Expr = exprBind(v.Expr, hintedWhere.(*ast.IsNullExpr).Expr) + } + case *ast.IsTruthExpr: + if v.Expr != nil { + v.Expr = exprBind(v.Expr, hintedWhere.(*ast.IsTruthExpr).Expr) + } + case *ast.PatternLikeExpr: + if v.Pattern != nil { + v.Pattern = exprBind(v.Pattern, hintedWhere.(*ast.PatternLikeExpr).Pattern) + } + case *ast.CompareSubqueryExpr: + if v.L != nil { + v.L = exprBind(v.L, hintedWhere.(*ast.CompareSubqueryExpr).L) + } + if v.R != nil { + v.R = exprBind(v.R, hintedWhere.(*ast.CompareSubqueryExpr).R) + } } return where } From 47b2759158ed8cfaa4f3013f50e2a3e87a61dc3f Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 11:41:08 +0800 Subject: [PATCH 11/18] defer --- bindinfo/handle.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index a0019aaf6985d..176e0b59c64c2 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -232,6 +232,7 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { // HandleDropBindRecord execute the drop bindRecord task. func (h *BindHandle) HandleDropBindRecord() { h.dropBindRecordMap.Lock() + defer h.dropBindRecordMap.Unlock() dropBindRecordMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Load().(map[string]*droppedBindRecord)) for key, droppedBindRecord := range dropBindRecordMap { if droppedBindRecord.droppedTime.IsZero() { @@ -248,7 +249,6 @@ func (h *BindHandle) HandleDropBindRecord() { } } h.dropBindRecordMap.Store(dropBindRecordMap) - h.dropBindRecordMap.Unlock() } // AddDropBindRecordTask add bindRecord to dropBindRecordMap when the bindRecord need to be deleted. From f4735c39ffaa875a2ffade34b512340c02e5f937 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 12:13:17 +0800 Subject: [PATCH 12/18] code remove --- bindinfo/bind.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index 756b5d0a29313..99c69a746c378 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -85,20 +85,10 @@ func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { } case *ast.BinaryOperationExpr: if v.L != nil { - switch v.L.(type) { - case *ast.BinaryOperationExpr: - v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) - case *ast.PatternInExpr: - v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) - } + v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) } if v.R != nil { - switch v.R.(type) { - case *ast.BinaryOperationExpr: - v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) - case *ast.PatternInExpr: - v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) - } + v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) } case *ast.IsNullExpr: if v.Expr != nil { From 167e011818f9fc859b9c79dd8bb7f5b7a58d4c6b Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 12:21:27 +0800 Subject: [PATCH 13/18] remove lock --- bindinfo/handle.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 176e0b59c64c2..3c1529dbf61fb 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -231,8 +231,6 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { // HandleDropBindRecord execute the drop bindRecord task. func (h *BindHandle) HandleDropBindRecord() { - h.dropBindRecordMap.Lock() - defer h.dropBindRecordMap.Unlock() dropBindRecordMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Load().(map[string]*droppedBindRecord)) for key, droppedBindRecord := range dropBindRecordMap { if droppedBindRecord.droppedTime.IsZero() { From 03e7b15d8e194f69971255b4d57e44ddf5e35e13 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 15:12:44 +0800 Subject: [PATCH 14/18] update --- bindinfo/bind.go | 62 +++++++++++++++++++++++++------------------- executor/compiler.go | 50 ++++++++++++++++++++++++++++++----- session/session.go | 6 ++--- 3 files changed, 83 insertions(+), 35 deletions(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index 99c69a746c378..dc44e5c9fc2a7 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -45,18 +45,15 @@ func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { } if originalNode.Fields != nil { - for idx := 0; idx < len(originalNode.Fields.Fields); idx++ { - originalNode.Fields.Fields[idx] = selectFieldBind(originalNode.Fields.Fields[idx], hintedNode.Fields.Fields[idx]) + origFields := originalNode.Fields.Fields + hintFields := hintedNode.Fields.Fields + for idx := range origFields { + origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr) } } return originalNode } -func selectFieldBind(originalNode, hintedNode *ast.SelectField) *ast.SelectField { - originalNode.Expr = exprBind(originalNode.Expr, hintedNode.Expr) - return originalNode -} - func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause { for idx := 0; idx < len(originalNode.Items); idx++ { originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr) @@ -69,48 +66,66 @@ func havingBind(originalNode, hintedNode *ast.HavingClause) *ast.HavingClause { return originalNode } -func exprBind(where ast.ExprNode, hintedWhere ast.ExprNode) ast.ExprNode { - switch v := where.(type) { +func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode { + switch v := originalNode.(type) { case *ast.SubqueryExpr: if v.Query != nil { - v.Query = resultSetNodeBind(v.Query, hintedWhere.(*ast.SubqueryExpr).Query) + v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query) } case *ast.ExistsSubqueryExpr: if v.Sel != nil { - v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query) } case *ast.PatternInExpr: if v.Sel != nil { - v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedWhere.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) + v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query) } case *ast.BinaryOperationExpr: if v.L != nil { - v.L = exprBind(v.L, hintedWhere.(*ast.BinaryOperationExpr).L) + v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L) } if v.R != nil { - v.R = exprBind(v.R, hintedWhere.(*ast.BinaryOperationExpr).R) + v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R) } case *ast.IsNullExpr: if v.Expr != nil { - v.Expr = exprBind(v.Expr, hintedWhere.(*ast.IsNullExpr).Expr) + v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr) } case *ast.IsTruthExpr: if v.Expr != nil { - v.Expr = exprBind(v.Expr, hintedWhere.(*ast.IsTruthExpr).Expr) + v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr) } case *ast.PatternLikeExpr: if v.Pattern != nil { - v.Pattern = exprBind(v.Pattern, hintedWhere.(*ast.PatternLikeExpr).Pattern) + v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern) } case *ast.CompareSubqueryExpr: if v.L != nil { - v.L = exprBind(v.L, hintedWhere.(*ast.CompareSubqueryExpr).L) + v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L) } if v.R != nil { - v.R = exprBind(v.R, hintedWhere.(*ast.CompareSubqueryExpr).R) + v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R) + } + case *ast.BetweenExpr: + if v.Left != nil { + v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left) + } + if v.Right != nil { + v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right) + } + case *ast.UnaryOperationExpr: + if v.V != nil { + v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V) + } + case *ast.CaseExpr: + if v.Value != nil { + v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value) + } + if v.ElseClause != nil { + v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause) } } - return where + return originalNode } func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode { @@ -125,7 +140,7 @@ func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSet case *ast.UnionStmt: x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt)) case *ast.TableName: - x.Source = dataSourceBind(v, ts.Source.(*ast.TableName)) + x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints } return x case *ast.SelectStmt: @@ -137,11 +152,6 @@ func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSet } } -func dataSourceBind(originalNode, hintedNode *ast.TableName) *ast.TableName { - originalNode.IndexHints = hintedNode.IndexHints - return originalNode -} - func joinBind(originalNode, hintedNode *ast.Join) *ast.Join { if originalNode.Left != nil { originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left) diff --git a/executor/compiler.go b/executor/compiler.go index 5195ca6e43d5d..ceb4e294f030f 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -95,6 +95,46 @@ func (c *Compiler) compile(ctx context.Context, stmtNode ast.StmtNode, skipBind }, nil } +// GetSelectTextFromStmtNode return stmtNode's select text if select text exist. +func (c *Compiler) GetSelectTextFromStmtNode(stmtNode ast.StmtNode) string { + switch x := stmtNode.(type) { + case *ast.ExplainStmt: + switch x.Stmt.(type) { + case *ast.SelectStmt: + normalizeExplainSQL := parser.Normalize(x.Text()) + idx := strings.Index(normalizeExplainSQL, "select") + return normalizeExplainSQL[idx:] + } + case *ast.SelectStmt: + return parser.Normalize(x.Text()) + } + return "" +} + +// GetBindMeta return normdOriginSQL's bindMeta in session bind cache or global bind cache. +func (c *Compiler) GetBindMeta(ctx sessionctx.Context, normalizeSQL string) *bindinfo.BindMeta { + if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. + return nil + } + sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) + bindMeta := sessionHandle.GetBindRecord(normalizeSQL, ctx.GetSessionVars().CurrentDB) + if bindMeta != nil { + if bindMeta.Status == bindinfo.Invalid { + return nil + } + if bindMeta.Status == bindinfo.Using { + return bindMeta + } + } + + globalHandle := domain.GetDomain(ctx).BindHandle() + bindMeta = globalHandle.GetBindRecord(normalizeSQL, ctx.GetSessionVars().CurrentDB) + if bindMeta == nil { + bindMeta = globalHandle.GetBindRecord(normalizeSQL, "") + } + return bindMeta +} + func logExpensiveQuery(stmtNode ast.StmtNode, finalPlan plannercore.Plan) (expensive bool) { expensive = isExpensiveQuery(finalPlan) if !expensive { @@ -386,13 +426,15 @@ func GetInfoSchema(ctx sessionctx.Context) infoschema.InfoSchema { } func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { + if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. + return stmtNode + } switch x := stmtNode.(type) { case *ast.ExplainStmt: switch x.Stmt.(type) { case *ast.SelectStmt: normalizeExplainSQL := parser.Normalize(x.Text()) - lowerSQL := strings.ToLower(normalizeExplainSQL) - idx := strings.Index(lowerSQL, "select") + idx := strings.Index(normalizeExplainSQL, "select") normalizeSQL := normalizeExplainSQL[idx:] x.Stmt = addHintForSelect(normalizeSQL, ctx, x.Stmt) } @@ -405,10 +447,6 @@ func addHint(ctx sessionctx.Context, stmtNode ast.StmtNode) ast.StmtNode { } func addHintForSelect(normdOrigSQL string, ctx sessionctx.Context, stmt ast.StmtNode) ast.StmtNode { - if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. - return stmt - } - sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB) if bindRecord != nil { diff --git a/session/session.go b/session/session.go index 5b710e189069b..3250f4c12c424 100644 --- a/session/session.go +++ b/session/session.go @@ -1010,7 +1010,8 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { - if tempStmtNodes == nil { + normolizedSQL := compiler.GetSelectTextFromStmtNode(stmtNode) + if tempStmtNodes == nil && normolizedSQL != "" && compiler.GetBindMeta(s, normolizedSQL) != nil { tempStmtNodes, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) if err != nil || warns != nil { //just skip errcheck, because parse will not return an error. @@ -1058,8 +1059,7 @@ func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.Stmt switch x.Stmt.(type) { case *ast.SelectStmt: normalizeExplainSQL := parser.Normalize(x.Text()) - lowerSQL := strings.ToLower(normalizeExplainSQL) - idx := strings.Index(lowerSQL, "select") + idx := strings.Index(normalizeExplainSQL, "select") normdOrigSQL = normalizeExplainSQL[idx:] default: return From 924c2a4256456b05c3c44de23e7f89f32052aeeb Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 15:18:23 +0800 Subject: [PATCH 15/18] expreNode bind --- bindinfo/bind.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/bindinfo/bind.go b/bindinfo/bind.go index dc44e5c9fc2a7..9d70aadf198a4 100644 --- a/bindinfo/bind.go +++ b/bindinfo/bind.go @@ -37,7 +37,7 @@ func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt { } if originalNode.Having != nil { - originalNode.Having = havingBind(originalNode.Having, hintedNode.Having) + originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr) } if originalNode.OrderBy != nil { @@ -61,11 +61,6 @@ func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause return originalNode } -func havingBind(originalNode, hintedNode *ast.HavingClause) *ast.HavingClause { - originalNode.Expr = exprBind(originalNode.Expr, hintedNode.Expr) - return originalNode -} - func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode { switch v := originalNode.(type) { case *ast.SubqueryExpr: From c3303b6b3ccc92c0cb4ce2ae06a48f76599523c1 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 15:27:36 +0800 Subject: [PATCH 16/18] remove check --- executor/compiler.go | 40 ---------------------------------------- session/session.go | 3 +-- 2 files changed, 1 insertion(+), 42 deletions(-) diff --git a/executor/compiler.go b/executor/compiler.go index ceb4e294f030f..f233bdfe833cd 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -95,46 +95,6 @@ func (c *Compiler) compile(ctx context.Context, stmtNode ast.StmtNode, skipBind }, nil } -// GetSelectTextFromStmtNode return stmtNode's select text if select text exist. -func (c *Compiler) GetSelectTextFromStmtNode(stmtNode ast.StmtNode) string { - switch x := stmtNode.(type) { - case *ast.ExplainStmt: - switch x.Stmt.(type) { - case *ast.SelectStmt: - normalizeExplainSQL := parser.Normalize(x.Text()) - idx := strings.Index(normalizeExplainSQL, "select") - return normalizeExplainSQL[idx:] - } - case *ast.SelectStmt: - return parser.Normalize(x.Text()) - } - return "" -} - -// GetBindMeta return normdOriginSQL's bindMeta in session bind cache or global bind cache. -func (c *Compiler) GetBindMeta(ctx sessionctx.Context, normalizeSQL string) *bindinfo.BindMeta { - if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil { //when the domain is initializing, the bind will be nil. - return nil - } - sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) - bindMeta := sessionHandle.GetBindRecord(normalizeSQL, ctx.GetSessionVars().CurrentDB) - if bindMeta != nil { - if bindMeta.Status == bindinfo.Invalid { - return nil - } - if bindMeta.Status == bindinfo.Using { - return bindMeta - } - } - - globalHandle := domain.GetDomain(ctx).BindHandle() - bindMeta = globalHandle.GetBindRecord(normalizeSQL, ctx.GetSessionVars().CurrentDB) - if bindMeta == nil { - bindMeta = globalHandle.GetBindRecord(normalizeSQL, "") - } - return bindMeta -} - func logExpensiveQuery(stmtNode ast.StmtNode, finalPlan plannercore.Plan) (expensive bool) { expensive = isExpensiveQuery(finalPlan) if !expensive { diff --git a/session/session.go b/session/session.go index 3250f4c12c424..4b61570bb52bb 100644 --- a/session/session.go +++ b/session/session.go @@ -1010,8 +1010,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } stmt, err := compiler.Compile(ctx, stmtNode) if err != nil { - normolizedSQL := compiler.GetSelectTextFromStmtNode(stmtNode) - if tempStmtNodes == nil && normolizedSQL != "" && compiler.GetBindMeta(s, normolizedSQL) != nil { + if tempStmtNodes == nil { tempStmtNodes, warns, err = s.ParseSQL(ctx, sql, charsetInfo, collation) if err != nil || warns != nil { //just skip errcheck, because parse will not return an error. From dadeb8715d6d087f69e59b1205cc0c1f1dddd603 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 16:26:33 +0800 Subject: [PATCH 17/18] rename droppedBindRecord etc --- bindinfo/bind_test.go | 2 +- bindinfo/handle.go | 48 ++++++++++++++++++++++--------------------- domain/domain.go | 2 +- session/session.go | 8 ++++---- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index 069207d2c6a03..e37ae0f45eff3 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -415,7 +415,7 @@ func (s *testSuite) TestErrorBind(c *C) { _, err = tk.Exec("select * from t where i > 10") c.Check(err, IsNil) - s.domain.BindHandle().HandleDropBindRecord() + s.domain.BindHandle().DropInvalidBindRecord() rs, err := tk.Exec("show global bindings") c.Assert(err, IsNil) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 3c1529dbf61fb..8536fb153dc3e 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -61,6 +61,8 @@ type BindHandle struct { atomic.Value } + // invalidBindRecordMap indicates the invalid bind records found during querying. + // A record will be deleted from this map, after 2 bind-lease, after it is dropped from the kv. dropBindRecordMap struct { sync.Mutex atomic.Value @@ -70,7 +72,7 @@ type BindHandle struct { lastUpdateTime types.Time } -type droppedBindRecord struct { +type invalidBindRecordMap struct { bindRecord *BindRecord droppedTime time.Time } @@ -80,7 +82,7 @@ func NewBindHandle(ctx sessionctx.Context, parser *parser.Parser) *BindHandle { handle := &BindHandle{parser: parser} handle.sctx.Context = ctx handle.bindInfo.Value.Store(make(cache, 32)) - handle.dropBindRecordMap.Value.Store(make(map[string]*droppedBindRecord)) + handle.dropBindRecordMap.Value.Store(make(map[string]*invalidBindRecordMap)) return handle } @@ -229,40 +231,40 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { return err } -// HandleDropBindRecord execute the drop bindRecord task. -func (h *BindHandle) HandleDropBindRecord() { - dropBindRecordMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Load().(map[string]*droppedBindRecord)) - for key, droppedBindRecord := range dropBindRecordMap { - if droppedBindRecord.droppedTime.IsZero() { - err := h.DropBindRecord(droppedBindRecord.bindRecord) +// DropInvalidBindRecord execute the drop bindRecord task. +func (h *BindHandle) DropInvalidBindRecord() { + invalidBindRecordMap := copyInvalidBindRecordMap(h.dropBindRecordMap.Load().(map[string]*invalidBindRecordMap)) + for key, invalidBindRecord := range invalidBindRecordMap { + if invalidBindRecord.droppedTime.IsZero() { + err := h.DropBindRecord(invalidBindRecord.bindRecord) if err != nil { - logutil.Logger(context.Background()).Error("handleDropBindRecord failed", zap.Error(err)) + logutil.Logger(context.Background()).Error("DropInvalidBindRecord failed", zap.Error(err)) } - droppedBindRecord.droppedTime = time.Now() + invalidBindRecord.droppedTime = time.Now() continue } - if time.Since(droppedBindRecord.droppedTime) > 6*time.Second { - delete(dropBindRecordMap, key) + if time.Since(invalidBindRecord.droppedTime) > 6*time.Second { + delete(invalidBindRecordMap, key) } } - h.dropBindRecordMap.Store(dropBindRecordMap) + h.dropBindRecordMap.Store(invalidBindRecordMap) } -// AddDropBindRecordTask add bindRecord to dropBindRecordMap when the bindRecord need to be deleted. -func (h *BindHandle) AddDropBindRecordTask(dropBindRecord *BindRecord) { - key := dropBindRecord.OriginalSQL + ":" + dropBindRecord.Db - if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)[key]; ok { +// AddDropInvalidBindTask add bindRecord to dropBindRecordMap when the bindRecord need to be deleted. +func (h *BindHandle) AddDropInvalidBindTask(invalidBindRecord *BindRecord) { + key := invalidBindRecord.OriginalSQL + ":" + invalidBindRecord.Db + if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { return } h.dropBindRecordMap.Lock() defer h.dropBindRecordMap.Unlock() - if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)[key]; ok { + if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { return } - newMap := copyDroppedBindRecordMap(h.dropBindRecordMap.Value.Load().(map[string]*droppedBindRecord)) - newMap[key] = &droppedBindRecord{ - bindRecord: dropBindRecord, + newMap := copyInvalidBindRecordMap(h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)) + newMap[key] = &invalidBindRecordMap{ + bindRecord: invalidBindRecord, } h.dropBindRecordMap.Store(newMap) } @@ -378,8 +380,8 @@ func (c cache) copy() cache { return newCache } -func copyDroppedBindRecordMap(oldMap map[string]*droppedBindRecord) map[string]*droppedBindRecord { - newMap := make(map[string]*droppedBindRecord, len(oldMap)) +func copyInvalidBindRecordMap(oldMap map[string]*invalidBindRecordMap) map[string]*invalidBindRecordMap { + newMap := make(map[string]*invalidBindRecordMap, len(oldMap)) for k, v := range oldMap { newMap[k] = v } diff --git a/domain/domain.go b/domain/domain.go index b0c8f4914c838..008348cee7eff 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -828,7 +828,7 @@ func (do *Domain) handleInvalidBindTaskLoop() { return case <-time.After(handleInvalidTaskDuration): } - do.bindHandle.HandleDropBindRecord() + do.bindHandle.DropInvalidBindRecord() } }() } diff --git a/session/session.go b/session/session.go index 4b61570bb52bb..a20beb1f4f67d 100644 --- a/session/session.go +++ b/session/session.go @@ -1025,7 +1025,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec zap.String("sql", sql)) return nil, err } - s.handleInValidBindRecord(ctx, stmtNode) + s.handleInvalidBindRecord(ctx, stmtNode) } if isInternal { sessionExecuteCompileDurationInternal.Observe(time.Since(startTS).Seconds()) @@ -1051,7 +1051,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec return recordSets, nil } -func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.StmtNode) { +func (s *session) handleInvalidBindRecord(ctx context.Context, stmtNode ast.StmtNode) { var normdOrigSQL string switch x := stmtNode.(type) { case *ast.ExplainStmt: @@ -1092,7 +1092,7 @@ func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.Stmt err := sessionHandle.AddBindRecord(record) if err != nil { - logutil.Logger(ctx).Warn("handleInValidBindRecord failed", zap.Error(err)) + logutil.Logger(ctx).Warn("handleInvalidBindRecord failed", zap.Error(err)) } globalHandle := domain.GetDomain(s).BindHandle() @@ -1100,7 +1100,7 @@ func (s *session) handleInValidBindRecord(ctx context.Context, stmtNode ast.Stmt OriginalSQL: bindMeta.OriginalSQL, Db: bindMeta.Db, } - globalHandle.AddDropBindRecordTask(dropBindRecord) + globalHandle.AddDropInvalidBindTask(dropBindRecord) } } From 8db7550f79cb9c990b277a1fa7670e2e8e8dd223 Mon Sep 17 00:00:00 2001 From: iamzhoug37 <2541781827@qq.com> Date: Mon, 29 Apr 2019 16:30:30 +0800 Subject: [PATCH 18/18] rename --- bindinfo/handle.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 8536fb153dc3e..fc33f2cd3442c 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -63,7 +63,7 @@ type BindHandle struct { // invalidBindRecordMap indicates the invalid bind records found during querying. // A record will be deleted from this map, after 2 bind-lease, after it is dropped from the kv. - dropBindRecordMap struct { + invalidBindRecordMap struct { sync.Mutex atomic.Value } @@ -82,7 +82,7 @@ func NewBindHandle(ctx sessionctx.Context, parser *parser.Parser) *BindHandle { handle := &BindHandle{parser: parser} handle.sctx.Context = ctx handle.bindInfo.Value.Store(make(cache, 32)) - handle.dropBindRecordMap.Value.Store(make(map[string]*invalidBindRecordMap)) + handle.invalidBindRecordMap.Value.Store(make(map[string]*invalidBindRecordMap)) return handle } @@ -233,7 +233,7 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { // DropInvalidBindRecord execute the drop bindRecord task. func (h *BindHandle) DropInvalidBindRecord() { - invalidBindRecordMap := copyInvalidBindRecordMap(h.dropBindRecordMap.Load().(map[string]*invalidBindRecordMap)) + invalidBindRecordMap := copyInvalidBindRecordMap(h.invalidBindRecordMap.Load().(map[string]*invalidBindRecordMap)) for key, invalidBindRecord := range invalidBindRecordMap { if invalidBindRecord.droppedTime.IsZero() { err := h.DropBindRecord(invalidBindRecord.bindRecord) @@ -248,25 +248,25 @@ func (h *BindHandle) DropInvalidBindRecord() { delete(invalidBindRecordMap, key) } } - h.dropBindRecordMap.Store(invalidBindRecordMap) + h.invalidBindRecordMap.Store(invalidBindRecordMap) } -// AddDropInvalidBindTask add bindRecord to dropBindRecordMap when the bindRecord need to be deleted. +// AddDropInvalidBindTask add bindRecord to invalidBindRecordMap when the bindRecord need to be deleted. func (h *BindHandle) AddDropInvalidBindTask(invalidBindRecord *BindRecord) { key := invalidBindRecord.OriginalSQL + ":" + invalidBindRecord.Db - if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { + if _, ok := h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { return } - h.dropBindRecordMap.Lock() - defer h.dropBindRecordMap.Unlock() - if _, ok := h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { + h.invalidBindRecordMap.Lock() + defer h.invalidBindRecordMap.Unlock() + if _, ok := h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)[key]; ok { return } - newMap := copyInvalidBindRecordMap(h.dropBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)) + newMap := copyInvalidBindRecordMap(h.invalidBindRecordMap.Value.Load().(map[string]*invalidBindRecordMap)) newMap[key] = &invalidBindRecordMap{ bindRecord: invalidBindRecord, } - h.dropBindRecordMap.Store(newMap) + h.invalidBindRecordMap.Store(newMap) } // Size return the size of bind info cache.