Skip to content

Commit

Permalink
sessionctx: support encoding and decoding statement context (#35688)
Browse files Browse the repository at this point in the history
close #35664
  • Loading branch information
djshow832 authored Jun 27, 2022
1 parent 0998cba commit 31c92c6
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 6 deletions.
9 changes: 7 additions & 2 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1928,14 +1928,19 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors {
if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates {
sc.InShowWarning = true
sc.SetWarnings(vars.StmtCtx.GetWarnings())
}
case *ast.SplitRegionStmt:
sc.IgnoreTruncate = false
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
case *ast.SetSessionStatesStmt:
sc.InSetSessionStatesStmt = true
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode()
default:
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
Expand All @@ -1954,7 +1959,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID
}
sc.PrevAffectedRows = 0
if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt {
if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt {
sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows())
} else if vars.StmtCtx.InSelectStmt {
sc.PrevAffectedRows = -1
Expand Down
9 changes: 5 additions & 4 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3543,15 +3543,16 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}

// Decode session variables.
for name, val := range sessionStates.SystemVars {
if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil {
return err
}
}

// Decode stmt ctx after session vars because setting session vars may override stmt ctx, such as warnings.
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}
return err
}
4 changes: 4 additions & 0 deletions sessionctx/sessionstates/session_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

ptypes "github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
)

Expand Down Expand Up @@ -52,4 +53,7 @@ type SessionStates struct {
FoundInBinding bool `json:"in-binding,omitempty"`
SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"`
MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"`
LastAffectedRows int64 `json:"affected-rows,omitempty"`
LastInsertID uint64 `json:"last-insert-id,omitempty"`
Warnings []stmtctx.SQLWarn `json:"warnings,omitempty"`
}
119 changes: 119 additions & 0 deletions sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,125 @@ func TestSessionCtx(t *testing.T) {
}
}

func TestStatementCtx(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("create table test.t1(id int auto_increment primary key, str char(1))")

tests := []struct {
setFunc func(tk *testkit.TestKit) any
checkFunc func(tk *testkit.TestKit, param any)
}{
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("show warnings")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("0"))
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("select 1")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastAffectedRows
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("insert into test.t1(str) value('a'), ('b'), ('c')")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select row_count()").Check(testkit.Rows("3"))
tk.MustQuery("select row_count()").Check(testkit.Rows("-1"))
},
},
{
// check LastInsertID
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_insert_id").Check(testkit.Rows("0"))
},
},
{
// check LastInsertID
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("insert into test.t1(str) value('d')")
rows := tk.MustQuery("select @@last_insert_id").Rows()
require.NotEqual(t, "0", rows[0][0].(string))
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("select @@last_insert_id").Check(param.([][]any))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustQuery("select 1")
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("show errors").Check(testkit.Rows())
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("show errors").Check(testkit.Rows())
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("0 0"))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustGetErrCode("insert into test.t1(str) value('ef')", errno.ErrDataTooLong)
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, 1, len(rows))
tk.MustQuery("show errors").Check(rows)
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(param.([][]any))
tk.MustQuery("show errors").Check(param.([][]any))
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("1 1"))
},
},
{
// check Warning
setFunc: func(tk *testkit.TestKit) any {
tk.MustExec("set sql_mode=''")
tk.MustExec("insert into test.t1(str) value('ef'), ('ef')")
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, 2, len(rows))
tk.MustQuery("show errors").Check(testkit.Rows())
return rows
},
checkFunc: func(tk *testkit.TestKit, param any) {
tk.MustQuery("show warnings").Check(param.([][]any))
tk.MustQuery("show errors").Check(testkit.Rows())
tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("2 0"))
},
},
}

for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
var param any
if tt.setFunc != nil {
param = tt.setFunc(tk1)
}
tk2 := testkit.NewTestKit(t, store)
showSessionStatesAndSet(t, tk1, tk2)
tt.checkFunc(tk2, param)
}
}

func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
rows := tk1.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)
Expand Down
49 changes: 49 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
package stmtctx

import (
"encoding/json"
"math"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/execdetails"
"github.com/pingcap/tidb/util/memory"
Expand Down Expand Up @@ -60,6 +63,43 @@ type SQLWarn struct {
Err error
}

type jsonSQLWarn struct {
Level string `json:"level"`
SQLErr *terror.Error `json:"err,omitempty"`
Msg string `json:"msg,omitempty"`
}

// MarshalJSON implements the Marshaler.MarshalJSON interface.
func (warn *SQLWarn) MarshalJSON() ([]byte, error) {
w := &jsonSQLWarn{
Level: warn.Level,
}
e := errors.Cause(warn.Err)
switch x := e.(type) {
case *terror.Error:
// Omit outter errors because only the most inner error matters.
w.SQLErr = x
default:
w.Msg = e.Error()
}
return json.Marshal(w)
}

// UnmarshalJSON implements the Unmarshaler.UnmarshalJSON interface.
func (warn *SQLWarn) UnmarshalJSON(data []byte) error {
var w jsonSQLWarn
if err := json.Unmarshal(data, &w); err != nil {
return err
}
warn.Level = w.Level
if w.SQLErr != nil {
warn.Err = w.SQLErr
} else {
warn.Err = errors.New(w.Msg)
}
return nil
}

// StatementContext contains variables for a statement.
// It should be reset before executing a statement.
type StatementContext struct {
Expand All @@ -76,6 +116,7 @@ type StatementContext struct {
InLoadDataStmt bool
InExplainStmt bool
InCreateOrAlterStmt bool
InSetSessionStatesStmt bool
InPreparedPlanBuilding bool
IgnoreTruncate bool
IgnoreZeroInDate bool
Expand Down Expand Up @@ -406,6 +447,13 @@ func (sc *StatementContext) AddAffectedRows(rows uint64) {
sc.mu.affectedRows += rows
}

// SetAffectedRows sets affected rows.
func (sc *StatementContext) SetAffectedRows(rows uint64) {
sc.mu.Lock()
sc.mu.affectedRows = rows
sc.mu.Unlock()
}

// AffectedRows gets affected rows.
func (sc *StatementContext) AffectedRows() uint64 {
sc.mu.Lock()
Expand Down Expand Up @@ -558,6 +606,7 @@ func (sc *StatementContext) SetWarnings(warns []SQLWarn) {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.mu.warnings = warns
sc.mu.errorCount = 0
for _, w := range warns {
if w.Level == WarnLevelError {
sc.mu.errorCount++
Expand Down
43 changes: 43 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ package stmtctx_test

import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/util/execdetails"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -143,3 +146,43 @@ func TestWeakConsistencyRead(t *testing.T) {
execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI)
tk.MustExec("rollback")
}

func TestMarshalSQLWarn(t *testing.T) {
warns := []stmtctx.SQLWarn{
{
Level: stmtctx.WarnLevelError,
Err: errors.New("any error"),
},
{
Level: stmtctx.WarnLevelError,
Err: errors.Trace(errors.New("any error")),
},
{
Level: stmtctx.WarnLevelWarning,
Err: variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown"),
},
{
Level: stmtctx.WarnLevelWarning,
Err: errors.Trace(variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown")),
},
}

store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
// First query can trigger loading global variables, which produces warnings.
tk.MustQuery("select 1")
tk.Session().GetSessionVars().StmtCtx.SetWarnings(warns)
rows := tk.MustQuery("show warnings").Rows()
require.Equal(t, len(warns), len(rows))

// The unmarshalled result doesn't need to be exactly the same with the original one.
// We only need that the results of `show warnings` are the same.
bytes, err := json.Marshal(warns)
require.NoError(t, err)
var newWarns []stmtctx.SQLWarn
err = json.Unmarshal(bytes, &newWarns)
require.NoError(t, err)
tk.Session().GetSessionVars().StmtCtx.SetWarnings(newWarns)
tk.MustQuery("show warnings").Check(rows)
}
10 changes: 10 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,11 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se
sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime
sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache
sessionStates.FoundInBinding = s.PrevFoundInBinding

// Encode StatementContext. We encode it here to avoid circle dependency.
sessionStates.LastAffectedRows = s.StmtCtx.PrevAffectedRows
sessionStates.LastInsertID = s.StmtCtx.PrevLastInsertID
sessionStates.Warnings = s.StmtCtx.GetWarnings()
return
}

Expand Down Expand Up @@ -1902,6 +1907,11 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se
}
s.FoundInPlanCache = sessionStates.FoundInPlanCache
s.FoundInBinding = sessionStates.FoundInBinding

// Decode StatementContext.
s.StmtCtx.SetAffectedRows(uint64(sessionStates.LastAffectedRows))
s.StmtCtx.PrevLastInsertID = sessionStates.LastInsertID
s.StmtCtx.SetWarnings(sessionStates.Warnings)
return
}

Expand Down

0 comments on commit 31c92c6

Please sign in to comment.