Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: remove goCtx from session struct #5174

Merged
merged 8 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comment
  • Loading branch information
tiancaiamao committed Nov 21, 2017
commit 79a06b38323185e506b17c98e5df9f1f2684146e
3 changes: 1 addition & 2 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package ast

import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -182,7 +181,7 @@ type Statement interface {
OriginText() string

// Exec executes SQL and gets a Recordset.
Exec(goCtx goctx.Context, ctx context.Context) (RecordSet, error)
Exec(goCtx goctx.Context) (RecordSet, error)

// IsPrepared returns whether this statement is prepared statement.
IsPrepared() bool
Expand Down
7 changes: 3 additions & 4 deletions ddl/ddl_db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,16 @@ func (t *testExecInfo) parseSQLs(p *parser.Parser) error {
}

func (t *testExecInfo) compileSQL(idx int) (err error) {
compiler := executor.Compiler{}
for _, info := range t.sqlInfos {
c := info.cases[idx]
compiler := executor.Compiler{c.session}
se := c.session
goCtx := goctx.TODO()
se.PrepareTxnCtx(goCtx)
ctx := se.(context.Context)
executor.ResetStmtCtx(ctx, c.rawStmt)

c.stmt, err = compiler.Compile(goCtx, ctx, c.rawStmt)
c.stmt, err = compiler.Compile(goCtx, c.rawStmt)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -264,8 +264,7 @@ func (t *testExecInfo) compileSQL(idx int) (err error) {
func (t *testExecInfo) execSQL(idx int) error {
for _, sqlInfo := range t.sqlInfos {
c := sqlInfo.cases[idx]
ctx := c.session.(context.Context)
_, err := c.stmt.Exec(goctx.TODO(), ctx)
_, err := c.stmt.Exec(goctx.TODO())
if c.expectedErr != nil {
if err == nil {
err = errors.Errorf("expected error %s but got nil", c.expectedErr)
Expand Down
21 changes: 10 additions & 11 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (a *recordSet) Fields() []*ast.ResultField {
for _, col := range a.executor.Schema().Columns {
dbName := col.DBName.O
if dbName == "" && col.TblName.L != "" {
dbName = a.stmt.ctx.GetSessionVars().CurrentDB
dbName = a.stmt.Ctx.GetSessionVars().CurrentDB
}
rf := &ast.ResultField{
ColumnAsName: col.ColName,
Expand All @@ -81,13 +81,13 @@ func (a *recordSet) Next() (types.Row, error) {
}
if row == nil {
if a.stmt != nil {
a.stmt.ctx.GetSessionVars().LastFoundRows = a.stmt.ctx.GetSessionVars().StmtCtx.FoundRows()
a.stmt.Ctx.GetSessionVars().LastFoundRows = a.stmt.Ctx.GetSessionVars().StmtCtx.FoundRows()
}
return nil, nil
}

if a.stmt != nil {
a.stmt.ctx.GetSessionVars().StmtCtx.AddFoundRows(1)
a.stmt.Ctx.GetSessionVars().StmtCtx.AddFoundRows(1)
}
return row, nil
}
Expand All @@ -101,12 +101,12 @@ func (a *recordSet) NextChunk(chk *chunk.Chunk) error {
numRows := chk.NumRows()
if numRows == 0 {
if a.stmt != nil {
a.stmt.ctx.GetSessionVars().LastFoundRows = a.stmt.ctx.GetSessionVars().StmtCtx.FoundRows()
a.stmt.Ctx.GetSessionVars().LastFoundRows = a.stmt.Ctx.GetSessionVars().StmtCtx.FoundRows()
}
return nil
}
if a.stmt != nil {
a.stmt.ctx.GetSessionVars().StmtCtx.AddFoundRows(uint64(numRows))
a.stmt.Ctx.GetSessionVars().StmtCtx.AddFoundRows(uint64(numRows))
}
return nil
}
Expand Down Expand Up @@ -141,7 +141,7 @@ type ExecStmt struct {
// Text represents the origin query text.
Text string

ctx context.Context
Ctx context.Context
startTime time.Time
isPreparedStmt bool

Expand All @@ -168,10 +168,9 @@ func (a *ExecStmt) IsReadOnly() bool {
// This function builds an Executor from a plan. If the Executor doesn't return result,
// like the INSERT, UPDATE statements, it executes in this function, if the Executor returns
// result, execution is done after this function returns, in the returned ast.RecordSet Next method.
func (a *ExecStmt) Exec(goCtx goctx.Context, ctx context.Context) (ast.RecordSet, error) {
func (a *ExecStmt) Exec(goCtx goctx.Context) (ast.RecordSet, error) {
a.startTime = time.Now()
a.ctx = ctx

ctx := a.Ctx
if _, ok := a.Plan.(*plan.Analyze); ok && ctx.GetSessionVars().InRestrictedSQL {
oriStats := ctx.GetSessionVars().Systems[variable.TiDBBuildStatsConcurrency]
oriScan := ctx.GetSessionVars().DistSQLScanConcurrency
Expand Down Expand Up @@ -324,8 +323,8 @@ func (a *ExecStmt) logSlowQuery(txnTS uint64, succ bool) {
if len(sql) > cfg.Log.QueryLogMaxLen {
sql = fmt.Sprintf("%.*q(len:%d)", cfg.Log.QueryLogMaxLen, sql, len(a.Text))
}
connID := a.ctx.GetSessionVars().ConnectionID
currentDB := a.ctx.GetSessionVars().CurrentDB
connID := a.Ctx.GetSessionVars().ConnectionID
currentDB := a.Ctx.GetSessionVars().CurrentDB
logEntry := log.NewEntry(logutil.SlowQueryLogger)
logEntry.Data = log.Fields{
"connectionId": connID,
Expand Down
12 changes: 7 additions & 5 deletions executor/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,22 @@ import (

// Compiler compiles an ast.StmtNode to a physical plan.
type Compiler struct {
Ctx context.Context
}

// Compile compiles an ast.StmtNode to a physical plan.
func (c *Compiler) Compile(goCtx goctx.Context, ctx context.Context, stmtNode ast.StmtNode) (*ExecStmt, error) {
func (c *Compiler) Compile(goCtx goctx.Context, stmtNode ast.StmtNode) (*ExecStmt, error) {
if span := opentracing.SpanFromContext(goCtx); span != nil {
span1 := opentracing.StartSpan("executor.Compile", opentracing.ChildOf(span.Context()))
defer span1.Finish()
}

infoSchema := GetInfoSchema(ctx)
if err := plan.Preprocess(ctx, stmtNode, infoSchema, false); err != nil {
infoSchema := GetInfoSchema(c.Ctx)
if err := plan.Preprocess(c.Ctx, stmtNode, infoSchema, false); err != nil {
return nil, errors.Trace(err)
}

finalPlan, err := plan.Optimize(ctx, stmtNode, infoSchema)
finalPlan, err := plan.Optimize(c.Ctx, stmtNode, infoSchema)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -53,10 +54,11 @@ func (c *Compiler) Compile(goCtx goctx.Context, ctx context.Context, stmtNode as
return &ExecStmt{
InfoSchema: infoSchema,
Plan: finalPlan,
Expensive: stmtCount(stmtNode, finalPlan, ctx.GetSessionVars().InRestrictedSQL),
Expensive: stmtCount(stmtNode, finalPlan, c.Ctx.GetSessionVars().InRestrictedSQL),
Cacheable: plan.Cacheable(stmtNode),
Text: stmtNode.Text(),
ReadOnly: ast.IsReadOnly(readOnlyCheckStmt),
Ctx: c.Ctx,
}, nil
}

Expand Down
7 changes: 3 additions & 4 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1520,17 +1520,16 @@ func (s *testSuite) TestAdapterStatement(c *C) {
se, err := tidb.CreateSession(s.store)
c.Check(err, IsNil)
se.GetSessionVars().TxnCtx.InfoSchema = sessionctx.GetDomain(se).InfoSchema()
compiler := &executor.Compiler{}
ctx := se.(context.Context)
compiler := &executor.Compiler{se}
stmtNode, err := s.ParseOneStmt("select 1", "", "")
c.Check(err, IsNil)
stmt, err := compiler.Compile(goctx.TODO(), ctx, stmtNode)
stmt, err := compiler.Compile(goctx.TODO(), stmtNode)
c.Check(err, IsNil)
c.Check(stmt.OriginText(), Equals, "select 1")

stmtNode, err = s.ParseOneStmt("create table test.t (a int)", "", "")
c.Check(err, IsNil)
stmt, err = compiler.Compile(goctx.TODO(), ctx, stmtNode)
stmt, err = compiler.Compile(goctx.TODO(), stmtNode)
c.Check(err, IsNil)
c.Check(stmt.OriginText(), Equals, "create table test.t (a int)")
}
Expand Down
1 change: 1 addition & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ func CompileExecutePreparedStmt(ctx context.Context, ID uint32, args ...interfac
InfoSchema: GetInfoSchema(ctx),
Plan: execPlan,
ReadOnly: readOnly,
Ctx: ctx,
}
if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID].(*plan.Prepared); ok {
stmt.Text = prepared.Stmt.Text()
Expand Down
40 changes: 24 additions & 16 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ type session struct {
txn kv.Transaction // current transaction
txnFuture *txnFuture

// cancelFunc is used for cancelling the execution of current transaction.
cancelFunc goctx.CancelFunc

mu struct {
sync.RWMutex
values map[fmt.Stringer]interface{}

// cancelFunc is used for cancelling the execution of current transaction.
cancelFunc goctx.CancelFunc
}

store kv.Storage
Expand All @@ -144,7 +144,9 @@ type session struct {
func (s *session) Cancel() {
// TODO: How to wait for the resource to release and make sure
// it's not leak?
s.cancelFunc()
s.mu.RLock()
s.mu.cancelFunc()
s.mu.RUnlock()
}

func (s *session) cleanRetryInfo() {
Expand Down Expand Up @@ -441,7 +443,7 @@ func (s *session) retry(goCtx goctx.Context, maxCnt int, infoSchemaChanged bool)
}
s.sessionVars.StmtCtx = sr.stmtCtx
s.sessionVars.StmtCtx.ResetForRetry()
_, err = st.Exec(goCtx, s)
_, err = st.Exec(goCtx)
if err != nil {
break
}
Expand Down Expand Up @@ -703,8 +705,11 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec
span, goCtx1 := opentracing.StartSpanFromContext(goCtx, "session.Execute")
defer span.Finish()

goCtx, s.cancelFunc = goctx.WithCancel(goCtx1)
s.PrepareTxnCtx(goCtx)
goCtx2, cancelFunc := goctx.WithCancel(goCtx1)
s.mu.Lock()
s.mu.cancelFunc = cancelFunc
s.mu.Unlock()
s.PrepareTxnCtx(goCtx2)
var (
cacheKey kvcache.Key
cacheValue kvcache.Value
Expand All @@ -728,37 +733,38 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec
Expensive: cacheValue.(*plan.SQLCacheValue).Expensive,
Text: stmtNode.Text(),
ReadOnly: ast.IsReadOnly(stmtNode),
Ctx: s,
}

s.PrepareTxnCtx(goCtx)
s.PrepareTxnCtx(goCtx2)
executor.ResetStmtCtx(s, stmtNode)
if recordSets, err = s.executeStatement(goCtx, connID, stmtNode, stmt, recordSets); err != nil {
if recordSets, err = s.executeStatement(goCtx2, connID, stmtNode, stmt, recordSets); err != nil {
return nil, errors.Trace(err)
}
} else {
charset, collation := s.sessionVars.GetCharsetInfo()

// Step1: Compile query string to abstract syntax trees(ASTs).
startTS := time.Now()
stmtNodes, err := s.ParseSQL(goCtx, sql, charset, collation)
stmtNodes, err := s.ParseSQL(goCtx2, sql, charset, collation)
if err != nil {
log.Warnf("[%d] parse error:\n%v\n%s", connID, err, sql)
return nil, errors.Trace(err)
}
sessionExecuteParseDuration.Observe(time.Since(startTS).Seconds())

compiler := executor.Compiler{}
compiler := executor.Compiler{s}
for _, stmtNode := range stmtNodes {
s.PrepareTxnCtx(goCtx)
s.PrepareTxnCtx(goCtx2)

// Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt).
startTS = time.Now()
// Some executions are done in compile stage, so we reset them before compile.
executor.ResetStmtCtx(s, stmtNode)
stmt, err := compiler.Compile(goCtx, s, stmtNode)
stmt, err := compiler.Compile(goCtx2, stmtNode)
if err != nil {
log.Warnf("[%d] compile error:\n%v\n%s", connID, err, sql)
terror.Log(errors.Trace(s.RollbackTxn(goCtx)))
terror.Log(errors.Trace(s.RollbackTxn(goCtx2)))
return nil, errors.Trace(err)
}
sessionExecuteCompileDuration.Observe(time.Since(startTS).Seconds())
Expand All @@ -769,7 +775,7 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec
}

// Step4: Execute the physical plan.
if recordSets, err = s.executeStatement(goCtx, connID, stmtNode, stmt, recordSets); err != nil {
if recordSets, err = s.executeStatement(goCtx2, connID, stmtNode, stmt, recordSets); err != nil {
return nil, errors.Trace(err)
}
}
Expand Down Expand Up @@ -846,7 +852,9 @@ func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (ast.R
return nil, errors.Trace(err)
}
goCtx, cancelFunc := goctx.WithCancel(goctx.TODO())
s.cancelFunc = cancelFunc
s.mu.Lock()
s.mu.cancelFunc = cancelFunc
s.mu.Unlock()
s.PrepareTxnCtx(goCtx)
st, err := executor.CompileExecutePreparedStmt(s, stmtID, args...)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func Parse(ctx context.Context, src string) ([]ast.StmtNode, error) {

// Compile is safe for concurrent use by multiple goroutines.
func Compile(goCtx goctx.Context, ctx context.Context, stmtNode ast.StmtNode) (ast.Statement, error) {
compiler := executor.Compiler{}
stmt, err := compiler.Compile(goCtx, ctx, stmtNode)
compiler := executor.Compiler{ctx}
stmt, err := compiler.Compile(goCtx, stmtNode)
return stmt, errors.Trace(err)
}

Expand All @@ -152,7 +152,7 @@ func runStmt(goCtx goctx.Context, ctx context.Context, s ast.Statement) (ast.Rec
var err error
var rs ast.RecordSet
se := ctx.(*session)
rs, err = s.Exec(goCtx, ctx)
rs, err = s.Exec(goCtx)
span.SetTag("txn.id", se.sessionVars.TxnCtx.StartTS)
// All the history should be added here.
GetHistory(ctx).Add(0, s, se.sessionVars.StmtCtx)
Expand Down