Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sessionctx: support encoding and decoding statement context #35688

Merged
merged 5 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1919,14 +1919,19 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors {
if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates {
sc.InShowWarning = true
sc.SetWarnings(vars.StmtCtx.GetWarnings())
}
case *ast.SplitRegionStmt:
sc.IgnoreTruncate = false
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
case *ast.SetSessionStatesStmt:
sc.InSetSessionStatesStmt = true
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
default:
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
Expand All @@ -1945,7 +1950,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID
}
sc.PrevAffectedRows = 0
if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt {
if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt {
sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows())
} else if vars.StmtCtx.InSelectStmt {
sc.PrevAffectedRows = -1
Expand Down
9 changes: 5 additions & 4 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3543,15 +3543,16 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}

// Decode session variables.
for name, val := range sessionStates.SystemVars {
if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil {
return err
}
}

// Decode stmt ctx after session vars because setting session vars may override stmt ctx, such as warnings.
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}
return err
}
4 changes: 4 additions & 0 deletions sessionctx/sessionstates/session_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

ptypes "github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
)

Expand Down Expand Up @@ -52,4 +53,7 @@ type SessionStates struct {
FoundInBinding bool `json:"in-binding,omitempty"`
SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"`
MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"`
LastAffectedRows int64 `json:"affected-rows,omitempty"`
LastInsertID uint64 `json:"last-insert-id,omitempty"`
Warnings []stmtctx.SQLWarn `json:"warnings,omitempty"`
}
119 changes: 119 additions & 0 deletions sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,125 @@ func TestSessionCtx(t *testing.T) {
}
}

func TestStatementCtx(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("create table test.t1(id int auto_increment primary key, str char(1))")

tests := []struct {
setFunc func(tk *testkit.TestKit) any
checkFunc func(tk *testkit.TestKit, param any)
}{
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("show warnings")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("0"))
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("select 1")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("insert into test.t1(str) value('a'), ('b'), ('c')")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("3"))
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastInsertID
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_insert_id").Check(testkit.Rows("0"))
},
},
{
// check LastInsertID
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("insert into test.t1(str) value('d')")
rows := tk.MustQuery("select @@last_insert_id").Rows()
require.NotEqual(t, "0", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_insert_id").Check(param.([][]any))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("select 1")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("show errors").Check(testkit.Rows())
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("show errors").Check(testkit.Rows())
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("0 0"))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustGetErrCode("insert into test.t1(str) value('ef')", errno.ErrDataTooLong)
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, 1, len(rows))
tk.MustQuery("show errors").Check(rows)
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(param.([][]any))
tk.MustQuery("show errors").Check(param.([][]any))
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("1 1"))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("set sql_mode=''")
tk.MustExec("insert into test.t1(str) value('ef'), ('ef')")
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, 2, len(rows))
tk.MustQuery("show errors").Check(testkit.Rows())
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(param.([][]any))
tk.MustQuery("show errors").Check(testkit.Rows())
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("2 0"))
},
},
}

for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
var param any
if tt.setFunc != nil {
param = tt.setFunc(tk1)
}
tk2 := testkit.NewTestKit(t, store)
showSessionStatesAndSet(t, tk1, tk2)
tt.checkFunc(tk2, param)
}
}

func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
rows := tk1.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)
Expand Down
49 changes: 49 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
package stmtctx

import (
"encoding/json"
"math"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/memory"
Expand Down Expand Up @@ -60,6 +63,43 @@ type SQLWarn struct {
Err error
}

type jsonSQLWarn struct {
Level string `json:"level"`
SQLErr *terror.Error `json:"err,omitempty"`
Msg string `json:"msg,omitempty"`
}

// MarshalJSON implements the Marshaler.MarshalJSON interface.
func (warn *SQLWarn) MarshalJSON() ([]byte, error) {
w := &jsonSQLWarn{
Level: warn.Level,
}
e := errors.Cause(warn.Err)
switch x := e.(type) {
case *terror.Error:
// Omit outter errors because only the most inner error matters.
w.SQLErr = x
default:
w.Msg = e.Error()
}
return json.Marshal(w)
}

// UnmarshalJSON implements the Unmarshaler.UnmarshalJSON interface.
func (warn *SQLWarn) UnmarshalJSON(data []byte) error {
var w jsonSQLWarn
if err := json.Unmarshal(data, &w); err != nil {
return err
}
warn.Level = w.Level
if w.SQLErr != nil {
warn.Err = w.SQLErr
} else {
warn.Err = errors.New(w.Msg)
}
return nil
}

// StatementContext contains variables for a statement.
// It should be reset before executing a statement.
type StatementContext struct {
Expand All @@ -76,6 +116,7 @@ type StatementContext struct {
InLoadDataStmt bool
InExplainStmt bool
InCreateOrAlterStmt bool
InSetSessionStatesStmt bool
InPreparedPlanBuilding bool
IgnoreTruncate bool
IgnoreZeroInDate bool
Expand Down Expand Up @@ -406,6 +447,13 @@ func (sc *StatementContext) AddAffectedRows(rows uint64) {
sc.mu.affectedRows += rows
}

// SetAffectedRows sets affected rows.
func (sc *StatementContext) SetAffectedRows(rows uint64) {
sc.mu.Lock()
sc.mu.affectedRows = rows
sc.mu.Unlock()
}

// AffectedRows gets affected rows.
func (sc *StatementContext) AffectedRows() uint64 {
sc.mu.Lock()
Expand Down Expand Up @@ -558,6 +606,7 @@ func (sc *StatementContext) SetWarnings(warns []SQLWarn) {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.mu.warnings = warns
sc.mu.errorCount = 0
for _, w := range warns {
if w.Level == WarnLevelError {
sc.mu.errorCount++
Expand Down
43 changes: 43 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ package stmtctx_test

import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/util/execdetails"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -143,3 +146,43 @@ func TestWeakConsistencyRead(t *testing.T) {
execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI)
tk.MustExec("rollback")
}

func TestMarshalSQLWarn(t *testing.T) {
warns := []stmtctx.SQLWarn{
{
Level: stmtctx.WarnLevelError,
Err: errors.New("any error"),
},
{
Level: stmtctx.WarnLevelError,
Err: errors.Trace(errors.New("any error")),
},
{
Level: stmtctx.WarnLevelWarning,
Err: variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown"),
},
{
Level: stmtctx.WarnLevelWarning,
Err: errors.Trace(variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown")),
},
}

store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
// First query can trigger loading global variables, which produces warnings.
tk.MustQuery("select 1")
tk.Session().GetSessionVars().StmtCtx.SetWarnings(warns)
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, len(warns), len(rows))

// The unmarshalled result doesn't need to be exactly the same with the original one.
// We only need that the results of `show warnings` are the same.
bytes, err := json.Marshal(warns)
require.NoError(t, err)
var newWarns []stmtctx.SQLWarn
err = json.Unmarshal(bytes, &newWarns)
require.NoError(t, err)
tk.Session().GetSessionVars().StmtCtx.SetWarnings(newWarns)
tk.MustQuery("show warnings").Check(rows)
}
10 changes: 10 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,11 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se
sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime
sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache
sessionStates.FoundInBinding = s.PrevFoundInBinding

// Encode StatementContext. We encode it here to avoid circle dependency.
sessionStates.LastAffectedRows = s.StmtCtx.PrevAffectedRows
sessionStates.LastInsertID = s.StmtCtx.PrevLastInsertID
sessionStates.Warnings = s.StmtCtx.GetWarnings()
return
}

Expand Down Expand Up @@ -1902,6 +1907,11 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se
}
s.FoundInPlanCache = sessionStates.FoundInPlanCache
s.FoundInBinding = sessionStates.FoundInBinding

// Decode StatementContext.
s.StmtCtx.SetAffectedRows(uint64(sessionStates.LastAffectedRows))
s.StmtCtx.PrevLastInsertID = sessionStates.LastInsertID
s.StmtCtx.SetWarnings(sessionStates.Warnings)
return
}

Expand Down