diff --git a/br/pkg/gluetidb/glue.go b/br/pkg/gluetidb/glue.go index 45c8d84862351..06af5615ff451 100644 --- a/br/pkg/gluetidb/glue.go +++ b/br/pkg/gluetidb/glue.go @@ -61,6 +61,10 @@ type tidbSession struct { // GetDomain implements glue.Glue. func (Glue) GetDomain(store kv.Storage) (*domain.Domain, error) { + initStatsSe, err := session.CreateSession(store) + if err != nil { + return nil, errors.Trace(err) + } se, err := session.CreateSession(store) if err != nil { return nil, errors.Trace(err) @@ -74,7 +78,7 @@ func (Glue) GetDomain(store kv.Storage) (*domain.Domain, error) { return nil, err } // create stats handler for backup and restore. - err = dom.UpdateTableStatsLoop(se) + err = dom.UpdateTableStatsLoop(se, initStatsSe) if err != nil { return nil, errors.Trace(err) } diff --git a/domain/domain.go b/domain/domain.go index 1b53789dd0535..977694d773998 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1836,8 +1836,8 @@ func (do *Domain) StatsHandle() *handle.Handle { } // CreateStatsHandle is used only for test. -func (do *Domain) CreateStatsHandle(ctx sessionctx.Context) error { - h, err := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) +func (do *Domain) CreateStatsHandle(ctx, initStatsCtx sessionctx.Context) error { + h, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) if err != nil { return err } @@ -1900,8 +1900,8 @@ func (do *Domain) SetupAnalyzeExec(ctxs []sessionctx.Context) { } // LoadAndUpdateStatsLoop loads and updates stats info. -func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context) error { - if err := do.UpdateTableStatsLoop(ctxs[0]); err != nil { +func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context, initStatsCtx sessionctx.Context) error { + if err := do.UpdateTableStatsLoop(ctxs[0], initStatsCtx); err != nil { return err } do.StartLoadStatsSubWorkers(ctxs[1:]) @@ -1911,9 +1911,9 @@ func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context) error { // UpdateTableStatsLoop creates a goroutine loads stats info and updates stats info in a loop. // It will also start a goroutine to analyze tables automatically. // It should be called only once in BootstrapSession. -func (do *Domain) UpdateTableStatsLoop(ctx sessionctx.Context) error { +func (do *Domain) UpdateTableStatsLoop(ctx, initStatsCtx sessionctx.Context) error { ctx.GetSessionVars().InRestrictedSQL = true - statsHandle, err := handle.NewHandle(ctx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) + statsHandle, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.ServerID) if err != nil { return err } diff --git a/planner/core/mock.go b/planner/core/mock.go index 472c980fa3777..93d389352e3af 100644 --- a/planner/core/mock.go +++ b/planner/core/mock.go @@ -401,9 +401,13 @@ func MockContext() sessionctx.Context { ctx.Store = &mock.Store{ Client: &mock.Client{}, } + initStatsCtx := mock.NewContext() + initStatsCtx.Store = &mock.Store{ + Client: &mock.Client{}, + } ctx.GetSessionVars().CurrentDB = "test" do := domain.NewMockDomain() - if err := do.CreateStatsHandle(ctx); err != nil { + if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil { panic(fmt.Sprintf("create mock context panic: %+v", err)) } domain.BindDomain(ctx, do) diff --git a/session/session.go b/session/session.go index 641af7164e616..2b77540b238c0 100644 --- a/session/session.go +++ b/session/session.go @@ -3416,7 +3416,11 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { for i := 0; i < cnt; i++ { subCtxs[i] = sessionctx.Context(syncStatsCtxs[i]) } - if err = dom.LoadAndUpdateStatsLoop(subCtxs); err != nil { + initStatsCtx, err := createSession(store) + if err != nil { + return nil, err + } + if err = dom.LoadAndUpdateStatsLoop(subCtxs, initStatsCtx); err != nil { return nil, err } diff --git a/statistics/handle/bootstrap.go b/statistics/handle/bootstrap.go index edbf04c532f28..05e971488b360 100644 --- a/statistics/handle/bootstrap.go +++ b/statistics/handle/bootstrap.go @@ -63,7 +63,7 @@ func (h *Handle) initStatsMeta4Chunk(is infoschema.InfoSchema, cache *statsCache func (h *Handle) initStatsMeta(is infoschema.InfoSchema) (statsCache, error) { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) sql := "select HIGH_PRIORITY version, table_id, modify_count, count from mysql.stats_meta" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rc, err := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return statsCache{}, errors.Trace(err) } @@ -167,7 +167,7 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache *stat func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, cache *statsCache) error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) sql := "select HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rc, err := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } @@ -214,7 +214,7 @@ func (h *Handle) initStatsTopN4Chunk(cache *statsCache, iter *chunk.Iterator4Chu func (h *Handle) initStatsTopN(cache *statsCache) error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) sql := "select HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rc, err := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } @@ -263,7 +263,7 @@ func (h *Handle) initStatsFMSketch4Chunk(cache *statsCache, iter *chunk.Iterator func (h *Handle) initStatsFMSketch(cache *statsCache) error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) sql := "select HIGH_PRIORITY table_id, is_index, hist_id, value from mysql.stats_fm_sketch" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rc, err := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } @@ -357,7 +357,7 @@ func (h *Handle) initTopNCountSum(tableID, colID int64) (int64, error) { func (h *Handle) initStatsBuckets(cache *statsCache) error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) sql := "select HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets order by table_id, is_index, hist_id, bucket_id" - rc, err := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) + rc, err := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql) if err != nil { return errors.Trace(err) } @@ -398,15 +398,13 @@ func (h *Handle) initStatsBuckets(cache *statsCache) error { func (h *Handle) InitStats(is infoschema.InfoSchema) (err error) { loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - h.mu.Lock() defer func() { - _, err1 := h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "commit") + _, err1 := h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "commit") if err == nil && err1 != nil { err = err1 } - h.mu.Unlock() }() - _, err = h.mu.ctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "begin") + _, err = h.initStatsCtx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, "begin") if err != nil { return err } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index ab900ffc43392..79a5382779208 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -65,11 +65,19 @@ const ( // Handle can update stats info periodically. type Handle struct { + + // initStatsCtx is the ctx only used for initStats + initStatsCtx sessionctx.Context + mu struct { sync.RWMutex ctx sessionctx.Context // rateMap contains the error rate delta from feedback. rateMap errorRateDeltaMap + } + + schemaMu struct { + sync.RWMutex // pid2tid is the map from partition ID to table ID. pid2tid map[int64]int64 // schemaVersion is the version of information schema when `pid2tid` is built. @@ -460,7 +468,7 @@ type sessionPool interface { } // NewHandle creates a Handle for update stats. -func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool, tracker sessionctx.SysProcTracker, serverIDGetter func() uint64) (*Handle, error) { +func NewHandle(ctx, initStatsCtx sessionctx.Context, lease time.Duration, pool sessionPool, tracker sessionctx.SysProcTracker, serverIDGetter func() uint64) (*Handle, error) { cfg := config.GetGlobalConfig() handle := &Handle{ ddlEventCh: make(chan *ddlUtil.Event, 1000), @@ -470,6 +478,7 @@ func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool, tr sysProcTracker: tracker, serverIDGetter: serverIDGetter, } + handle.initStatsCtx = initStatsCtx handle.lease.Store(lease) handle.statsCache.memTracker = memory.NewTracker(memory.LabelForStatsCache, -1) handle.mu.ctx = ctx @@ -933,11 +942,13 @@ func (h *Handle) mergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchS } func (h *Handle) getTableByPhysicalID(is infoschema.InfoSchema, physicalID int64) (table.Table, bool) { - if is.SchemaMetaVersion() != h.mu.schemaVersion { - h.mu.schemaVersion = is.SchemaMetaVersion() - h.mu.pid2tid = buildPartitionID2TableID(is) + h.schemaMu.Lock() + defer h.schemaMu.Unlock() + if is.SchemaMetaVersion() != h.schemaMu.schemaVersion { + h.schemaMu.schemaVersion = is.SchemaMetaVersion() + h.schemaMu.pid2tid = buildPartitionID2TableID(is) } - if id, ok := h.mu.pid2tid[physicalID]; ok { + if id, ok := h.schemaMu.pid2tid[physicalID]; ok { return is.TableByID(id) } return is.TableByID(physicalID) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index ac9936bed11fe..2b0669033f8c9 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -344,6 +344,7 @@ func TestDurationToTS(t *testing.T) { func TestVersion(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) + testKit2 := testkit.NewTestKit(t, store) testKit := testkit.NewTestKit(t, store) testKit.MustExec("use test") testKit.MustExec("create table t1 (c1 int, c2 int)") @@ -353,7 +354,7 @@ func TestVersion(t *testing.T) { tbl1, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) require.NoError(t, err) tableInfo1 := tbl1.Meta() - h, err := handle.NewHandle(testKit.Session(), time.Millisecond, do.SysSessionPool(), do.SysProcTracker(), do.ServerID) + h, err := handle.NewHandle(testKit.Session(), testKit2.Session(), time.Millisecond, do.SysSessionPool(), do.SysProcTracker(), do.ServerID) require.NoError(t, err) unit := oracle.ComposeTS(1, 0) testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", 2*unit, tableInfo1.ID)