Skip to content

Commit

Permalink
*: fix data race in the statsCache (#37753)
Browse files Browse the repository at this point in the history
close #37603
  • Loading branch information
hawkingrei authored Sep 13, 2022
1 parent 506dc05 commit 04a564e
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func getColLengthTables(ctx context.Context, sctx sessionctx.Context, tableIDs .
return colLengthMap, nil
}

func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) {
func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) {
columnLength := make(map[string]uint64, len(info.Columns))
for _, col := range info.Columns {
if col.State != model.StatePublic {
Expand All @@ -275,7 +275,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin
if length != types.VarStorageLen {
columnLength[col.Name.L] = rowCount * uint64(length)
} else {
length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}]
length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID})
columnLength[col.Name.L] = length
}
}
Expand Down Expand Up @@ -317,45 +317,55 @@ func invalidInfoSchemaStatCache(tblID int64) {
tableStatsCache.dirtyIDs = append(tableStatsCache.dirtyIDs, tblID)
}

func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) {
func (c *statsCache) GetTableRows(id int64) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tableRows[id]
}

func (c *statsCache) GetColLength(id tableHistID) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.colLength[id]
}

func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats)
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
if len(c.dirtyIDs) > 0 {
tableRows, err := getRowCountTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return nil, nil, err
return err
}
for id, tr := range tableRows {
c.tableRows[id] = tr
}
colLength, err := getColLengthTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return nil, nil, err
return err
}
for id, cl := range colLength {
c.colLength[id] = cl
}
c.dirtyIDs = nil
}
tableRows, colLength := c.tableRows, c.colLength
return tableRows, colLength, nil
return nil
}
tableRows, err := getRowCountTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}
colLength, err := getColLengthTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}

c.tableRows = tableRows
c.colLength = colLength
c.modifyTime = time.Now()
c.dirtyIDs = nil
return tableRows, colLength, nil
return nil
}

func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) {
Expand Down Expand Up @@ -616,7 +626,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess
}

func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -660,12 +670,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
} else {
for _, pi := range table.GetPartitionInfo().Definitions {
rowCount += tableRowsMap[pi.ID]
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
piRowCnt := tableStatsCache.GetTableRows(pi.ID)
rowCount += piRowCnt
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt)
dataLength += parDataLen
indexLength += parIndexLen
}
Expand Down Expand Up @@ -993,7 +1004,7 @@ func calcCharOctLength(lenInChar int, cs string) int {
}

func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand All @@ -1009,8 +1020,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
avgRowLength := uint64(0)
if rowCount != 0 {
avgRowLength = dataLength / rowCount
Expand Down Expand Up @@ -1047,8 +1058,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess
rows = append(rows, record)
} else {
for i, pi := range table.GetPartitionInfo().Definitions {
rowCount = tableRowsMap[pi.ID]
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
rowCount = tableStatsCache.GetTableRows(pi.ID)
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount)

avgRowLength := uint64(0)
if rowCount != 0 {
Expand Down

0 comments on commit 04a564e

Please sign in to comment.