diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index d3bf9371e5acd..a7b282e2d2da0 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -265,7 +265,7 @@ func getColLengthTables(ctx context.Context, sctx sessionctx.Context, tableIDs . return colLengthMap, nil } -func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) { +func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) { columnLength := make(map[string]uint64, len(info.Columns)) for _, col := range info.Columns { if col.State != model.StatePublic { @@ -275,7 +275,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin if length != types.VarStorageLen { columnLength[col.Name.L] = rowCount * uint64(length) } else { - length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}] + length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID}) columnLength[col.Name.L] = length } } @@ -317,7 +317,19 @@ func invalidInfoSchemaStatCache(tblID int64) { tableStatsCache.dirtyIDs = append(tableStatsCache.dirtyIDs, tblID) } -func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) { +func (c *statsCache) GetTableRows(id int64) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.tableRows[id] +} + +func (c *statsCache) GetColLength(id tableHistID) uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.colLength[id] +} + +func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error { c.mu.Lock() defer c.mu.Unlock() ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) @@ -325,37 +337,35 @@ func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int6 if len(c.dirtyIDs) > 0 { tableRows, err := getRowCountTables(ctx, sctx, c.dirtyIDs...) if err != nil { - return nil, nil, err + return err } for id, tr := range tableRows { c.tableRows[id] = tr } colLength, err := getColLengthTables(ctx, sctx, c.dirtyIDs...) if err != nil { - return nil, nil, err + return err } for id, cl := range colLength { c.colLength[id] = cl } c.dirtyIDs = nil } - tableRows, colLength := c.tableRows, c.colLength - return tableRows, colLength, nil + return nil } tableRows, err := getRowCountTables(ctx, sctx) if err != nil { - return nil, nil, err + return err } colLength, err := getColLengthTables(ctx, sctx) if err != nil { - return nil, nil, err + return err } - c.tableRows = tableRows c.colLength = colLength c.modifyTime = time.Now() c.dirtyIDs = nil - return tableRows, colLength, nil + return nil } func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) { @@ -616,7 +626,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess } func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -660,12 +670,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) } else { for _, pi := range table.GetPartitionInfo().Definitions { - rowCount += tableRowsMap[pi.ID] - parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + piRowCnt := tableStatsCache.GetTableRows(pi.ID) + rowCount += piRowCnt + parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt) dataLength += parDataLen indexLength += parIndexLen } @@ -993,7 +1004,7 @@ func calcCharOctLength(lenInChar int, cs string) int { } func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) + err := tableStatsCache.update(ctx, sctx) if err != nil { return err } @@ -1009,8 +1020,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess var rowCount, dataLength, indexLength uint64 if table.GetPartitionInfo() == nil { - rowCount = tableRowsMap[table.ID] - dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + rowCount = tableStatsCache.GetTableRows(table.ID) + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 { avgRowLength = dataLength / rowCount @@ -1047,8 +1058,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess rows = append(rows, record) } else { for i, pi := range table.GetPartitionInfo().Definitions { - rowCount = tableRowsMap[pi.ID] - dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + rowCount = tableStatsCache.GetTableRows(pi.ID) + dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount) avgRowLength := uint64(0) if rowCount != 0 {