diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index a3f1e8256b449..4ef95f7f4a701 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -71,6 +71,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildLastValue(windowFuncDesc, ordinal) case ast.WindowFuncCumeDist: return buildCumeDist(ordinal, orderByCols) + case ast.WindowFuncNthValue: + return buildNthValue(windowFuncDesc, ordinal) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -374,3 +376,13 @@ func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc { r := &cumeDist{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)} return r } + +func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + // Already checked when building the function description. + nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1]) + return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth} +} diff --git a/executor/aggfuncs/func_value.go b/executor/aggfuncs/func_value.go index 18c552e7a1bed..36d6535eca97f 100644 --- a/executor/aggfuncs/func_value.go +++ b/executor/aggfuncs/func_value.go @@ -300,3 +300,50 @@ func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialR } return nil } + +type nthValue struct { + baseAggFunc + + tp *types.FieldType + nth uint64 +} + +type partialResult4NthValue struct { + seenRows uint64 + evaluator valueEvaluator +} + +func (v *nthValue) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4NthValue{evaluator: buildValueEvaluator(v.tp)}) +} + +func (v *nthValue) ResetPartialResult(pr PartialResult) { + p := (*partialResult4NthValue)(pr) + p.seenRows = 0 +} + +func (v *nthValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + if v.nth == 0 { + return nil + } + p := (*partialResult4NthValue)(pr) + numRows := uint64(len(rowsInGroup)) + if v.nth > p.seenRows && v.nth-p.seenRows <= numRows { + err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[v.nth-p.seenRows-1]) + if err != nil { + return err + } + } + p.seenRows += numRows + return nil +} + +func (v *nthValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4NthValue)(pr) + if v.nth == 0 || p.seenRows < v.nth { + chk.AppendNull(v.ordinal) + } else { + p.evaluator.appendResult(chk, v.ordinal) + } + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index 462afca70f414..76cc408a654d9 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -110,4 +110,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows("1 1 0.5", "1 2 0.5", "2 1 1", "2 2 1")) result = tk.MustQuery("select a, b, cume_dist() over(order by a, b) from t") result.Check(testkit.Rows("1 1 0.25", "1 2 0.5", "2 1 0.75", "2 2 1")) + + result = tk.MustQuery("select a, nth_value(a, null) over() from t") + result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) + result = tk.MustQuery("select a, nth_value(a, 1) over() from t") + result.Check(testkit.Rows("1 1", "1 1", "2 1", "2 1")) + result = tk.MustQuery("select a, nth_value(a, 4) over() from t") + result.Check(testkit.Rows("1 2", "1 2", "2 2", "2 2")) + result = tk.MustQuery("select a, nth_value(a, 5) over() from t") + result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 306e47ec9569d..9e8db54b5d560 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -92,7 +92,7 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { case ast.AggFuncGroupConcat: a.typeInfer4GroupConcat(ctx) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow, - ast.WindowFuncFirstValue, ast.WindowFuncLastValue: + ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue: a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx) diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index a0e629350aec6..aa436b957aab7 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -28,6 +28,13 @@ type WindowFuncDesc struct { // NewWindowFuncDesc creates a window function signature descriptor. func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { + if strings.ToLower(name) == ast.WindowFuncNthValue { + val, isNull, ok := expression.GetUint64FromConstant(args[1]) + // nth_value does not allow `0`, but allows `null`. + if !ok || (val == 0 && !isNull) { + return nil + } + } return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} } diff --git a/expression/util.go b/expression/util.go index 98220b9275fb6..82b1071b409ca 100644 --- a/expression/util.go +++ b/expression/util.go @@ -20,6 +20,7 @@ import ( "unicode" "github.com/pingcap/errors" + "github.com/pingcap/log" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -28,6 +29,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" + "go.uber.org/zap" "golang.org/x/tools/container/intsets" ) @@ -670,3 +672,34 @@ func RemoveDupExprs(ctx sessionctx.Context, exprs []Expression) []Expression { } return res } + +// GetUint64FromConstant gets a uint64 from constant expression. +func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { + con, ok := expr.(*Constant) + if !ok { + log.Warn("not a constant expression", zap.Any("value", expr)) + return 0, false, false + } + dt := con.Value + if con.DeferredExpr != nil { + var err error + dt, err = con.DeferredExpr.Eval(chunk.Row{}) + if err != nil { + log.Warn("eval deferred expr failed", zap.Error(err)) + return 0, false, false + } + } + switch dt.Kind() { + case types.KindNull: + return 0, true, true + case types.KindInt64: + val := dt.GetInt64() + if val < 0 { + return 0, false, false + } + return uint64(val), false, true + case types.KindUint64: + return dt.GetUint64(), false, true + } + return 0, false, false +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 6380cd8e0ed94..e53772deb12dd 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2783,8 +2783,9 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu return nil, nil, nil, nil, err } p = np - if col, ok := newArg.(*expression.Column); ok { - newArgList = append(newArgList, col) + switch newArg.(type) { + case *expression.Column, *expression.Constant: + newArgList = append(newArgList, newArg) continue } proj.Exprs = append(proj.Exprs, newArg) @@ -2966,6 +2967,9 @@ func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExp return nil, err } desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) + if desc == nil { + return nil, ErrWrongArguments.GenWithStackByArgs(expr.F) + } // TODO: Check if the function is aggregation function after we support more functions. desc.WrapCastForAggArgs(b.ctx) window := LogicalWindow{ diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index c7618c1ef7a29..889d4a325bb79 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -2199,6 +2199,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "select row_number() over(rows between 1 preceding and 1 following) from t", result: "TableReader(Table(t))->Window(row_number() over())->Projection", }, + { + sql: "select nth_value(a, 1.0) over() from t", + result: "[planner:1210]Incorrect arguments to nth_value", + }, + { + sql: "select nth_value(a, 0) over() from t", + result: "[planner:1210]Incorrect arguments to nth_value", + }, } s.Parser.EnableWindowFunc(true)