diff --git a/session/session.go b/session/session.go index 594d57af20e97..e4e1522390865 100644 --- a/session/session.go +++ b/session/session.go @@ -2268,6 +2268,9 @@ func (s *session) preparedStmtExec(ctx context.Context, execStmt *ast.ExecuteStm is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, execStmt, is) + if err == nil { + err = sessiontxn.OptimizeWithPlanAndThenWarmUp(s, st.Plan) + } if err != nil { return nil, err } diff --git a/sessiontxn/failpoint.go b/sessiontxn/failpoint.go index b41be21165908..da63ac753870f 100644 --- a/sessiontxn/failpoint.go +++ b/sessiontxn/failpoint.go @@ -43,6 +43,9 @@ var BreakPointBeforeExecutorFirstRun = "beforeExecutorFirstRun" // Only for test var BreakPointOnStmtRetryAfterLockError = "lockErrorAndThenOnStmtRetryCalled" +// TsoRequestCount is the key for recording tso request counts in some places +var TsoRequestCount stringutil.StringerStr = "tsoRequestCount" + // AssertLockErr is used to record the lock errors we encountered // Only for test var AssertLockErr stringutil.StringerStr = "assertLockError" @@ -112,6 +115,17 @@ func AddAssertEntranceForLockError(sctx sessionctx.Context, name string) { } } +// TsoRequestCountInc is used only for test +// When it is called, there is a tso cmd request. +func TsoRequestCountInc(sctx sessionctx.Context) { + count, ok := sctx.Value(TsoRequestCount).(uint64) + if !ok { + count = 0 + } + count += 1 + sctx.SetValue(TsoRequestCount, count) +} + // ExecTestHook is used only for test. It consumes hookKey in session wait do what it gets from it. func ExecTestHook(sctx sessionctx.Context, hookKey fmt.Stringer) { c := sctx.Value(hookKey) diff --git a/sessiontxn/isolation/base.go b/sessiontxn/isolation/base.go index 572bc218f754b..4fe7b3a595dff 100644 --- a/sessiontxn/isolation/base.go +++ b/sessiontxn/isolation/base.go @@ -20,6 +20,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -408,6 +409,10 @@ func newOracleFuture(ctx context.Context, sctx sessionctx.Context, scope string) ctx = opentracing.ContextWithSpan(ctx, span1) } + failpoint.Inject("requestTsoFromPD", func() { + sessiontxn.TsoRequestCountInc(sctx) + }) + oracleStore := sctx.GetStore().GetOracle() option := &oracle.Option{TxnScope: scope} diff --git a/sessiontxn/isolation/repeatable_read.go b/sessiontxn/isolation/repeatable_read.go index 2ea4bd41c5996..a70f882758951 100644 --- a/sessiontxn/isolation/repeatable_read.go +++ b/sessiontxn/isolation/repeatable_read.go @@ -18,6 +18,7 @@ import ( "context" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/terror" @@ -107,6 +108,10 @@ func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { return errors.Trace(kv.ErrInvalidTxn) } + failpoint.Inject("RequestTsoFromPD", func() { + sessiontxn.TsoRequestCountInc(sctx) + }) + // Because the ForUpdateTS is used for the snapshot for reading data in DML. // We can avoid allocating a global TSO here to speed it up by using the local TSO. version, err := sctx.GetStore().CurrentVersion(sctx.GetSessionVars().TxnCtx.TxnScope) diff --git a/sessiontxn/txn_context_test.go b/sessiontxn/txn_context_test.go index 75a3b72f1ac38..ed41496ee2596 100644 --- a/sessiontxn/txn_context_test.go +++ b/sessiontxn/txn_context_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testfork" "github.com/pingcap/tidb/testkit/testsetup" + "github.com/pingcap/tidb/types" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) @@ -882,3 +883,98 @@ func TestOptimisticTxnRetryInPessimisticMode(t *testing.T) { } }) } + +func TestTSOCmdCountForPrepareExecute(t *testing.T) { + // This is a mock workload mocks one which discovers that the tso request count is abnormal. + // After the bug fix, the tso request count recovers, so we use this workload to record the current tso request count + // to reject future works that accidentally causes tso request increasing. + // Note, we do not record all tso requests but some typical requests. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/isolation/requestTsoFromPD", "return")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/sessiontxn/isolation/requestTsoFromPD")) + }() + store, clean := testkit.CreateMockStore(t) + defer clean() + + ctx := context.Background() + tk := testkit.NewTestKit(t, store) + sctx := tk.Session() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("drop table if exists t3") + + tk.MustExec("create table t1(id int, v int, v2 int, primary key (id), unique key uk (v))") + tk.MustExec("create table t2(id int, v int, unique key i1(v))") + tk.MustExec("create table t3(id int, v int, key i1(v))") + + sqlSelectID, _, _, _ := tk.Session().PrepareStmt("select * from t1 where id = ? for update") + sqlUpdateID, _, _, _ := tk.Session().PrepareStmt("update t1 set v = v + 10 where id = ?") + sqlInsertID1, _, _, _ := tk.Session().PrepareStmt("insert into t2 values(?, ?)") + sqlInsertID2, _, _, _ := tk.Session().PrepareStmt("insert into t3 values(?, ?)") + + tk.MustExec("insert into t1 values (1, 1, 1)") + sctx.SetValue(sessiontxn.TsoRequestCount, 0) + + for i := 1; i < 100; i++ { + tk.MustExec("begin pessimistic") + stmt, err := tk.Session().ExecutePreparedStmt(ctx, sqlSelectID, []types.Datum{types.NewDatum(1)}) + require.NoError(t, err) + require.NoError(t, stmt.Close()) + stmt, err = tk.Session().ExecutePreparedStmt(ctx, sqlUpdateID, []types.Datum{types.NewDatum(1)}) + require.NoError(t, err) + require.Nil(t, stmt) + + val := i * 10 + stmt, err = tk.Session().ExecutePreparedStmt(ctx, sqlInsertID1, []types.Datum{types.NewDatum(val), types.NewDatum(val)}) + require.NoError(t, err) + require.Nil(t, stmt) + stmt, err = tk.Session().ExecutePreparedStmt(ctx, sqlInsertID2, []types.Datum{types.NewDatum(val), types.NewDatum(val)}) + require.NoError(t, err) + require.Nil(t, stmt) + tk.MustExec("commit") + } + count := sctx.Value(sessiontxn.TsoRequestCount) + require.Equal(t, uint64(99), count) + +} + +func TestTSOCmdCountForTextSql(t *testing.T) { + // This is a mock workload mocks one which discovers that the tso request count is abnormal. + // After the bug fix, the tso request count recovers, so we use this workload to record the current tso request count + // to reject future works that accidentally causes tso request increasing. + // Note, we do not record all tso requests but some typical requests. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/isolation/requestTsoFromPD", "return")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/sessiontxn/isolation/requestTsoFromPD")) + }() + store, clean := testkit.CreateMockStore(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + sctx := tk.Session() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("drop table if exists t3") + + tk.MustExec("create table t1(id int, v int, v2 int, primary key (id), unique key uk (v))") + tk.MustExec("create table t2(id int, v int, unique key i1(v))") + tk.MustExec("create table t3(id int, v int, key i1(v))") + + tk.MustExec("insert into t1 values (1, 1, 1)") + sctx.SetValue(sessiontxn.TsoRequestCount, 0) + for i := 1; i < 100; i++ { + tk.MustExec("begin pessimistic") + tk.MustQuery("select * from t1 where id = 1 for update") + tk.MustExec("update t1 set v = v + 10 where id = 1") + val := i * 10 + tk.MustExec(fmt.Sprintf("insert into t2 values(%v, %v)", val, val)) + tk.MustExec(fmt.Sprintf("insert into t3 values(%v, %v)", val, val)) + tk.MustExec("commit") + } + count := sctx.Value(sessiontxn.TsoRequestCount) + require.Equal(t, uint64(99), count) +}