diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index ebc13d6f6f245..6109ea611a367 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -229,7 +229,7 @@ func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[ta return colLengthMap, nil } -func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) { +func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) { columnLength := make(map[string]uint64, len(info.Columns)) for _, col := range info.Columns { if col.State != model.StatePublic { @@ -239,7 +239,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin if length != types.VarStorageLen { columnLength[col.Name.L] = rowCount * uint64(length) } else { - length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}] + length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID}) columnLength[col.Name.L] = length } } @@ -274,33 +274,44 @@ var tableStatsCache = &statsCache{} // TableStatsCacheExpiry is the expiry time for table stats cache. var TableStatsCacheExpiry = 3 * time.Second -func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) { +func (c *statsCache) GetTableRows(id int64) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.tableRows[id] +} + +func (c *statsCache) GetColLength(id tableHistID) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.colLength[id] +} + +func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error { c.mu.RLock() if time.Since(c.modifyTime) < TableStatsCacheExpiry { - tableRows, colLength := c.tableRows, c.colLength c.mu.RUnlock() - return tableRows, colLength, nil + return nil } c.mu.RUnlock() c.mu.Lock() defer c.mu.Unlock() if time.Since(c.modifyTime) < TableStatsCacheExpiry { - return c.tableRows, c.colLength, nil + return nil } tableRows, err := getRowCountAllTable(ctx, sctx) if err != nil { - return nil, nil, err + return err } colLength, err := getColLengthAllTables(ctx, sctx) if err != nil { - return nil, nil, err + return err } c.tableRows = tableRows c.colLength = colLength c.modifyTime = time.Now() - return tableRows, colLength, nil + return nil } func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) { @@ -505,7 +516,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess } func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -549,12 +560,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) } else { for _, pi := range table.GetPartitionInfo().Definitions { - rowCount += tableRowsMap[pi.ID] - parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + piRowCnt := tableStatsCache.GetTableRows(pi.ID) + rowCount += piRowCnt + parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt) dataLength += parDataLen indexLength += parIndexLen } @@ -870,7 +882,7 @@ func calcCharOctLength(lenInChar int, cs string) int { } func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -886,8 +898,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 { avgRowLength = dataLength / rowCount @@ -924,8 +936,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess rows = append(rows, record) } else { for i, pi := range table.GetPartitionInfo().Definitions { - rowCount = tableRowsMap[pi.ID] - dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + rowCount = tableStatsCache.GetTableRows(pi.ID) + dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 {