diff --git a/expression/builtin_op.go b/expression/builtin_op.go index c7af218cdb231..096f817cc5ab5 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -712,16 +712,15 @@ func (c *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow boo // typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow, // typerInfer will infers the return type as types.ETDecimal, not types.ETInt. -func (c *unaryMinusFunctionClass) typeInfer(ctx sessionctx.Context, argExpr Expression) (types.EvalType, bool) { +func (c *unaryMinusFunctionClass) typeInfer(argExpr Expression) (types.EvalType, bool) { tp := argExpr.GetType().EvalType() if tp != types.ETInt && tp != types.ETDecimal { tp = types.ETReal } - sc := ctx.GetSessionVars().StmtCtx overflow := false // TODO: Handle float overflow. - if arg, ok := argExpr.(*Constant); sc.InSelectStmt && ok && tp == types.ETInt { + if arg, ok := argExpr.(*Constant); ok && tp == types.ETInt { overflow = c.handleIntOverflow(arg) if overflow { tp = types.ETDecimal @@ -736,7 +735,7 @@ func (c *unaryMinusFunctionClass) getFunction(ctx sessionctx.Context, args []Exp } argExpr, argExprTp := args[0], args[0].GetType() - _, intOverflow := c.typeInfer(ctx, argExpr) + _, intOverflow := c.typeInfer(argExpr) var bf baseBuiltinFunc switch argExprTp.EvalType() { diff --git a/expression/integration_test.go b/expression/integration_test.go index 7aad56f99c9bc..79d2352b96752 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2098,6 +2098,12 @@ func (s *testIntegrationSuite) TestOpBuiltin(c *C) { // for unaryPlus result = tk.MustQuery(`select +1, +0, +(-9), +(-0.001), +0.999, +null, +"aaa"`) result.Check(testkit.Rows("1 0 -9 -0.001 0.999 aaa")) + // for unaryMinus + tk.MustExec("drop table if exists f") + tk.MustExec("create table f(a decimal(65,0))") + tk.MustExec("insert into f value (-17000000000000000000)") + result = tk.MustQuery("select a from f") + result.Check(testkit.Rows("-17000000000000000000")) } func (s *testIntegrationSuite) TestDatetimeOverflow(c *C) {