diff --git a/executor/adapter.go b/executor/adapter.go index b0ab16ecfcb2c..e05473765b524 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -136,13 +136,7 @@ func (a *recordSet) NewChunk() *chunk.Chunk { func (a *recordSet) Close() error { err := a.executor.Close() - // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. - a.stmt.LogSlowQuery(a.txnStartTS, a.lastErr == nil, false) - a.stmt.SummaryStmt() - sessVars := a.stmt.Ctx.GetSessionVars() - pps := types.CloneRow(sessVars.PreparedParams) - sessVars.PrevStmt = FormatSQL(a.stmt.OriginText(), pps) - a.stmt.logAudit() + a.stmt.CloseRecordSet(a.txnStartTS, a.lastErr) return err } @@ -338,6 +332,7 @@ type chunkRowRecordSet struct { idx int fields []*ast.ResultField e Executor + stmt *ExecStmt } func (c *chunkRowRecordSet) Fields() []*ast.ResultField { @@ -358,6 +353,7 @@ func (c *chunkRowRecordSet) NewChunk() *chunk.Chunk { } func (c *chunkRowRecordSet) Close() error { + c.stmt.CloseRecordSet(c.stmt.Ctx.GetSessionVars().TxnCtx.StartTS, nil) return nil } @@ -389,7 +385,7 @@ func (a *ExecStmt) runPessimisticSelectForUpdate(ctx context.Context, e Executor } if req.NumRows() == 0 { fields := schema2ResultFields(e.Schema(), a.Ctx.GetSessionVars().CurrentDB) - return &chunkRowRecordSet{rows: rows, fields: fields, e: e}, nil + return &chunkRowRecordSet{rows: rows, fields: fields, e: e, stmt: a}, nil } iter := chunk.NewIterator4Chunk(req) for r := iter.Begin(); r != iter.End(); r = iter.Next() { @@ -686,6 +682,17 @@ func FormatSQL(sql string, pps variable.PreparedParams) stringutil.StringerFunc } } +// CloseRecordSet will finish the execution of current statement and do some record work +func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { + // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. + a.LogSlowQuery(txnStartTS, lastErr == nil, false) + a.SummaryStmt() + sessVars := a.Ctx.GetSessionVars() + pps := types.CloneRow(sessVars.PreparedParams) + sessVars.PrevStmt = FormatSQL(a.OriginText(), pps) + a.logAudit() +} + // LogSlowQuery is used to print the slow query in the log files. func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { sessVars := a.Ctx.GetSessionVars()