From 93c8df8233c6aeb2acdab8215f999d48a456e03c Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 13 Sep 2022 12:20:59 +0800 Subject: [PATCH] cherry pick #37753 to release-6.1 Signed-off-by: ti-srebot --- executor/infoschema_reader.go | 79 ++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index ebc13d6f6f245..a82816969f533 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -229,7 +229,7 @@ func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[ta return colLengthMap, nil } -func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) { +func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) { columnLength := make(map[string]uint64, len(info.Columns)) for _, col := range info.Columns { if col.State != model.StatePublic { @@ -239,7 +239,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin if length != types.VarStorageLen { columnLength[col.Name.L] = rowCount * uint64(length) } else { - length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}] + length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID}) columnLength[col.Name.L] = length } } @@ -274,12 +274,56 @@ var tableStatsCache = &statsCache{} // TableStatsCacheExpiry is the expiry time for table stats cache. var TableStatsCacheExpiry = 3 * time.Second +<<<<<<< HEAD func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) { c.mu.RLock() if time.Since(c.modifyTime) < TableStatsCacheExpiry { tableRows, colLength := c.tableRows, c.colLength c.mu.RUnlock() return tableRows, colLength, nil +======= +func invalidInfoSchemaStatCache(tblID int64) { + tableStatsCache.mu.Lock() + defer tableStatsCache.mu.Unlock() + tableStatsCache.dirtyIDs = append(tableStatsCache.dirtyIDs, tblID) +} + +func (c *statsCache) GetTableRows(id int64) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.tableRows[id] +} + +func (c *statsCache) GetColLength(id tableHistID) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.colLength[id] +} + +func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + if time.Since(c.modifyTime) < TableStatsCacheExpiry { + if len(c.dirtyIDs) > 0 { + tableRows, err := getRowCountTables(ctx, sctx, c.dirtyIDs...) + if err != nil { + return err + } + for id, tr := range tableRows { + c.tableRows[id] = tr + } + colLength, err := getColLengthTables(ctx, sctx, c.dirtyIDs...) + if err != nil { + return err + } + for id, cl := range colLength { + c.colLength[id] = cl + } + c.dirtyIDs = nil + } + return nil +>>>>>>> 04a564ee4... *: fix data race in the statsCache (#37753) } c.mu.RUnlock() @@ -290,17 +334,21 @@ func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int6 } tableRows, err := getRowCountAllTable(ctx, sctx) if err != nil { - return nil, nil, err + return err } colLength, err := getColLengthAllTables(ctx, sctx) if err != nil { - return nil, nil, err + return err } - c.tableRows = tableRows c.colLength = colLength c.modifyTime = time.Now() +<<<<<<< HEAD return tableRows, colLength, nil +======= + c.dirtyIDs = nil + return nil +>>>>>>> 04a564ee4... *: fix data race in the statsCache (#37753) } func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) { @@ -505,7 +553,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess } func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -549,12 +597,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) } else { for _, pi := range table.GetPartitionInfo().Definitions { - rowCount += tableRowsMap[pi.ID] - parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + piRowCnt := tableStatsCache.GetTableRows(pi.ID) + rowCount += piRowCnt + parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt) dataLength += parDataLen indexLength += parIndexLen } @@ -870,7 +919,7 @@ func calcCharOctLength(lenInChar int, cs string) int { } func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -886,8 +935,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 { avgRowLength = dataLength / rowCount @@ -924,8 +973,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess rows = append(rows, record) } else { for i, pi := range table.GetPartitionInfo().Definitions { - rowCount = tableRowsMap[pi.ID] - dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + rowCount = tableStatsCache.GetTableRows(pi.ID) + dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 {