diff --git a/planner/cascades/enforcer_rules.go b/planner/cascades/enforcer_rules.go index b965ca6b474fd..8c96d955d0b94 100644 --- a/planner/cascades/enforcer_rules.go +++ b/planner/cascades/enforcer_rules.go @@ -79,7 +79,7 @@ func (e *OrderEnforcer) OnEnforce(reqProp *property.PhysicalProperty, child memo func (e *OrderEnforcer) GetEnforceCost(g *memo.Group) float64 { // We need a SessionCtx to calculate the cost of a sort. sctx := g.Equivalents.Front().Value.(*memo.GroupExpr).ExprNode.SCtx() - sort := plannercore.PhysicalSort{}.Init(sctx, nil, 0, nil) - cost := sort.GetCost(g.Prop.Stats.RowCount) + sort := plannercore.PhysicalSort{}.Init(sctx, g.Prop.Stats, 0, nil) + cost := sort.GetCost(g.Prop.Stats.RowCount, g.Prop.Schema) return cost } diff --git a/planner/core/task.go b/planner/core/task.go index 35bfe2195f323..1a079e58ecc32 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -425,12 +426,12 @@ func (p *PhysicalIndexJoin) GetCost(outerTask, innerTask task) float64 { return outerTask.cost() + innerPlanCost + cpuCost + memoryCost } -func (p *PhysicalHashJoin) avgRowSize(inner PhysicalPlan) (size float64) { - if inner.statsInfo().HistColl != nil { - size = inner.statsInfo().HistColl.GetAvgRowSizeListInDisk(inner.Schema().Columns) +func getAvgRowSize(stats *property.StatsInfo, schema *expression.Schema) (size float64) { + if stats.HistColl != nil { + size = stats.HistColl.GetAvgRowSizeListInDisk(schema.Columns) } else { // Estimate using just the type info. - cols := inner.Schema().Columns + cols := schema.Columns for _, col := range cols { size += float64(chunk.EstimateTypeWidth(col.GetType())) } @@ -450,7 +451,7 @@ func (p *PhysicalHashJoin) GetCost(lCnt, rCnt float64) float64 { sessVars := p.ctx.GetSessionVars() oomUseTmpStorage := config.GetGlobalConfig().OOMUseTmpStorage memQuota := sessVars.StmtCtx.MemTracker.GetBytesLimit() // sessVars.MemQuotaQuery && hint - rowSize := p.avgRowSize(build) + rowSize := getAvgRowSize(build.statsInfo(), build.Schema()) spill := oomUseTmpStorage && memQuota > 0 && rowSize*buildCnt > float64(memQuota) // Cost of building hash table. cpuCost := buildCnt * sessVars.CPUFactor @@ -842,18 +843,31 @@ func (p *PhysicalTopN) allColsFromSchema(schema *expression.Schema) bool { } // GetCost computes the cost of in memory sort. -func (p *PhysicalSort) GetCost(count float64) float64 { +func (p *PhysicalSort) GetCost(count float64, schema *expression.Schema) float64 { if count < 2.0 { count = 2.0 } sessVars := p.ctx.GetSessionVars() - return count*math.Log2(count)*sessVars.CPUFactor + count*sessVars.MemoryFactor + cpuCost := count * math.Log2(count) * sessVars.CPUFactor + memoryCost := count * sessVars.MemoryFactor + + oomUseTmpStorage := config.GetGlobalConfig().OOMUseTmpStorage + memQuota := sessVars.StmtCtx.MemTracker.GetBytesLimit() // sessVars.MemQuotaQuery && hint + rowSize := getAvgRowSize(p.statsInfo(), schema) + spill := oomUseTmpStorage && memQuota > 0 && rowSize*count > float64(memQuota) + diskCost := count * sessVars.DiskFactor * rowSize + if !spill { + diskCost = 0 + } else { + memoryCost *= float64(memQuota) / (rowSize * count) + } + return cpuCost + memoryCost + diskCost } func (p *PhysicalSort) attach2Task(tasks ...task) task { t := tasks[0].copy() t = attachPlan2Task(p, t) - t.addCost(p.GetCost(t.count())) + t.addCost(p.GetCost(t.count(), p.Schema())) return t } diff --git a/planner/implementation/sort.go b/planner/implementation/sort.go index 9e2e8a5e2a07c..92025048954e4 100644 --- a/planner/implementation/sort.go +++ b/planner/implementation/sort.go @@ -34,7 +34,7 @@ func NewSortImpl(sort *plannercore.PhysicalSort) *SortImpl { func (impl *SortImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 { cnt := math.Min(children[0].GetPlan().Stats().RowCount, impl.plan.GetChildReqProps(0).ExpectedCnt) sort := impl.plan.(*plannercore.PhysicalSort) - impl.cost = sort.GetCost(cnt) + children[0].GetCost() + impl.cost = sort.GetCost(cnt, children[0].GetPlan().Schema()) + children[0].GetCost() return impl.cost }