Skip to content

Commit

Permalink
expression: BuildContext read location from EvalContext instead o…
Browse files Browse the repository at this point in the history
…f `SessionVars` (pingcap#52451)

ref pingcap#52366
  • Loading branch information
lcwangchao authored and 3AceShowHand committed Apr 16, 2024
1 parent a1ac4bb commit eaa6926
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra
val := w.rowMap[w.oldColInfo.ID]
col := w.newColInfo
if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx.GetExprCtx(), col.GetType(), col.GetDecimal()); err == nil {
if v, err := expression.GetTimeCurrentTimestamp(w.sessCtx.GetExprCtx().GetEvalCtx(), col.GetType(), col.GetDecimal()); err == nil {
// convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set.
w.rowMap[w.oldColInfo.ID] = v
}
Expand Down
1 change: 1 addition & 0 deletions pkg/ddl/internal/session/session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func (sg *Pool) Get() (sessionctx.Context, error) {
}
ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusAutocommit, true)
ctx.GetSessionVars().InRestrictedSQL = true
ctx.GetSessionVars().StmtCtx.SetTimeZone(ctx.GetSessionVars().Location())
infosync.StoreInternalSession(ctx)
return ctx, nil
}
Expand Down
12 changes: 2 additions & 10 deletions pkg/expression/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/expression/context"
"github.com/pingcap/tidb/pkg/expression/contextopt"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
Expand Down Expand Up @@ -96,16 +95,9 @@ func wrapEvalAssert(ctx EvalContext, fn builtinFunc) (ret *assertionEvalContext)
}

func checkEvalCtx(ctx EvalContext) {
loc := ctx.Location().String()
tc := ctx.TypeCtx()
tcLoc := tc.Location().String()
intest.Assert(loc == tcLoc, "location mismatch, evalCtx: %s, typeCtx: %s", loc, tcLoc)
if ctx.GetOptionalPropSet().Contains(context.OptPropSessionVars) {
vars, err := contextopt.SessionVarsPropReader{}.GetSessionVars(ctx)
intest.AssertNoError(err)
stmtLoc := vars.StmtCtx.TimeZone().String()
intest.Assert(loc == stmtLoc, "location mismatch, evalCtx: %s, stmtCtx: %s", loc, stmtLoc)
}
intest.Assert(ctx.Location() == tc.Location(),
"location is not equal, ctxLoc: %s, tcLoc: %s", ctx.Location(), tc.Location())
}

func (ctx *assertionEvalContext) GetOptionalPropProvider(key OptionalEvalPropKey) (OptionalEvalPropProvider, bool) {
Expand Down
12 changes: 12 additions & 0 deletions pkg/expression/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/mathutil"
)

Expand Down Expand Up @@ -109,3 +110,14 @@ type ExprContext interface {
// GetGroupConcatMaxLen returns the value of the 'group_concat_max_len' system variable.
GetGroupConcatMaxLen() uint64
}

// AssertLocationWithSessionVars asserts the location in the context and session variables are the same.
// It is only used for testing.
func AssertLocationWithSessionVars(ctxLoc *time.Location, vars *variable.SessionVars) {
varsLoc := vars.Location()
stmtLoc := vars.StmtCtx.TimeZone()
intest.Assert(ctxLoc == varsLoc && ctxLoc == stmtLoc,
"location mismatch, ctxLoc: %s, varsLoc: %s, stmtLoc: %s",
ctxLoc.String(), varsLoc.String(), stmtLoc.String(),
)
}
8 changes: 6 additions & 2 deletions pkg/expression/contextimpl/sessionctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ func (ctx *SessionEvalContext) SQLMode() mysql.SQLMode {
}

// TypeCtx returns the types.Context
func (ctx *SessionEvalContext) TypeCtx() types.Context {
return ctx.sctx.GetSessionVars().StmtCtx.TypeCtx()
func (ctx *SessionEvalContext) TypeCtx() (tc types.Context) {
tc = ctx.sctx.GetSessionVars().StmtCtx.TypeCtx()
if intest.InTest {
exprctx.AssertLocationWithSessionVars(tc.Location(), ctx.sctx.GetSessionVars())
}
return
}

// ErrCtx returns the errctx.Context
Expand Down
5 changes: 4 additions & 1 deletion pkg/expression/contextimpl/sessionctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ func TestSessionEvalContextOptProps(t *testing.T) {
require.Equal(t, ctx.GetSessionVars().ActiveRoles, roles)

// test for OptPropSessionVars
gotVars := getProvider[*contextopt.SessionVarsPropProvider](t, impl, context.OptPropSessionVars).GetSessionVars()
sessVarsProvider := getProvider[*contextopt.SessionVarsPropProvider](t, impl, context.OptPropSessionVars)
require.NotNil(t, sessVarsProvider)
gotVars, err := contextopt.SessionVarsPropReader{}.GetSessionVars(impl)
require.NoError(t, err)
require.Same(t, ctx.GetSessionVars(), gotVars)

// test for OptPropAdvisoryLock
Expand Down
9 changes: 7 additions & 2 deletions pkg/expression/contextopt/optional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package contextopt

import (
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/expression/context"
Expand All @@ -39,6 +40,10 @@ type mockEvalCtx struct {
props OptionalEvalPropProviders
}

func (ctx *mockEvalCtx) Location() *time.Location {
return time.UTC
}

func (ctx *mockEvalCtx) GetOptionalPropProvider(
key context.OptionalEvalPropKey,
) (context.OptionalEvalPropProvider, bool) {
Expand Down Expand Up @@ -92,15 +97,15 @@ func TestOptionalEvalPropProviders(t *testing.T) {
}
case context.OptPropSessionVars:
vars := variable.NewSessionVars(nil)
vars.TimeZone = time.UTC
vars.StmtCtx.SetTimeZone(time.UTC)
p = NewSessionVarsProvider(mockSessionVarsProvider{vars: vars})
r := SessionVarsPropReader{}
reader = r
verifyNoProvider = func(ctx context.EvalContext) {
assertReaderFuncReturnErr(t, ctx, r.GetSessionVars)
}
verifyProvider = func(ctx context.EvalContext, val context.OptionalEvalPropProvider) {
got := val.(*SessionVarsPropProvider).GetSessionVars()
require.Same(t, vars, got)
require.Same(t, vars, assertReaderFuncValue(t, ctx, r.GetSessionVars))
}
case context.OptPropInfoSchema:
Expand Down
11 changes: 8 additions & 3 deletions pkg/expression/contextopt/sessionvars.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ var _ RequireOptionalEvalProps = SessionVarsPropReader{}

// SessionVarsPropProvider is a provider to get the session variables
type SessionVarsPropProvider struct {
variable.SessionVarsProvider
vars variable.SessionVarsProvider
}

// NewSessionVarsProvider returns a new SessionVarsPropProvider
func NewSessionVarsProvider(provider variable.SessionVarsProvider) *SessionVarsPropProvider {
intest.AssertNotNil(provider)
return &SessionVarsPropProvider{provider}
return &SessionVarsPropProvider{vars: provider}
}

// Desc implements the OptionalEvalPropProvider interface.
Expand All @@ -53,5 +53,10 @@ func (SessionVarsPropReader) GetSessionVars(ctx context.EvalContext) (*variable.
if err != nil {
return nil, err
}
return p.GetSessionVars(), nil

if intest.InTest {
context.AssertLocationWithSessionVars(ctx.Location(), p.vars.GetSessionVars())
}

return p.vars.GetSessionVars(), nil
}
10 changes: 5 additions & 5 deletions pkg/expression/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func IsValidCurrentTimestampExpr(exprNode ast.ExprNode, fieldType *types.FieldTy
}

// GetTimeCurrentTimestamp is used for generating a timestamp for some special cases: cast null value to timestamp type with not null flag.
func GetTimeCurrentTimestamp(ctx BuildContext, tp byte, fsp int) (d types.Datum, err error) {
func GetTimeCurrentTimestamp(ctx EvalContext, tp byte, fsp int) (d types.Datum, err error) {
var t types.Time
t, err = getTimeCurrentTimeStamp(ctx, tp, fsp)
if err != nil {
Expand All @@ -65,15 +65,15 @@ func GetTimeCurrentTimestamp(ctx BuildContext, tp byte, fsp int) (d types.Datum,
return d, nil
}

func getTimeCurrentTimeStamp(ctx BuildContext, tp byte, fsp int) (t types.Time, err error) {
func getTimeCurrentTimeStamp(ctx EvalContext, tp byte, fsp int) (t types.Time, err error) {
value := types.NewTime(types.ZeroCoreTime, tp, fsp)
defaultTime, err := getStmtTimestamp(ctx.GetEvalCtx())
defaultTime, err := getStmtTimestamp(ctx)
if err != nil {
return value, err
}
value.SetCoreTime(types.FromGoTime(defaultTime.Truncate(time.Duration(math.Pow10(9-fsp)) * time.Nanosecond)))
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate {
err = value.ConvertTimeZone(time.Local, ctx.GetSessionVars().Location())
err = value.ConvertTimeZone(time.Local, ctx.Location())
if err != nil {
return value, err
}
Expand All @@ -93,7 +93,7 @@ func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Lo
case string:
lowerX := strings.ToLower(x)
if lowerX == ast.CurrentTimestamp || lowerX == ast.CurrentDate {
if value, err = getTimeCurrentTimeStamp(ctx, tp, fsp); err != nil {
if value, err = getTimeCurrentTimeStamp(ctx.GetEvalCtx(), tp, fsp); err != nil {
return d, err
}
} else if lowerX == types.ZeroDatetimeStr {
Expand Down
2 changes: 2 additions & 0 deletions pkg/expression/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func TestCurrentTimestampTimeZone(t *testing.T) {
require.NoError(t, err)
err = sessionVars.SetSystemVar("time_zone", "+00:00")
require.NoError(t, err)
sessionVars.StmtCtx.SetTimeZone(sessionVars.Location())
v, err := GetTimeValue(ctx, ast.CurrentTimestamp, mysql.TypeTimestamp, types.MinFsp, nil)
require.NoError(t, err)
require.EqualValues(t, types.NewTime(
Expand All @@ -174,6 +175,7 @@ func TestCurrentTimestampTimeZone(t *testing.T) {
// would get different value.
err = sessionVars.SetSystemVar("time_zone", "+08:00")
require.NoError(t, err)
sessionVars.StmtCtx.SetTimeZone(sessionVars.Location())
v, err = GetTimeValue(ctx, ast.CurrentTimestamp, mysql.TypeTimestamp, types.MinFsp, nil)
require.NoError(t, err)
require.EqualValues(t, types.NewTime(
Expand Down
2 changes: 2 additions & 0 deletions pkg/planner/core/memtable_predicate_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ func TestMetricTableExtractor(t *testing.T) {
quantiles: []float64{0},
},
}
se.GetSessionVars().TimeZone = time.Local
se.GetSessionVars().StmtCtx.SetTimeZone(time.Local)
for _, ca := range cases {
logicalMemTable := getLogicalMemTable(t, dom, se, parser, ca.sql)
Expand Down Expand Up @@ -1048,6 +1049,7 @@ func TestTiDBHotRegionsHistoryTableExtractor(t *testing.T) {

se, err := session.CreateSession4Test(store)
require.NoError(t, err)
se.GetSessionVars().TimeZone = time.Local
se.GetSessionVars().StmtCtx.SetTimeZone(time.Local)

var cases = []struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3046,6 +3046,7 @@ func CreateSession4TestWithOpt(store kv.Storage, opt *Opt) (types.Session, error
s.GetSessionVars().MaxChunkSize = 32
s.GetSessionVars().MinPagingSize = variable.DefMinPagingSize
s.GetSessionVars().EnablePaging = variable.DefTiDBEnablePaging
s.GetSessionVars().StmtCtx.SetTimeZone(s.GetSessionVars().Location())
err = s.GetSessionVars().SetSystemVarWithoutValidation(variable.CharacterSetConnection, "utf8mb4")
}
return s, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ func getColDefaultValue(ctx expression.BuildContext, col *model.ColumnInfo, defa
// If the column's default value is not ZeroDatetimeStr or CurrentTimestamp, convert the default value to the current session time zone.
if needChangeTimeZone {
t := value.GetMysqlTime()
err = t.ConvertTimeZone(explicitTz, ctx.GetSessionVars().Location())
err = t.ConvertTimeZone(explicitTz, ctx.GetEvalCtx().Location())
if err != nil {
return value, err
}
Expand Down

0 comments on commit eaa6926

Please sign in to comment.