Skip to content

Commit

Permalink
*: fix data race in the statsCache
Browse files Browse the repository at this point in the history
Signed-off-by: Weizhen Wang <wangweizhen@pingcap.com>
  • Loading branch information
hawkingrei committed Sep 10, 2022
1 parent 12ae85f commit 5cad9dc
Showing 1 changed file with 40 additions and 26 deletions.
66 changes: 40 additions & 26 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -265,7 +266,7 @@ func getColLengthTables(ctx context.Context, sctx sessionctx.Context, tableIDs .
return colLengthMap, nil
}

func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) {
func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) {
columnLength := make(map[string]uint64, len(info.Columns))
for _, col := range info.Columns {
if col.State != model.StatePublic {
Expand All @@ -275,7 +276,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin
if length != types.VarStorageLen {
columnLength[col.Name.L] = rowCount * uint64(length)
} else {
length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}]
length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID})
columnLength[col.Name.L] = length
}
}
Expand All @@ -300,7 +301,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin

type statsCache struct {
mu sync.RWMutex
modifyTime time.Time
modifyTime atomic.Value //time.Time
tableRows map[int64]uint64
colLength map[tableHistID]uint64
dirtyIDs []int64
Expand All @@ -317,45 +318,57 @@ func invalidInfoSchemaStatCache(tblID int64) {
tableStatsCache.dirtyIDs = append(tableStatsCache.dirtyIDs, tblID)
}

func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) {
c.mu.Lock()
defer c.mu.Unlock()
func (c *statsCache) GetTableRows(id int64) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tableRows[id]
}

func (c *statsCache) GetColLength(id tableHistID) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.colLength[id]
}

func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error {
ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats)
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
if time.Since(c.modifyTime.Load().(time.Time)) < TableStatsCacheExpiry {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.dirtyIDs) > 0 {
tableRows, err := getRowCountTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return nil, nil, err
return err
}
for id, tr := range tableRows {
c.tableRows[id] = tr
}
colLength, err := getColLengthTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return nil, nil, err
return err
}
for id, cl := range colLength {
c.colLength[id] = cl
}
c.dirtyIDs = nil
}
tableRows, colLength := c.tableRows, c.colLength
return tableRows, colLength, nil
return nil
}
tableRows, err := getRowCountTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}
colLength, err := getColLengthTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}

c.mu.Lock()
c.tableRows = tableRows
c.colLength = colLength
c.modifyTime = time.Now()
c.modifyTime.Store(time.Now())
c.dirtyIDs = nil
return tableRows, colLength, nil
c.mu.Unlock()
return nil
}

func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) {
Expand Down Expand Up @@ -616,7 +629,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess
}

func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -660,12 +673,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
} else {
for _, pi := range table.GetPartitionInfo().Definitions {
rowCount += tableRowsMap[pi.ID]
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
piRowCnt := tableStatsCache.GetTableRows(pi.ID)
rowCount += piRowCnt
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt)
dataLength += parDataLen
indexLength += parIndexLen
}
Expand Down Expand Up @@ -993,7 +1007,7 @@ func calcCharOctLength(lenInChar int, cs string) int {
}

func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand All @@ -1009,8 +1023,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
avgRowLength := uint64(0)
if rowCount != 0 {
avgRowLength = dataLength / rowCount
Expand Down Expand Up @@ -1047,8 +1061,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess
rows = append(rows, record)
} else {
for i, pi := range table.GetPartitionInfo().Definitions {
rowCount = tableRowsMap[pi.ID]
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
rowCount = tableStatsCache.GetTableRows(pi.ID)
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount)

avgRowLength := uint64(0)
if rowCount != 0 {
Expand Down

0 comments on commit 5cad9dc

Please sign in to comment.