Skip to content

Commit

Permalink
cherry pick pingcap#37753 to release-6.1
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
hawkingrei authored and ti-srebot committed Sep 13, 2022
1 parent 101368f commit 93c8df8
Showing 1 changed file with 64 additions and 15 deletions.
79 changes: 64 additions & 15 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[ta
return colLengthMap, nil
}

func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) {
func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64) (uint64, uint64) {
columnLength := make(map[string]uint64, len(info.Columns))
for _, col := range info.Columns {
if col.State != model.StatePublic {
Expand All @@ -239,7 +239,7 @@ func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uin
if length != types.VarStorageLen {
columnLength[col.Name.L] = rowCount * uint64(length)
} else {
length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}]
length := tableStatsCache.GetColLength(tableHistID{tableID: physicalID, histID: col.ID})
columnLength[col.Name.L] = length
}
}
Expand Down Expand Up @@ -274,12 +274,56 @@ var tableStatsCache = &statsCache{}
// TableStatsCacheExpiry is the expiry time for table stats cache.
var TableStatsCacheExpiry = 3 * time.Second

<<<<<<< HEAD
func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) {
c.mu.RLock()
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
tableRows, colLength := c.tableRows, c.colLength
c.mu.RUnlock()
return tableRows, colLength, nil
=======
func invalidInfoSchemaStatCache(tblID int64) {
tableStatsCache.mu.Lock()
defer tableStatsCache.mu.Unlock()
tableStatsCache.dirtyIDs = append(tableStatsCache.dirtyIDs, tblID)
}

func (c *statsCache) GetTableRows(id int64) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tableRows[id]
}

func (c *statsCache) GetColLength(id tableHistID) uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.colLength[id]
}

func (c *statsCache) update(ctx context.Context, sctx sessionctx.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats)
if time.Since(c.modifyTime) < TableStatsCacheExpiry {
if len(c.dirtyIDs) > 0 {
tableRows, err := getRowCountTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return err
}
for id, tr := range tableRows {
c.tableRows[id] = tr
}
colLength, err := getColLengthTables(ctx, sctx, c.dirtyIDs...)
if err != nil {
return err
}
for id, cl := range colLength {
c.colLength[id] = cl
}
c.dirtyIDs = nil
}
return nil
>>>>>>> 04a564ee4... *: fix data race in the statsCache (#37753)
}
c.mu.RUnlock()

Expand All @@ -290,17 +334,21 @@ func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int6
}
tableRows, err := getRowCountAllTable(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}
colLength, err := getColLengthAllTables(ctx, sctx)
if err != nil {
return nil, nil, err
return err
}

c.tableRows = tableRows
c.colLength = colLength
c.modifyTime = time.Now()
<<<<<<< HEAD
return tableRows, colLength, nil
=======
c.dirtyIDs = nil
return nil
>>>>>>> 04a564ee4... *: fix data race in the statsCache (#37753)
}

func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) {
Expand Down Expand Up @@ -505,7 +553,7 @@ func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sess
}

func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -549,12 +597,13 @@ func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionc

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
} else {
for _, pi := range table.GetPartitionInfo().Definitions {
rowCount += tableRowsMap[pi.ID]
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
piRowCnt := tableStatsCache.GetTableRows(pi.ID)
rowCount += piRowCnt
parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, piRowCnt)
dataLength += parDataLen
indexLength += parIndexLen
}
Expand Down Expand Up @@ -870,7 +919,7 @@ func calcCharOctLength(lenInChar int, cs string) int {
}

func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error {
tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx)
err := tableStatsCache.update(ctx, sctx)
if err != nil {
return err
}
Expand All @@ -886,8 +935,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess

var rowCount, dataLength, indexLength uint64
if table.GetPartitionInfo() == nil {
rowCount = tableRowsMap[table.ID]
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap)
rowCount = tableStatsCache.GetTableRows(table.ID)
dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount)
avgRowLength := uint64(0)
if rowCount != 0 {
avgRowLength = dataLength / rowCount
Expand Down Expand Up @@ -924,8 +973,8 @@ func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sess
rows = append(rows, record)
} else {
for i, pi := range table.GetPartitionInfo().Definitions {
rowCount = tableRowsMap[pi.ID]
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap)
rowCount = tableStatsCache.GetTableRows(pi.ID)
dataLength, indexLength = getDataAndIndexLength(table, pi.ID, rowCount)

avgRowLength := uint64(0)
if rowCount != 0 {
Expand Down

0 comments on commit 93c8df8

Please sign in to comment.