Skip to content

Commit

Permalink
*: fix data race in the statsCache (pingcap#37753) (pingcap#37770)
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Sep 13, 2022
1 parent 101368f commit 0d11336
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[ta
return colLengthMap, nil
}

func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) {
func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) {
columnLength := make(map[string]uint64, len(info.Columns))
for _, col := range info.Columns {
if col.State != model.StatePublic {
Expand All @@ -239,7 +239,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin
if length != types.VarStorageLen {
columnLength[col.Name.L] = rowCount * uint64(length)
} else {
length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}]
length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID})
columnLength[col.Name.L] = length
}
}
Expand Down Expand Up @@ -274,33 +274,44 @@ var tableStatsCache = &statsCache{}
// TableStatsCacheExpiry is the expiry time for table stats cache.
var TableStatsCacheExpiry = 3 * time.Second

func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) {
func (c *statsCache) GetTableRows(id int64) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tableRows[id]
}

func (c *statsCache) GetColLength(id tableHistID) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.colLength[id]
}

func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error {
c.mu.RLock()
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
tableRows, colLength := c.tableRows, c.colLength
c.mu.RUnlock()
return tableRows, colLength, nil
return nil
}
c.mu.RUnlock()

c.mu.Lock()
defer c.mu.Unlock()
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
return c.tableRows, c.colLength, nil
return nil
}
tableRows, err := getRowCountAllTable(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}
colLength, err := getColLengthAllTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}

c.tableRows = tableRows
c.colLength = colLength
c.modifyTime = time.Now()
return tableRows, colLength, nil
return nil
}

func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) {
Expand Down Expand Up @@ -505,7 +516,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess
}

func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -549,12 +560,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
} else {
for _, pi := range table.GetPartitionInfo().Definitions {
rowCount += tableRowsMap[pi.ID]
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
piRowCnt := tableStatsCache.GetTableRows(pi.ID)
rowCount += piRowCnt
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt)
dataLength += parDataLen
indexLength += parIndexLen
}
Expand Down Expand Up @@ -870,7 +882,7 @@ func calcCharOctLength(lenInChar int, cs string) int {
}

func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand All @@ -886,8 +898,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
avgRowLength := uint64(0)
if rowCount != 0 {
avgRowLength = dataLength / rowCount
Expand Down Expand Up @@ -924,8 +936,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess
rows = append(rows, record)
} else {
for i, pi := range table.GetPartitionInfo().Definitions {
rowCount = tableRowsMap[pi.ID]
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
rowCount = tableStatsCache.GetTableRows(pi.ID)
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount)

avgRowLength := uint64(0)
if rowCount != 0 {
Expand Down

0 comments on commit 0d11336

Please sign in to comment.