diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index 44f2349b672b2..13817e28eb668 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -517,7 +517,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) { switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := common.ConnectMySQL(cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return nil, errors.Trace(err) } @@ -546,7 +554,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error } switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := sql.Open("mysql", cfg.Checkpoint.DSN) + var ( + db *sql.DB + err error + ) + if cfg.Checkpoint.MySQLParam != nil { + db, err = cfg.Checkpoint.MySQLParam.Connect() + } else { + db, err = sql.Open("mysql", cfg.Checkpoint.DSN) + } if err != nil { return false, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index cc03f0ec68dca..679ba6cc5d48b 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -23,7 +23,6 @@ import ( "io" "net" "net/http" - "net/url" "os" "strconv" "strings" @@ -58,28 +57,38 @@ type MySQLConnectParam struct { Vars map[string]string } -func (param *MySQLConnectParam) ToDSN() string { - hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) - dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s", - param.User, param.Password, hostPort, - param.SQLMode, param.MaxAllowedPacket, param.TLS) +func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { + cfg := mysql.NewConfig() + cfg.Params = make(map[string]string) + + cfg.User = param.User + cfg.Passwd = param.Password + cfg.Net = "tcp" + cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) + cfg.Params["charset"] = "utf8mb4" + cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode) + cfg.MaxAllowedPacket = int(param.MaxAllowedPacket) + cfg.TLSConfig = param.TLS for k, v := range param.Vars { - dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v)) + cfg.Params[k] = fmt.Sprintf("'%s'", v) } - - return dsn + return cfg } -func tryConnectMySQL(dsn string) (*sql.DB, error) { - driverName := "mysql" - failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) { - driverName = val.(string) +func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { + failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { + pwd := val.(string) + if cfg.Passwd != pwd { + failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) + } + failpoint.Return(nil, nil) }) - db, err := sql.Open(driverName, dsn) + c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) } + db := sql.OpenDB(c) if err = db.Ping(); err != nil { _ = db.Close() return nil, errors.Trace(err) @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) { // ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, // we will try to connect MySQL with the base64 decoding of the password. -func ConnectMySQL(dsn string) (*sql.DB, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, errors.Trace(err) - } +func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { // Try plain password first. - db, firstErr := tryConnectMySQL(dsn) + db, firstErr := tryConnectMySQL(cfg) if firstErr == nil { return db, nil } @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { // If password is encoded by base64, try the decoded string as well. if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd { cfg.Passwd = string(password) - db, err = tryConnectMySQL(cfg.FormatDSN()) + db2, err := tryConnectMySQL(cfg) if err == nil { - return db, nil + return db2, nil } } } @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) { } func (param *MySQLConnectParam) Connect() (*sql.DB, error) { - db, err := ConnectMySQL(param.ToDSN()) + db, err := ConnectMySQL(param.ToDriverConfig()) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util_test.go b/br/pkg/lightning/common/util_test.go index c7c95b44f69bf..a192ecea11906 100644 --- a/br/pkg/lightning/common/util_test.go +++ b/br/pkg/lightning/common/util_test.go @@ -16,16 +16,12 @@ package common_test import ( "context" - "database/sql" - "database/sql/driver" "encoding/base64" "encoding/json" "fmt" "io" - "math/rand" "net/http" "net/http/httptest" - "strconv" "testing" "time" @@ -35,7 +31,6 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/log" - tmysql "github.com/pingcap/tidb/errno" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) { require.Regexp(t, ".*http status code != 200.*", err.Error()) } -func TestToDSN(t *testing.T) { - param := common.MySQLConnectParam{ - Host: "127.0.0.1", - Port: 4000, - User: "root", - Password: "123456", - SQLMode: "strict", - MaxAllowedPacket: 1234, - TLS: "cluster", - Vars: map[string]string{ - "tidb_distsql_scan_concurrency": "1", - }, - } - require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) - - param.Host = "::1" - require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) -} - -type mockDriver struct { - driver.Driver - plainPsw string -} - -func (m *mockDriver) Open(dsn string) (driver.Conn, error) { - cfg, err := mysql.ParseDSN(dsn) - if err != nil { - return nil, err - } - accessDenied := cfg.Passwd != m.plainPsw - return &mockConn{accessDenied: accessDenied}, nil -} - -type mockConn struct { - driver.Conn - driver.Pinger - accessDenied bool -} - -func (c *mockConn) Ping(ctx context.Context) error { - if c.accessDenied { - return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} - } - return nil -} - -func (c *mockConn) Close() error { - return nil -} - func TestConnect(t *testing.T) { plainPsw := "dQAUoDiyb1ucWZk7" - driverName := "mysql-mock-" + strconv.Itoa(rand.Int()) - sql.Register(driverName, &mockDriver{plainPsw: plainPsw}) require.NoError(t, failpoint.Enable( - "github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver", - fmt.Sprintf("return(\"%s\")", driverName))) + "github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword", + fmt.Sprintf("return(\"%s\")", plainPsw))) defer func() { - require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword")) }() param := common.MySQLConnectParam{ @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) { SQLMode: "strict", MaxAllowedPacket: 1234, } - db, err := param.Connect() + _, err := param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw)) - db, err = param.Connect() + _, err = param.Connect() require.NoError(t, err) - require.NoError(t, db.Close()) } func TestIsContextCanceledError(t *testing.T) { diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 4c1af0d2baff3..638784ff3ed1e 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -553,11 +553,12 @@ type TikvImporter struct { } type Checkpoint struct { - Schema string `toml:"schema" json:"schema"` - DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. - Driver string `toml:"driver" json:"driver"` - Enable bool `toml:"enable" json:"enable"` - KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` + Schema string `toml:"schema" json:"schema"` + DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON. + MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN. + Driver string `toml:"driver" json:"driver"` + Enable bool `toml:"enable" json:"enable"` + KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"` } type Cron struct { @@ -1142,7 +1143,7 @@ func (cfg *Config) AdjustCheckPoint() { MaxAllowedPacket: defaultMaxAllowedPacket, TLS: cfg.TiDB.TLS, } - cfg.Checkpoint.DSN = param.ToDSN() + cfg.Checkpoint.MySQLParam = ¶m case CheckpointDriverFile: cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb" } diff --git a/br/pkg/lightning/config/config_test.go b/br/pkg/lightning/config/config_test.go index 2a4dcbe7cdad9..e74094a6b9066 100644 --- a/br/pkg/lightning/config/config_test.go +++ b/br/pkg/lightning/config/config_test.go @@ -32,7 +32,6 @@ import ( "github.com/BurntSushi/toml" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -626,7 +625,9 @@ func TestLoadConfig(t *testing.T) { taskCfg.TiDB.DistSQLScanConcurrency = 1 err = taskCfg.Adjust(context.Background()) require.NoError(t, err) - require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN) + equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN() + expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27" + require.Equal(t, expectedDSN, equivalentDSN) result := taskCfg.String() require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result) diff --git a/cmd/explaintest/r/imdbload.result b/cmd/explaintest/r/imdbload.result index c3ee5badab7e6..00543e6a1640f 100644 --- a/cmd/explaintest/r/imdbload.result +++ b/cmd/explaintest/r/imdbload.result @@ -286,7 +286,7 @@ IndexLookUp_7 1005030.94 root └─TableRowIDScan_6(Probe) 1005030.94 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where ((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436')); CE_trace -[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] +[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] explain select * from char_name where ((imdb_index = 'V') and (surname_pcode < 'L3416')); id estRows task access object operator info @@ -356,7 +356,7 @@ IndexLookUp_11 901.00 root └─TableRowIDScan_9 901.00 cop[tikv] table:keyword keep order:false trace plan target = 'estimation' select * from keyword where ((phonetic_code = 'R1652') and (keyword > 'ecg-monitor' and keyword < 'killers')); CE_trace -[{"table_name":"keyword","type":"Column Stats-Point","expr":"((phonetic_code = 'R1652'))","row_count":23480},{"table_name":"keyword","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":236627},{"table_name":"keyword","type":"Column Stats-Range","expr":"((keyword > 'ecg-monitor' and keyword < 'killers'))","row_count":44075},{"table_name":"keyword","type":"Index Stats-Point","expr":"((phonetic_code = 'R1652'))","row_count":23480},{"table_name":"keyword","type":"Index Stats-Range","expr":"((keyword > 'ecg-monitor' and keyword < 'killers'))","row_count":44036},{"table_name":"keyword","type":"Index Stats-Range","expr":"((keyword >= 'ecg-m' and keyword <= 'kille'))","row_count":44036},{"table_name":"keyword","type":"Index Stats-Range","expr":"((phonetic_code = 'R1652') and (keyword > 'ecg-monitor' and keyword < 'killers'))","row_count":901},{"table_name":"keyword","type":"Table Stats-Expression-CNF","expr":"`and`(`eq`(imdbload.keyword.phonetic_code, 'R1652'), `and`(`gt`(imdbload.keyword.keyword, 'ecg-monitor'), `lt`(imdbload.keyword.keyword, 'killers')))","row_count":901}] +[{"table_name":"keyword","type":"Column Stats-Point","expr":"((phonetic_code = 'R1652'))","row_count":23480},{"table_name":"keyword","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":236627},{"table_name":"keyword","type":"Column Stats-Range","expr":"((keyword > 'ecg-monitor' and keyword < 'killers'))","row_count":44075},{"table_name":"keyword","type":"Index Stats-Point","expr":"((phonetic_code = 'R1652'))","row_count":23480},{"table_name":"keyword","type":"Index Stats-Range","expr":"((keyword >= 'ecg-m' and keyword <= 'kille'))","row_count":44036},{"table_name":"keyword","type":"Index Stats-Range","expr":"((phonetic_code = 'R1652') and (keyword > 'ecg-monitor' and keyword < 'killers'))","row_count":901},{"table_name":"keyword","type":"Table Stats-Expression-CNF","expr":"`and`(`eq`(imdbload.keyword.phonetic_code, 'R1652'), `and`(`gt`(imdbload.keyword.keyword, 'ecg-monitor'), `lt`(imdbload.keyword.keyword, 'killers')))","row_count":901}] explain select * from cast_info where (nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida'); id estRows task access object operator info diff --git a/cmd/importer/db.go b/cmd/importer/db.go index 8b0d7353b9adf..b8ecf83abfc4b 100644 --- a/cmd/importer/db.go +++ b/cmd/importer/db.go @@ -22,7 +22,7 @@ import ( "strconv" "strings" - _ "github.com/go-sql-driver/mysql" + mysql2 "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tidb/parser/mysql" @@ -318,13 +318,18 @@ func execSQL(db *sql.DB, sql string) error { } func createDB(cfg DBConfig) (*sql.DB, error) { - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name) - db, err := sql.Open("mysql", dbDSN) + driverCfg := mysql2.NewConfig() + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = cfg.Host + ":" + strconv.Itoa(cfg.Port) + driverCfg.DBName = cfg.Name + + c, err := mysql2.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - return db, nil + return sql.OpenDB(c), nil } func closeDB(db *sql.DB) error { diff --git a/dumpling/export/config.go b/dumpling/export/config.go index 980de0d8807f5..b92d2922d2572 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -218,6 +218,31 @@ func (conf *Config) GetDSN(db string) string { return dsn } +// GetDriverConfig returns the MySQL driver config from Config. +func (conf *Config) GetDriverConfig(db string) *mysql.Config { + driverCfg := mysql.NewConfig() + // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. + // https://github.com/go-sql-driver/mysql#maxallowedpacket + hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port)) + driverCfg.User = conf.User + driverCfg.Passwd = conf.Password + driverCfg.Net = "tcp" + driverCfg.Addr = hostPort + driverCfg.DBName = db + driverCfg.Collation = "utf8mb4_general_ci" + driverCfg.ReadTimeout = conf.ReadTimeout + driverCfg.WriteTimeout = 30 * time.Second + driverCfg.InterpolateParams = true + driverCfg.MaxAllowedPacket = 0 + if conf.Security.DriveTLSName != "" { + driverCfg.TLSConfig = conf.Security.DriveTLSName + } + if conf.AllowCleartextPasswords { + driverCfg.AllowCleartextPasswords = true + } + return driverCfg +} + func timestampDirName() string { return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339)) } diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index cdc91e6e4a389..857ef5d7470fb 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -37,7 +37,7 @@ import ( "golang.org/x/sync/errgroup" ) -var openDBFunc = sql.Open +var openDBFunc = openDB var errEmptyHandleVals = errors.New("empty handleVals for TiDB table") @@ -1309,11 +1309,11 @@ func startHTTPService(d *Dumper) error { // openSQLDB is an initialization step of Dumper. func openSQLDB(d *Dumper) error { conf := d.conf - pool, err := sql.Open("mysql", conf.GetDSN("")) + c, err := mysql.NewConnector(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } - d.dbHandle = pool + d.dbHandle = sql.OpenDB(c) return nil } @@ -1510,12 +1510,20 @@ func setSessionParam(d *Dumper) error { } } } - if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDSN(""), conf.SessionParams); err != nil { + if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDriverConfig(""), conf.SessionParams); err != nil { return errors.Trace(err) } return nil } +func openDB(cfg *mysql.Config) (*sql.DB, error) { + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + return sql.OpenDB(c), nil +} + func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) error { conf := d.conf if !(conf.ServerInfo.ServerType == version.ServerTypeTiDB && conf.ServerInfo.ServerVersion != nil && conf.ServerInfo.HasTiKV && @@ -1532,7 +1540,7 @@ func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) d.selectTiDBTableRegionFunc = func(_ *tcontext.Context, _ *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { return nil, nil, errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(meta.DatabaseName()), escapeString(meta.TableName())) } - dbHandle, err := openDBFunc("mysql", conf.GetDSN("")) + dbHandle, err := openDBFunc(conf.GetDriverConfig("")) if err != nil { return errors.Trace(err) } diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 83655df99e330..837bec568b9a7 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "math" - "net/url" "strconv" "strings" @@ -834,7 +833,7 @@ func isUnknownSystemVariableErr(err error) bool { // resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters. // If returned error is nil, the input `db` will be closed. -func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, params map[string]interface{}) (*sql.DB, error) { +func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Config, params map[string]interface{}) (*sql.DB, error) { support := make(map[string]interface{}) for k, v := range params { var pv interface{} @@ -862,6 +861,10 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa support[k] = pv } + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + for k, v := range support { var s string // Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40' @@ -871,19 +874,21 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, pa } else { s = fmt.Sprintf("%v", v) } - dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(s)) + cfg.Params[k] = s } db.Close() - newDB, err := sql.Open("mysql", dsn) - if err == nil { - // ping to make sure all session parameters are set correctly - err = newDB.PingContext(tctx) - if err != nil { - newDB.Close() - } + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + newDB := sql.OpenDB(c) + // ping to make sure all session parameters are set correctly + err = newDB.PingContext(tctx) + if err != nil { + newDB.Close() } - return newDB, errors.Trace(err) + return newDB, nil } func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) { diff --git a/dumpling/export/sql_test.go b/dumpling/export/sql_test.go index d98a8a3c76a64..04615637be8f1 100644 --- a/dumpling/export/sql_test.go +++ b/dumpling/export/sql_test.go @@ -1345,7 +1345,7 @@ func TestBuildVersion3RegionQueries(t *testing.T) { defer func() { openDBFunc = oldOpenFunc }() - openDBFunc = func(_, _ string) (*sql.DB, error) { + openDBFunc = func(*mysql.Config) (*sql.DB, error) { return db, nil } diff --git a/dumpling/tests/s3/import.go b/dumpling/tests/s3/import.go index 0489be3fa7a80..30dc95fae84b1 100644 --- a/dumpling/tests/s3/import.go +++ b/dumpling/tests/s3/import.go @@ -6,7 +6,9 @@ import ( "context" "database/sql" "fmt" + "net" "os" + "strconv" _ "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" @@ -48,7 +50,7 @@ func main() { return errors.Trace(err) } - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4", "root", "", "127.0.0.1", port, database) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", "root", "", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), database) db, err := sql.Open("mysql", dsn) if err != nil { return errors.Trace(err) diff --git a/executor/adapter.go b/executor/adapter.go index b83c50ecfd7e1..a4c8075bf0395 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -363,7 +363,7 @@ func (a *ExecStmt) IsReadOnly(vars *variable.SessionVars) bool { // It returns the current information schema version that 'a' is using. func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { ret := &plannercore.PreprocessorReturn{} - if err := plannercore.Preprocess(a.Ctx, a.StmtNode, plannercore.InTxnRetry, plannercore.InitTxnContextProvider, plannercore.WithPreprocessorReturn(ret)); err != nil { + if err := plannercore.Preprocess(ctx, a.Ctx, a.StmtNode, plannercore.InTxnRetry, plannercore.InitTxnContextProvider, plannercore.WithPreprocessorReturn(ret)); err != nil { return 0, err } diff --git a/executor/analyze_test.go b/executor/analyze_test.go index 2139816195b5f..246016a4082f7 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -17,6 +17,7 @@ package executor_test import ( "fmt" "io/ioutil" + "strconv" "strings" "sync/atomic" "testing" @@ -338,3 +339,78 @@ func TestAnalyzePartitionTableForFloat(t *testing.T) { } tk.MustExec("analyze table t1") } + +func TestAnalyzePartitionTableByConcurrencyInDynamic(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@tidb_partition_prune_mode='dynamic'") + tk.MustExec("use test") + tk.MustExec("create table t(id int) partition by hash(id) partitions 4") + testcases := []struct { + concurrency string + }{ + { + concurrency: "1", + }, + { + concurrency: "2", + }, + { + concurrency: "3", + }, + { + concurrency: "4", + }, + { + concurrency: "5", + }, + } + // assert empty table + for _, tc := range testcases { + concurrency := tc.concurrency + fmt.Println("testcase ", concurrency) + tk.MustExec(fmt.Sprintf("set @@tidb_merge_partition_stats_concurrency=%v", concurrency)) + tk.MustQuery("select @@tidb_merge_partition_stats_concurrency").Check(testkit.Rows(concurrency)) + tk.MustExec("analyze table t") + tk.MustQuery("show stats_topn where partition_name = 'global' and table_name = 't'") + } + + for i := 1; i <= 500; i++ { + for j := 1; j <= 20; j++ { + tk.MustExec(fmt.Sprintf("insert into t (id) values (%v)", j)) + } + } + var expected [][]interface{} + for i := 1; i <= 20; i++ { + expected = append(expected, []interface{}{ + strconv.FormatInt(int64(i), 10), "500", + }) + } + testcases = []struct { + concurrency string + }{ + { + concurrency: "1", + }, + { + concurrency: "2", + }, + { + concurrency: "3", + }, + { + concurrency: "4", + }, + { + concurrency: "5", + }, + } + for _, tc := range testcases { + concurrency := tc.concurrency + fmt.Println("testcase ", concurrency) + tk.MustExec(fmt.Sprintf("set @@tidb_merge_partition_stats_concurrency=%v", concurrency)) + tk.MustQuery("select @@tidb_merge_partition_stats_concurrency").Check(testkit.Rows(concurrency)) + tk.MustExec("analyze table t") + tk.MustQuery("show stats_topn where partition_name = 'global' and table_name = 't'").CheckAt([]int{5, 6}, expected) + } +} diff --git a/executor/compiler.go b/executor/compiler.go index 8f0ac913a30f1..241b15874e1e2 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -76,7 +76,7 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecS c.Ctx.GetSessionVars().StmtCtx.IsReadOnly = plannercore.IsReadOnly(stmtNode, c.Ctx.GetSessionVars()) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(c.Ctx, + err = plannercore.Preprocess(ctx, c.Ctx, stmtNode, plannercore.WithPreprocessorReturn(ret), plannercore.InitTxnContextProvider, diff --git a/executor/ddl.go b/executor/ddl.go index 1fd2b20eb70a1..fe105f074658d 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -153,7 +153,7 @@ func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { case *ast.CreateTableStmt: err = e.executeCreateTable(x) case *ast.CreateViewStmt: - err = e.executeCreateView(x) + err = e.executeCreateView(ctx, x) case *ast.DropIndexStmt: err = e.executeDropIndex(x) case *ast.DropDatabaseStmt: @@ -281,9 +281,9 @@ func (e *DDLExec) createSessionTemporaryTable(s *ast.CreateTableStmt) error { return nil } -func (e *DDLExec) executeCreateView(s *ast.CreateViewStmt) error { +func (e *DDLExec) executeCreateView(ctx context.Context, s *ast.CreateViewStmt) error { ret := &core.PreprocessorReturn{} - err := core.Preprocess(e.ctx, s.Select, core.WithPreprocessorReturn(ret)) + err := core.Preprocess(ctx, e.ctx, s.Select, core.WithPreprocessorReturn(ret)) if err != nil { return errors.Trace(err) } diff --git a/executor/executor_test.go b/executor/executor_test.go index 5063362462cb5..97890c934f5a5 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -4931,7 +4931,7 @@ func TestIsPointGet(t *testing.T) { stmtNode, err := s.ParseOneStmt(sqlStr, "", "") require.NoError(t, err) preprocessorReturn := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(ctx, stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) + err = plannercore.Preprocess(context.Background(), ctx, stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmtNode, preprocessorReturn.InfoSchema) require.NoError(t, err) @@ -4964,7 +4964,7 @@ func TestClusteredIndexIsPointGet(t *testing.T) { stmtNode, err := s.ParseOneStmt(sqlStr, "", "") require.NoError(t, err) preprocessorReturn := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(ctx, stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) + err = plannercore.Preprocess(context.Background(), ctx, stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmtNode, preprocessorReturn.InfoSchema) require.NoError(t, err) diff --git a/executor/metrics_reader_test.go b/executor/metrics_reader_test.go index 680bfd872e9e3..276c99c8ac22d 100644 --- a/executor/metrics_reader_test.go +++ b/executor/metrics_reader_test.go @@ -64,7 +64,7 @@ func TestStmtLabel(t *testing.T) { stmtNode, err := parser.New().ParseOneStmt(tt.sql, "", "") require.NoError(t, err) preprocessorReturn := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(tk.Session(), stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) + err = plannercore.Preprocess(context.Background(), tk.Session(), stmtNode, plannercore.WithPreprocessorReturn(preprocessorReturn)) require.NoError(t, err) _, _, err = planner.Optimize(context.TODO(), tk.Session(), stmtNode, preprocessorReturn.InfoSchema) require.NoError(t, err) diff --git a/expression/integration_test.go b/expression/integration_test.go index 5c47118c1ca46..945def182d76a 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2852,7 +2852,7 @@ func TestFilterExtractFromDNF(t *testing.T) { require.NoError(t, err, "error %v, for expr %s", err, tt.exprStr) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err, "error %v, for resolve name, expr %s", err, tt.exprStr) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err, "error %v, for build plan, expr %s", err, tt.exprStr) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index b2554c4f3d548..28f71099e9c46 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -126,7 +126,7 @@ func TestInferType(t *testing.T) { require.NoError(t, err) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmt, plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmt, plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err, comment) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmt, ret.InfoSchema) require.NoError(t, err, comment) diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 31ba6bfeb3e07..c88edf4470d9b 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -246,7 +246,7 @@ func TestIndexRead(t *testing.T) { require.Len(t, stmts, 1) stmt := stmts[0] ret := &core.PreprocessorReturn{} - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmt, ret.InfoSchema) require.NoError(t, err) @@ -276,7 +276,7 @@ func TestEmptyTable(t *testing.T) { require.Len(t, stmts, 1) stmt := stmts[0] ret := &core.PreprocessorReturn{} - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmt, ret.InfoSchema) require.NoError(t, err) @@ -343,7 +343,7 @@ func TestAnalyze(t *testing.T) { err = executor.ResetContextOfStmt(ctx, stmt) require.NoError(t, err) ret := &core.PreprocessorReturn{} - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmt, ret.InfoSchema) require.NoError(t, err) @@ -586,7 +586,7 @@ func BenchmarkOptimize(b *testing.B) { require.Len(b, stmts, 1) stmt := stmts[0] ret := &core.PreprocessorReturn{} - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(ret)) require.NoError(b, err) b.Run(tt.sql, func(b *testing.B) { diff --git a/planner/core/collect_column_stats_usage_test.go b/planner/core/collect_column_stats_usage_test.go index 38d246ff8bfd7..c6f8cd6933f59 100644 --- a/planner/core/collect_column_stats_usage_test.go +++ b/planner/core/collect_column_stats_usage_test.go @@ -256,7 +256,7 @@ func TestCollectPredicateColumns(t *testing.T) { } stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) builder, _ := NewPlanBuilder().Init(s.ctx, s.is, &hint.BlockHintProcessor{}) p, err := builder.Build(ctx, stmt) @@ -333,7 +333,7 @@ func TestCollectHistNeededColumns(t *testing.T) { } stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) builder, _ := NewPlanBuilder().Init(s.ctx, s.is, &hint.BlockHintProcessor{}) p, err := builder.Build(ctx, stmt) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index c73ce9f3c086d..2389f56da5337 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -1711,9 +1711,10 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre if len(ranges) == 0 || len(accessConds) == 0 || err != nil { return 0, err == nil, corr } - idxID, idxExists := ds.stats.HistColl.ColID2IdxID[colID] - if !idxExists { - idxID = -1 + idxID := int64(-1) + idxIDs, idxExists := ds.stats.HistColl.ColID2IdxIDs[colID] + if idxExists && len(idxIDs) > 0 { + idxID = idxIDs[0] } rangeCounts, ok := getColumnRangeCounts(ds.ctx, colID, ranges, ds.tableStats.HistColl, idxID) if !ok { diff --git a/planner/core/indexmerge_test.go b/planner/core/indexmerge_test.go index f109b85aaee18..8867e5a64c744 100644 --- a/planner/core/indexmerge_test.go +++ b/planner/core/indexmerge_test.go @@ -69,7 +69,7 @@ func TestIndexMergePathGeneration(t *testing.T) { for i, tc := range input { stmt, err := parser.ParseOneStmt(tc, "", "") require.NoErrorf(t, err, "case:%v sql:%s", i, tc) - err = Preprocess(sctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: is})) + err = Preprocess(context.Background(), sctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: is})) require.NoError(t, err) builder, _ := NewPlanBuilder().Init(MockContext(), is, &hint.BlockHintProcessor{}) p, err := builder.Build(ctx, stmt) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 4e320c62e2e50..7458c7307c19c 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -502,7 +502,7 @@ func TestSubquery(t *testing.T) { stmt, err := s.p.ParseOneStmt(ca, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) require.NoError(t, err) @@ -530,7 +530,7 @@ func TestPlanBuilder(t *testing.T) { require.NoError(t, err, comment) s.ctx.GetSessionVars().SetHashJoinConcurrency(1) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) require.NoError(t, err) @@ -927,7 +927,7 @@ func TestValidate(t *testing.T) { comment := fmt.Sprintf("for %s", sql) stmt, err := s.p.ParseOneStmt(sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) _, _, err = BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) if tt.err == nil { @@ -1422,7 +1422,7 @@ func TestVisitInfo(t *testing.T) { require.NoError(t, err, comment) // TODO: to fix, Table 'test.ttt' doesn't exist - _ = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + _ = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) domain.GetDomain(sctx).MockInfoCacheAndLoadInfoSchema(s.is) @@ -1502,7 +1502,7 @@ func TestUnion(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt) stmt, err := s.p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1535,7 +1535,7 @@ func TestTopNPushDown(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt) stmt, err := s.p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1612,7 +1612,7 @@ func TestOuterJoinEliminator(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt) stmt, err := s.p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1649,7 +1649,7 @@ func TestSelectView(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) builder, _ := NewPlanBuilder().Init(MockContext(), s.is, &hint.BlockHintProcessor{}) p, err := builder.Build(ctx, stmt) @@ -1732,7 +1732,7 @@ func (s *plannerSuiteWithOptimizeVars) optimize(ctx context.Context, sql string) if err != nil { return nil, nil, err } - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) if err != nil { return nil, nil, err } @@ -1841,7 +1841,7 @@ func TestSkylinePruning(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1914,7 +1914,7 @@ func TestFastPlanContextTables(t *testing.T) { for _, tt := range tests { stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) s.ctx.GetSessionVars().StmtCtx.Tables = nil p := TryFastPlan(s.ctx, stmt) @@ -1946,7 +1946,7 @@ func TestUpdateEQCond(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1965,7 +1965,7 @@ func TestConflictedJoinTypeHints(t *testing.T) { ctx := context.TODO() stmt, err := s.p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -1988,7 +1988,7 @@ func TestSimplyOuterJoinWithOnlyOuterExpr(t *testing.T) { ctx := context.TODO() stmt, err := s.p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err) sctx := MockContext() builder, _ := NewPlanBuilder().Init(sctx, s.is, &hint.BlockHintProcessor{}) @@ -2042,7 +2042,7 @@ func TestResolvingCorrelatedAggregate(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) stmt, err := s.p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) p, _, err := BuildLogicalPlanForTest(ctx, s.ctx, stmt, s.is) require.NoError(t, err, comment) @@ -2084,7 +2084,7 @@ func TestFastPathInvalidBatchPointGet(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, tc.sql) stmt, err := s.p.ParseOneStmt(tc.sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) plan := TryFastPlan(s.ctx, stmt) if tc.fastPlan { @@ -2106,7 +2106,7 @@ func TestTraceFastPlan(t *testing.T) { comment := fmt.Sprintf("sql:%s", sql) stmt, err := s.p.ParseOneStmt(sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) plan := TryFastPlan(s.ctx, stmt) require.NotNil(t, plan) diff --git a/planner/core/logical_plan_trace_test.go b/planner/core/logical_plan_trace_test.go index 5bf38b6e18f86..7233b49cb24e1 100644 --- a/planner/core/logical_plan_trace_test.go +++ b/planner/core/logical_plan_trace_test.go @@ -399,7 +399,7 @@ func TestSingleRuleTraceStep(t *testing.T) { comment := fmt.Sprintf("case:%v sql:%s", i, sql) stmt, err := s.p.ParseOneStmt(sql, "", "") require.NoError(t, err, comment) - err = Preprocess(s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) + err = Preprocess(context.Background(), s.ctx, stmt, WithPreprocessorReturn(&PreprocessorReturn{InfoSchema: s.is})) require.NoError(t, err, comment) sctx := MockContext() sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 652e1ed89c49e..fcbb2d568d40d 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -123,7 +123,7 @@ func TestAnalyzeBuildSucc(t *testing.T) { } else if err != nil { continue } - err = core.Preprocess(tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) + err = core.Preprocess(context.Background(), tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err) _, _, err = planner.Optimize(context.Background(), tk.Session(), stmt, is) if tt.succ { @@ -165,7 +165,7 @@ func TestAnalyzeSetRate(t *testing.T) { stmt, err := p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = core.Preprocess(tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) + err = core.Preprocess(context.Background(), tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err, comment) p, _, err := planner.Optimize(context.Background(), tk.Session(), stmt, is) require.NoError(t, err, comment) @@ -310,7 +310,7 @@ func TestDAGPlanBuilderBasePhysicalPlan(t *testing.T) { stmt, err := p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) - err = core.Preprocess(se, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) + err = core.Preprocess(context.Background(), se, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), se, stmt, is) require.NoError(t, err) @@ -1658,7 +1658,7 @@ func TestDAGPlanBuilderSplitAvg(t *testing.T) { stmt, err := p.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) - err = core.Preprocess(tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) + err = core.Preprocess(context.Background(), tk.Session(), stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), tk.Session(), stmt, is) require.NoError(t, err, comment) diff --git a/planner/core/physical_plan_trace_test.go b/planner/core/physical_plan_trace_test.go index b6df1b1869452..c9a74e81469a1 100644 --- a/planner/core/physical_plan_trace_test.go +++ b/planner/core/physical_plan_trace_test.go @@ -90,7 +90,7 @@ func TestPhysicalOptimizeWithTraceEnabled(t *testing.T) { sql := testcase.sql stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) require.NoError(t, err) sctx := core.MockContext() sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true @@ -144,7 +144,7 @@ func TestPhysicalOptimizerTrace(t *testing.T) { stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) require.NoError(t, err) sctx := core.MockContext() sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true @@ -207,7 +207,7 @@ func TestPhysicalOptimizerTraceChildrenNotDuplicated(t *testing.T) { sql := "select * from t" stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) require.NoError(t, err) sctx := core.MockContext() sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index e4bf81791e19c..e58e8b6d91708 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -42,8 +42,7 @@ import ( "go.uber.org/zap" ) -func planCachePreprocess(sctx sessionctx.Context, isGeneralPlanCache bool, is infoschema.InfoSchema, - stmt *PlanCacheStmt, params []expression.Expression) error { +func planCachePreprocess(ctx context.Context, sctx sessionctx.Context, isGeneralPlanCache bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) error { vars := sctx.GetSessionVars() stmtAst := stmt.PreparedAst vars.StmtCtx.StmtType = stmtAst.StmtType @@ -88,7 +87,7 @@ func planCachePreprocess(sctx sessionctx.Context, isGeneralPlanCache bool, is in // We should reset the tableRefs in the prepared update statements, otherwise, the ast nodes still hold the old // tableRefs columnInfo which will cause chaos in logic of trying point get plan. (should ban non-public column) ret := &PreprocessorReturn{InfoSchema: is} - err := Preprocess(sctx, stmtAst.Stmt, InPrepare, WithPreprocessorReturn(ret)) + err := Preprocess(ctx, sctx, stmtAst.Stmt, InPrepare, WithPreprocessorReturn(ret)) if err != nil { return ErrSchemaChanged.GenWithStack("Schema change caused error: %s", err.Error()) } @@ -116,7 +115,7 @@ func planCachePreprocess(sctx sessionctx.Context, isGeneralPlanCache bool, is in func GetPlanFromSessionPlanCache(ctx context.Context, sctx sessionctx.Context, isGeneralPlanCache bool, is infoschema.InfoSchema, stmt *PlanCacheStmt, params []expression.Expression) (plan Plan, names []*types.FieldName, err error) { - if err := planCachePreprocess(sctx, isGeneralPlanCache, is, stmt, params); err != nil { + if err := planCachePreprocess(ctx, sctx, isGeneralPlanCache, is, stmt, params); err != nil { return nil, nil, err } diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index ac73f34e9babb..971d481a315a9 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -84,7 +84,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, } ret := &PreprocessorReturn{} - err := Preprocess(sctx, stmt, InPrepare, WithPreprocessorReturn(ret)) + err := Preprocess(ctx, sctx, stmt, InPrepare, WithPreprocessorReturn(ret)) if err != nil { return nil, nil, 0, err } diff --git a/planner/core/plan_cost_detail_test.go b/planner/core/plan_cost_detail_test.go index 34584773aa6e8..0054cd351bf84 100644 --- a/planner/core/plan_cost_detail_test.go +++ b/planner/core/plan_cost_detail_test.go @@ -133,7 +133,7 @@ func TestPlanCostDetail(t *testing.T) { func optimize(t *testing.T, sql string, p *parser.Parser, ctx sessionctx.Context, dom *domain.Domain) map[int]*tracing.PhysicalPlanCostDetail { stmt, err := p.ParseOneStmt(sql, "", "") require.NoError(t, err) - err = plannercore.Preprocess(ctx, stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) + err = plannercore.Preprocess(context.Background(), ctx, stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: dom.InfoSchema()})) require.NoError(t, err) sctx := plannercore.MockContext() sctx.GetSessionVars().StmtCtx.EnableOptimizeTrace = true diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index 3a92d25719c09..27be9babce05a 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -328,7 +328,7 @@ func TestPointGetId(t *testing.T) { require.Len(t, stmts, 1) stmt := stmts[0] ret := &core.PreprocessorReturn{} - err = core.Preprocess(ctx, stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), ctx, stmt, core.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := planner.Optimize(context.TODO(), ctx, stmt, ret.InfoSchema) require.NoError(t, err) diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index da3f60d907bef..49904d751d489 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -118,12 +118,12 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { // Preprocess resolves table names of the node, and checks some statements' validation. // preprocessReturn used to extract the infoschema for the tableName and the timestamp from the asof clause. -func Preprocess(ctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error { +func Preprocess(ctx context.Context, sctx sessionctx.Context, node ast.Node, preprocessOpt ...PreprocessOpt) error { v := preprocessor{ - ctx: ctx, + sctx: sctx, tableAliasInJoin: make([]map[string]interface{}, 0), preprocessWith: &preprocessWith{cteCanUsed: make([]string, 0), cteBeforeOffset: make([]int, 0)}, - staleReadProcessor: staleread.NewStaleReadProcessor(ctx), + staleReadProcessor: staleread.NewStaleReadProcessor(ctx, sctx), } for _, optFn := range preprocessOpt { optFn(&v) @@ -208,7 +208,7 @@ func (pw *preprocessWith) UpdateCTEConsumerCount(tableName string) { // preprocessor is an ast.Visitor that preprocess // ast Nodes parsed from parser. type preprocessor struct { - ctx sessionctx.Context + sctx sessionctx.Context flag preprocessorFlag stmtTp byte showTp ast.ShowStmtType @@ -299,14 +299,14 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { p.stmtTp = TypeCreate EraseLastSemicolon(node.OriginNode) EraseLastSemicolon(node.HintedNode) - p.checkBindGrammar(node.OriginNode, node.HintedNode, p.ctx.GetSessionVars().CurrentDB) + p.checkBindGrammar(node.OriginNode, node.HintedNode, p.sctx.GetSessionVars().CurrentDB) return in, true case *ast.DropBindingStmt: p.stmtTp = TypeDrop EraseLastSemicolon(node.OriginNode) if node.HintedNode != nil { EraseLastSemicolon(node.HintedNode) - p.checkBindGrammar(node.OriginNode, node.HintedNode, p.ctx.GetSessionVars().CurrentDB) + p.checkBindGrammar(node.OriginNode, node.HintedNode, p.sctx.GetSessionVars().CurrentDB) } return in, true case *ast.RecoverTableStmt, *ast.FlashBackTableStmt: @@ -339,7 +339,7 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { p.flag |= inCreateOrDropTable } case *ast.TableSource: - isModeOracle := p.ctx.GetSessionVars().SQLMode&mysql.ModeOracle != 0 + isModeOracle := p.sctx.GetSessionVars().SQLMode&mysql.ModeOracle != 0 if _, ok := node.Source.(*ast.SelectStmt); ok && !isModeOracle && len(node.AsName.L) == 0 { p.err = dbterror.ErrDerivedMustHaveAlias.GenWithStackByArgs() } @@ -364,14 +364,14 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { // start transaction read only as of timestamp .... // then we need set StmtCtx.IsStaleness as true in order to avoid take tso in PrepareTSFuture. if node.AsOf != nil { - p.ctx.GetSessionVars().StmtCtx.IsStaleness = true + p.sctx.GetSessionVars().StmtCtx.IsStaleness = true p.IsStaleness = true - } else if p.ctx.GetSessionVars().TxnReadTS.PeakTxnReadTS() > 0 { + } else if p.sctx.GetSessionVars().TxnReadTS.PeakTxnReadTS() > 0 { // If the begin statement was like following: // set transaction read only as of timestamp ... // begin // then we need set StmtCtx.IsStaleness as true in order to avoid take tso in PrepareTSFuture. - p.ctx.GetSessionVars().StmtCtx.IsStaleness = true + p.sctx.GetSessionVars().StmtCtx.IsStaleness = true p.IsStaleness = true } default: @@ -442,7 +442,7 @@ func bindableStmtType(node ast.StmtNode) byte { } func (p *preprocessor) tableByName(tn *ast.TableName) (table.Table, error) { - currentDB := p.ctx.GetSessionVars().CurrentDB + currentDB := p.sctx.GetSessionVars().CurrentDB if tn.Schema.String() != "" { currentDB = tn.Schema.L } @@ -462,8 +462,8 @@ func (p *preprocessor) tableByName(tn *ast.TableName) (table.Table, error) { // We should never leak that the table doesn't exist (i.e. attach ErrTableNotExists) // unless we know that the user has permissions to it, should it exist. // By checking here, this makes all SELECT/SHOW/INSERT/UPDATE/DELETE statements safe. - currentUser, activeRoles := p.ctx.GetSessionVars().User, p.ctx.GetSessionVars().ActiveRoles - if pm := privilege.GetPrivilegeManager(p.ctx); pm != nil { + currentUser, activeRoles := p.sctx.GetSessionVars().User, p.sctx.GetSessionVars().ActiveRoles + if pm := privilege.GetPrivilegeManager(p.sctx); pm != nil { if !pm.RequestVerification(activeRoles, sName.L, tn.Name.O, "", mysql.AllPrivMask) { u := currentUser.Username h := currentUser.Hostname @@ -811,7 +811,7 @@ func (p *preprocessor) checkAdminCheckTableGrammar(stmt *ast.AdminStmt) { func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { if stmt.ReferTable != nil { - schema := model.NewCIStr(p.ctx.GetSessionVars().CurrentDB) + schema := model.NewCIStr(p.sctx.GetSessionVars().CurrentDB) if stmt.ReferTable.Schema.String() != "" { schema = stmt.ReferTable.Schema } @@ -985,7 +985,7 @@ func (p *preprocessor) checkDropTableGrammar(stmt *ast.DropTableStmt) { } func (p *preprocessor) checkDropTemporaryTableGrammar(stmt *ast.DropTableStmt) { - currentDB := model.NewCIStr(p.ctx.GetSessionVars().CurrentDB) + currentDB := model.NewCIStr(p.sctx.GetSessionVars().CurrentDB) for _, t := range stmt.Tables { if isIncorrectName(t.Name.String()) { p.err = dbterror.ErrWrongTableName.GenWithStackByArgs(t.Name.String()) @@ -1030,7 +1030,7 @@ func (p *preprocessor) checkNonUniqTableAlias(stmt *ast.Join) { p.tableAliasInJoin = append(p.tableAliasInJoin, make(map[string]interface{})) } tableAliases := p.tableAliasInJoin[len(p.tableAliasInJoin)-1] - isOracleMode := p.ctx.GetSessionVars().SQLMode&mysql.ModeOracle != 0 + isOracleMode := p.sctx.GetSessionVars().SQLMode&mysql.ModeOracle != 0 if !isOracleMode { if err := isTableAliasDuplicate(stmt.Left, tableAliases); err != nil { p.err = err @@ -1103,7 +1103,7 @@ func (p *preprocessor) checkCreateIndexGrammar(stmt *ast.CreateIndexStmt) { } func (p *preprocessor) checkGroupBy(stmt *ast.GroupByClause) { - noopFuncsMode := p.ctx.GetSessionVars().NoopFuncsMode + noopFuncsMode := p.sctx.GetSessionVars().NoopFuncsMode for _, item := range stmt.Items { if !item.NullOrder && noopFuncsMode != variable.OnInt { err := expression.ErrFunctionsNoopImpl.GenWithStackByArgs("GROUP BY expr ASC|DESC") @@ -1112,7 +1112,7 @@ func (p *preprocessor) checkGroupBy(stmt *ast.GroupByClause) { return } // NoopFuncsMode is Warn, append an error - p.ctx.GetSessionVars().StmtCtx.AppendWarning(err) + p.sctx.GetSessionVars().StmtCtx.AppendWarning(err) } } } @@ -1495,7 +1495,7 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { } } - currentDB := p.ctx.GetSessionVars().CurrentDB + currentDB := p.sctx.GetSessionVars().CurrentDB if currentDB == "" { p.err = errors.Trace(ErrNoDB) return @@ -1537,11 +1537,11 @@ func (p *preprocessor) handleTableName(tn *ast.TableName) { p.err = err return } - currentDB := p.ctx.GetSessionVars().CurrentDB + currentDB := p.sctx.GetSessionVars().CurrentDB if tn.Schema.String() != "" { currentDB = tn.Schema.L } - table, err = tryLockMDLAndUpdateSchemaIfNecessary(p.ctx, model.NewCIStr(currentDB), table, p.ensureInfoSchema()) + table, err = tryLockMDLAndUpdateSchemaIfNecessary(p.sctx, model.NewCIStr(currentDB), table, p.ensureInfoSchema()) if err != nil { p.err = err return @@ -1586,8 +1586,8 @@ func (p *preprocessor) handleRepairName(tn *ast.TableName) { p.err = dbterror.ErrRepairTableFail.GenWithStackByArgs("table " + tn.Name.L + " is not in repair") return } - p.ctx.SetValue(domainutil.RepairedTable, tableInfo) - p.ctx.SetValue(domainutil.RepairedDatabase, dbInfo) + p.sctx.SetValue(domainutil.RepairedTable, tableInfo) + p.sctx.SetValue(domainutil.RepairedDatabase, dbInfo) } func (p *preprocessor) resolveShowStmt(node *ast.ShowStmt) { @@ -1595,14 +1595,14 @@ func (p *preprocessor) resolveShowStmt(node *ast.ShowStmt) { if node.Table != nil && node.Table.Schema.L != "" { node.DBName = node.Table.Schema.O } else { - node.DBName = p.ctx.GetSessionVars().CurrentDB + node.DBName = p.sctx.GetSessionVars().CurrentDB } } else if node.Table != nil && node.Table.Schema.L == "" { node.Table.Schema = model.NewCIStr(node.DBName) } if node.User != nil && node.User.CurrentUser { // Fill the Username and Hostname with the current user. - currentUser := p.ctx.GetSessionVars().User + currentUser := p.sctx.GetSessionVars().User if currentUser != nil { node.User.Username = currentUser.Username node.User.Hostname = currentUser.Hostname @@ -1613,7 +1613,7 @@ func (p *preprocessor) resolveShowStmt(node *ast.ShowStmt) { } func (p *preprocessor) resolveExecuteStmt(node *ast.ExecuteStmt) { - prepared, err := GetPreparedStmt(node, p.ctx.GetSessionVars()) + prepared, err := GetPreparedStmt(node, p.sctx.GetSessionVars()) if err != nil { p.err = err return @@ -1708,12 +1708,12 @@ func (p *preprocessor) updateStateFromStaleReadProcessor() error { // or is affected by the tidb_read_staleness session variable, then the statement will be makred as isStaleness // in stmtCtx if p.flag&initTxnContextProvider != 0 { - p.ctx.GetSessionVars().StmtCtx.IsStaleness = true - if !p.ctx.GetSessionVars().InTxn() { - txnManager := sessiontxn.GetTxnManager(p.ctx) + p.sctx.GetSessionVars().StmtCtx.IsStaleness = true + if !p.sctx.GetSessionVars().InTxn() { + txnManager := sessiontxn.GetTxnManager(p.sctx) newTxnRequest := &sessiontxn.EnterNewTxnRequest{ Type: sessiontxn.EnterNewTxnWithReplaceProvider, - Provider: staleread.NewStalenessTxnContextProvider(p.ctx, p.LastSnapshotTS, p.InfoSchema), + Provider: staleread.NewStalenessTxnContextProvider(p.sctx, p.LastSnapshotTS, p.InfoSchema), } if err := txnManager.EnterNewTxn(context.TODO(), newTxnRequest); err != nil { return err @@ -1738,12 +1738,12 @@ func (p *preprocessor) ensureInfoSchema() infoschema.InfoSchema { return p.InfoSchema } - p.InfoSchema = sessiontxn.GetTxnManager(p.ctx).GetTxnInfoSchema() + p.InfoSchema = sessiontxn.GetTxnManager(p.sctx).GetTxnInfoSchema() return p.InfoSchema } func (p *preprocessor) hasAutoConvertWarning(colDef *ast.ColumnDef) bool { - sessVars := p.ctx.GetSessionVars() + sessVars := p.sctx.GetSessionVars() if !sessVars.SQLMode.HasStrictMode() && colDef.Tp.GetType() == mysql.TypeVarchar { colDef.Tp.SetType(mysql.TypeBlob) if colDef.Tp.GetCharset() == charset.CharsetBin { diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 2ff7bb8ce6d6a..6d6ebb1b8bf49 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -15,6 +15,7 @@ package core_test import ( + "context" "strings" "testing" @@ -45,7 +46,7 @@ func runSQL(t *testing.T, ctx sessionctx.Context, is infoschema.InfoSchema, sql if inPrepare { opts = append(opts, core.InPrepare) } - err = core.Preprocess(ctx, stmt, append(opts, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is}))...) + err = core.Preprocess(context.Background(), ctx, stmt, append(opts, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: is}))...) require.Truef(t, terror.ErrorEqual(err, terr), "sql: %s, err:%v", sql, err) } @@ -415,7 +416,7 @@ func TestPreprocessCTE(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) - err = core.Preprocess(tk.Session(), stmts[0]) + err = core.Preprocess(context.Background(), tk.Session(), stmts[0]) require.NoError(t, err) var rs strings.Builder diff --git a/planner/core/stats_test.go b/planner/core/stats_test.go index a301b42f48ff3..d9f4b5a015b38 100644 --- a/planner/core/stats_test.go +++ b/planner/core/stats_test.go @@ -55,7 +55,7 @@ func TestGroupNDVs(t *testing.T) { stmt, err := p.ParseOneStmt(tt, "", "") require.NoError(t, err, comment) ret := &core.PreprocessorReturn{} - err = core.Preprocess(tk.Session(), stmt, core.WithPreprocessorReturn(ret)) + err = core.Preprocess(context.Background(), tk.Session(), stmt, core.WithPreprocessorReturn(ret)) require.NoError(t, err) tk.Session().GetSessionVars().PlanColumnID = 0 builder, _ := core.NewPlanBuilder().Init(tk.Session(), ret.InfoSchema, &hint.BlockHintProcessor{}) diff --git a/planner/funcdep/extract_fd_test.go b/planner/funcdep/extract_fd_test.go index 92a924df6c0be..7e5dca8daf44c 100644 --- a/planner/funcdep/extract_fd_test.go +++ b/planner/funcdep/extract_fd_test.go @@ -218,7 +218,7 @@ func TestFDSet_ExtractFD(t *testing.T) { require.NoError(t, err, comment) tk.Session().GetSessionVars().PlanID = 0 tk.Session().GetSessionVars().PlanColumnID = 0 - err = plannercore.Preprocess(tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) + err = plannercore.Preprocess(context.Background(), tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err) require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).AdviseWarmup()) builder, _ := plannercore.NewPlanBuilder().Init(tk.Session(), is, &hint.BlockHintProcessor{}) @@ -316,7 +316,7 @@ func TestFDSet_ExtractFDForApply(t *testing.T) { require.NoError(t, err, comment) tk.Session().GetSessionVars().PlanID = 0 tk.Session().GetSessionVars().PlanColumnID = 0 - err = plannercore.Preprocess(tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) + err = plannercore.Preprocess(context.Background(), tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err, comment) require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).AdviseWarmup()) builder, _ := plannercore.NewPlanBuilder().Init(tk.Session(), is, &hint.BlockHintProcessor{}) @@ -364,7 +364,7 @@ func TestFDSet_MakeOuterJoin(t *testing.T) { require.NoError(t, err, comment) tk.Session().GetSessionVars().PlanID = 0 tk.Session().GetSessionVars().PlanColumnID = 0 - err = plannercore.Preprocess(tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) + err = plannercore.Preprocess(context.Background(), tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is})) require.NoError(t, err, comment) require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).AdviseWarmup()) builder, _ := plannercore.NewPlanBuilder().Init(tk.Session(), is, &hint.BlockHintProcessor{}) diff --git a/server/conn.go b/server/conn.go index 49a7f97a1df24..61638b4337c16 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1976,7 +1976,7 @@ func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.Stm } // TODO: the preprocess is run twice, we should find some way to avoid do it again. // TODO: handle the PreprocessorReturn. - if err = plannercore.Preprocess(cc.getCtx(), stmt); err != nil { + if err = plannercore.Preprocess(ctx, cc.getCtx(), stmt); err != nil { return nil, err } p := plannercore.TryFastPlan(cc.ctx.Session, stmt) diff --git a/session/nontransactional.go b/session/nontransactional.go index ef8adb541203e..540bb55a6c732 100644 --- a/session/nontransactional.go +++ b/session/nontransactional.go @@ -76,7 +76,7 @@ func (j job) String(redacted bool) string { // HandleNonTransactionalDelete is the entry point for a non-transactional delete func HandleNonTransactionalDelete(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, se Session) (sqlexec.RecordSet, error) { - err := core.Preprocess(se, stmt) + err := core.Preprocess(ctx, se, stmt) if err != nil { return nil, err } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c0b71b4e10af8..12841c40a67f5 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1271,6 +1271,9 @@ type SessionVars struct { // LastPlanReplayerToken indicates the last plan replayer token LastPlanReplayerToken string + // AnalyzePartitionMergeConcurrency indicates concurrency for merging partition stats + AnalyzePartitionMergeConcurrency int + HookContext } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index c350d71db5599..a95d497280f04 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -1917,6 +1917,13 @@ var defaultSysVars = []*SysVar{ s.RangeMaxSize = TidbOptInt64(val, DefTiDBOptRangeMaxSize) return nil }}, + { + Scope: ScopeGlobal | ScopeSession, Name: TiDBMergePartitionStatsConcurrency, Value: strconv.FormatInt(DefTiDBMergePartitionStatsConcurrency, 10), Type: TypeInt, MinValue: 1, MaxValue: MaxConfigurableConcurrency, + SetSession: func(s *SessionVars, val string) error { + s.AnalyzePartitionMergeConcurrency = TidbOptInt(val, DefTiDBMergePartitionStatsConcurrency) + return nil + }, + }, } // FeedbackProbability points to the FeedbackProbability in statistics package. diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 9ebdd9ecc61be..38fecce6f95f4 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -749,6 +749,9 @@ const ( // ranges would exceed the limit, it chooses less accurate ranges such as full range. 0 indicates that there is no memory // limit for ranges. TiDBOptRangeMaxSize = "tidb_opt_range_max_size" + + // TiDBMergePartitionStatsConcurrency indicates the concurrecny when merge partition stats into global stats + TiDBMergePartitionStatsConcurrency = "tidb_merge_partition_stats_concurrency" ) // TiDB vars that have only global scope @@ -1057,6 +1060,7 @@ const ( DefTiDBOptRangeMaxSize = 0 DefTiDBCostModelVer = 1 DefTiDBServerMemoryLimitSessMinSize = 128 << 20 + DefTiDBMergePartitionStatsConcurrency = 1 DefTiDBServerMemoryLimitGCTrigger = 0.7 DefTiDBEnableGOGCTuner = true ) diff --git a/sessiontxn/staleread/processor.go b/sessiontxn/staleread/processor.go index 887b57f28fb45..b01a5c69da577 100644 --- a/sessiontxn/staleread/processor.go +++ b/sessiontxn/staleread/processor.go @@ -50,6 +50,7 @@ type Processor interface { } type baseProcessor struct { + ctx context.Context sctx sessionctx.Context txnManager sessiontxn.TxnManager @@ -59,7 +60,8 @@ type baseProcessor struct { is infoschema.InfoSchema } -func (p *baseProcessor) init(sctx sessionctx.Context) { +func (p *baseProcessor) init(ctx context.Context, sctx sessionctx.Context) { + p.ctx = ctx p.sctx = sctx p.txnManager = sessiontxn.GetTxnManager(sctx) } @@ -135,9 +137,9 @@ type staleReadProcessor struct { } // NewStaleReadProcessor creates a new stale read processor -func NewStaleReadProcessor(sctx sessionctx.Context) Processor { +func NewStaleReadProcessor(ctx context.Context, sctx sessionctx.Context) Processor { p := &staleReadProcessor{} - p.init(sctx) + p.init(ctx, sctx) return p } @@ -155,7 +157,7 @@ func (p *staleReadProcessor) OnSelectTable(tn *ast.TableName) error { } // If `stmtAsOfTS` is not 0, it means we use 'select ... from xxx as of timestamp ...' - stmtAsOfTS, err := parseAndValidateAsOf(p.sctx, tn.AsOf) + stmtAsOfTS, err := parseAndValidateAsOf(p.ctx, p.sctx, tn.AsOf) if err != nil { return err } @@ -238,7 +240,7 @@ func (p *staleReadProcessor) evaluateFromStmtTSOrSysVariable(stmtTS uint64) erro return p.setAsNonStaleRead() } -func parseAndValidateAsOf(sctx sessionctx.Context, asOf *ast.AsOfClause) (uint64, error) { +func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *ast.AsOfClause) (uint64, error) { if asOf == nil { return 0, nil } @@ -248,7 +250,7 @@ func parseAndValidateAsOf(sctx sessionctx.Context, asOf *ast.AsOfClause) (uint64 return 0, err } - if err = sessionctx.ValidateStaleReadTS(context.TODO(), sctx, ts); err != nil { + if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil { return 0, err } diff --git a/sessiontxn/staleread/processor_test.go b/sessiontxn/staleread/processor_test.go index 4a98bff0364fc..5eb9f4aa89936 100644 --- a/sessiontxn/staleread/processor_test.go +++ b/sessiontxn/staleread/processor_test.go @@ -15,6 +15,7 @@ package staleread_test import ( + "context" "fmt" "testing" "time" @@ -331,7 +332,7 @@ func TestStaleReadProcessorInTxn(t *testing.T) { } func createProcessor(t *testing.T, se sessionctx.Context) staleread.Processor { - processor := staleread.NewStaleReadProcessor(se) + processor := staleread.NewStaleReadProcessor(context.Background(), se) require.False(t, processor.IsStaleness()) require.Equal(t, uint64(0), processor.GetStalenessReadTS()) require.Nil(t, processor.GetStalenessTSEvaluatorForPrepare()) diff --git a/statistics/cmsketch.go b/statistics/cmsketch.go index 848a10a653325..5f412371d2a0f 100644 --- a/statistics/cmsketch.go +++ b/statistics/cmsketch.go @@ -21,6 +21,7 @@ import ( "reflect" "sort" "strings" + "time" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -729,7 +730,7 @@ func NewTopN(n int) *TopN { // 1. `*TopN` is the final global-level topN. // 2. `[]TopNMeta` is the left topN value from the partition-level TopNs, but is not placed to global-level TopN. We should put them back to histogram latter. // 3. `[]*Histogram` are the partition-level histograms which just delete some values when we merge the global-level topN. -func MergePartTopN2GlobalTopN(sc *stmtctx.StatementContext, version int, topNs []*TopN, n uint32, hists []*Histogram, isIndex bool) (*TopN, []TopNMeta, []*Histogram, error) { +func MergePartTopN2GlobalTopN(loc *time.Location, version int, topNs []*TopN, n uint32, hists []*Histogram, isIndex bool) (*TopN, []TopNMeta, []*Histogram, error) { if checkEmptyTopNs(topNs) { return nil, nil, hists, nil } @@ -781,7 +782,7 @@ func MergePartTopN2GlobalTopN(sc *stmtctx.StatementContext, version int, topNs [ var err error if types.IsTypeTime(hists[0].Tp.GetType()) { // handle datetime values specially since they are encoded to int and we'll get int values if using DecodeOne. - _, d, err = codec.DecodeAsDateTime(val.Encoded, hists[0].Tp.GetType(), sc.TimeZone) + _, d, err = codec.DecodeAsDateTime(val.Encoded, hists[0].Tp.GetType(), loc) } else if types.IsTypeFloat(hists[0].Tp.GetType()) { _, d, err = codec.DecodeAsFloat32(val.Encoded, hists[0].Tp.GetType()) } else { @@ -866,6 +867,22 @@ func checkEmptyTopNs(topNs []*TopN) bool { return count == 0 } +// SortTopnMeta sort topnMeta +func SortTopnMeta(topnMetas []TopNMeta) []TopNMeta { + slices.SortFunc(topnMetas, func(i, j TopNMeta) bool { + if i.Count != j.Count { + return i.Count > j.Count + } + return bytes.Compare(i.Encoded, j.Encoded) < 0 + }) + return topnMetas +} + +// GetMergedTopNFromSortedSlice returns merged topn +func GetMergedTopNFromSortedSlice(sorted []TopNMeta, n uint32) (*TopN, []TopNMeta) { + return getMergedTopNFromSortedSlice(sorted, n) +} + func getMergedTopNFromSortedSlice(sorted []TopNMeta, n uint32) (*TopN, []TopNMeta) { slices.SortFunc(sorted, func(i, j TopNMeta) bool { if i.Count != j.Count { diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 00660c9756a68..6c08b5245fcaf 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -15,6 +15,7 @@ package handle import ( + "bytes" "context" "encoding/json" "fmt" @@ -28,7 +29,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/ddl/util" + ddlUtil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" @@ -41,6 +42,7 @@ import ( "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mathutil" @@ -56,6 +58,9 @@ import ( const ( // TiDBGlobalStats represents the global-stats for a partitioned table. TiDBGlobalStats = "global" + + // maxPartitionMergeBatchSize indicates the max batch size for a worker to merge partition stats + maxPartitionMergeBatchSize = 256 ) // Handle can update stats info periodically. @@ -83,7 +88,7 @@ type Handle struct { // ddlEventCh is a channel to notify a ddl operation has happened. // It is sent only by owner or the drop stats executor, and read by stats handle. - ddlEventCh chan *util.Event + ddlEventCh chan *ddlUtil.Event // listHead contains all the stats collector required by session. listHead *SessionStatsCollector // globalMap contains all the delta map from collectors when we dump them to KV. @@ -197,7 +202,7 @@ type sessionPool interface { func NewHandle(ctx sessionctx.Context, lease time.Duration, pool sessionPool, tracker sessionctx.SysProcTracker, serverIDGetter func() uint64) (*Handle, error) { cfg := config.GetGlobalConfig() handle := &Handle{ - ddlEventCh: make(chan *util.Event, 100), + ddlEventCh: make(chan *ddlUtil.Event, 100), listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, idxUsageListHead: &SessionIndexUsageCollector{mapper: make(indexUsageMap)}, pool: pool, @@ -547,7 +552,8 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, // Because after merging TopN, some numbers will be left. // These remaining topN numbers will be used as a separate bucket for later histogram merging. var popedTopN []statistics.TopNMeta - globalStats.TopN[i], popedTopN, allHg[i], err = statistics.MergePartTopN2GlobalTopN(sc.GetSessionVars().StmtCtx, sc.GetSessionVars().AnalyzeVersion, allTopN[i], uint32(opts[ast.AnalyzeOptNumTopN]), allHg[i], isIndex == 1) + wrapper := statistics.NewStatsWrapper(allHg[i], allTopN[i]) + globalStats.TopN[i], popedTopN, allHg[i], err = h.mergeGlobalStatsTopN(sc, wrapper, sc.GetSessionVars().StmtCtx.TimeZone, sc.GetSessionVars().AnalyzeVersion, uint32(opts[ast.AnalyzeOptNumTopN]), isIndex == 1) if err != nil { return } @@ -579,6 +585,104 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, return } +func (h *Handle) mergeGlobalStatsTopN(sc sessionctx.Context, wrapper *statistics.StatsWrapper, + timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN, + []statistics.TopNMeta, []*statistics.Histogram, error) { + mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency + // use original method if concurrency equals 1 or for version1 + if mergeConcurrency < 2 { + return statistics.MergePartTopN2GlobalTopN(timeZone, version, wrapper.AllTopN, n, wrapper.AllHg, isIndex) + } + batchSize := len(wrapper.AllTopN) / mergeConcurrency + if batchSize < 1 { + batchSize = 1 + } else if batchSize > maxPartitionMergeBatchSize { + batchSize = maxPartitionMergeBatchSize + } + return h.mergeGlobalStatsTopNByConcurrency(mergeConcurrency, batchSize, wrapper, timeZone, version, n, isIndex) +} + +// mergeGlobalStatsTopNByConcurrency merge partition topN by concurrency +// To merge global stats topn by concurrency, we will separate the partition topn in concurrency part and deal it with different worker. +// mergeConcurrency is used to control the total concurrency of the running worker, and mergeBatchSize is sued to control +// the partition size for each worker to solve it +func (h *Handle) mergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wrapper *statistics.StatsWrapper, + timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN, + []statistics.TopNMeta, []*statistics.Histogram, error) { + if len(wrapper.AllTopN) < mergeConcurrency { + mergeConcurrency = len(wrapper.AllTopN) + } + tasks := make([]*statistics.TopnStatsMergeTask, 0) + for start := 0; start < len(wrapper.AllTopN); { + end := start + mergeBatchSize + if end > len(wrapper.AllTopN) { + end = len(wrapper.AllTopN) + } + task := statistics.NewTopnStatsMergeTask(start, end) + tasks = append(tasks, task) + start = end + } + var wg util.WaitGroupWrapper + taskNum := len(tasks) + taskCh := make(chan *statistics.TopnStatsMergeTask, taskNum) + respCh := make(chan *statistics.TopnStatsMergeResponse, taskNum) + for i := 0; i < mergeConcurrency; i++ { + worker := statistics.NewTopnStatsMergeWorker(taskCh, respCh, wrapper) + wg.Run(func() { + worker.Run(timeZone, isIndex, n, version) + }) + } + for _, task := range tasks { + taskCh <- task + } + close(taskCh) + wg.Wait() + close(respCh) + resps := make([]*statistics.TopnStatsMergeResponse, 0) + + // handle Error + hasErr := false + for resp := range respCh { + if resp.Err != nil { + hasErr = true + } + resps = append(resps, resp) + } + if hasErr { + errMsg := make([]string, 0) + for _, resp := range resps { + if resp.Err != nil { + errMsg = append(errMsg, resp.Err.Error()) + } + } + return nil, nil, nil, errors.New(strings.Join(errMsg, ",")) + } + + // fetch the response from each worker and merge them into global topn stats + sorted := make([]statistics.TopNMeta, 0, mergeConcurrency) + leftTopn := make([]statistics.TopNMeta, 0) + for _, resp := range resps { + if resp.TopN != nil { + sorted = append(sorted, resp.TopN.TopN...) + } + leftTopn = append(leftTopn, resp.PopedTopn...) + for i, removeTopn := range resp.RemoveVals { + // Remove the value from the Hists. + if len(removeTopn) > 0 { + tmp := removeTopn + slices.SortFunc(tmp, func(i, j statistics.TopNMeta) bool { + cmpResult := bytes.Compare(i.Encoded, j.Encoded) + return cmpResult < 0 + }) + wrapper.AllHg[i].RemoveVals(tmp) + } + } + } + + globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n) + return globalTopN, statistics.SortTopnMeta(append(leftTopn, popedTopn...)), wrapper.AllHg, nil +} + func (h *Handle) getTableByPhysicalID(is infoschema.InfoSchema, physicalID int64) (table.Table, bool) { if is.SchemaMetaVersion() != h.mu.schemaVersion { h.mu.schemaVersion = is.SchemaMetaVersion() diff --git a/statistics/histogram.go b/statistics/histogram.go index 2133ccad3b53b..8c662b6f04061 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -997,7 +997,7 @@ func (coll *HistColl) NewHistCollBySelectivity(sctx sessionctx.Context, statsNod Columns: make(map[int64]*Column), Indices: make(map[int64]*Index), Idx2ColumnIDs: coll.Idx2ColumnIDs, - ColID2IdxID: coll.ColID2IdxID, + ColID2IdxIDs: coll.ColID2IdxIDs, Count: coll.Count, } for _, node := range statsNodes { diff --git a/statistics/index.go b/statistics/index.go index 6b8a88501c30e..71d2aa839bd61 100644 --- a/statistics/index.go +++ b/statistics/index.go @@ -346,14 +346,30 @@ func (idx *Index) expBackoffEstimation(sctx sessionctx.Context, coll *HistColl, } colID := colsIDs[i] var ( - count float64 - err error + count float64 + err error + foundStats bool ) - if anotherIdxID, ok := coll.ColID2IdxID[colID]; ok && anotherIdxID != idx.Histogram.ID { - count, err = coll.GetRowCountByIndexRanges(sctx, anotherIdxID, tmpRan) - } else if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sctx, coll.Pseudo) { + if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sctx, coll.Pseudo) { + foundStats = true count, err = coll.GetRowCountByColumnRanges(sctx, colID, tmpRan) - } else { + } + if idxIDs, ok := coll.ColID2IdxIDs[colID]; ok && !foundStats && len(indexRange.LowVal) > 1 { + // Note the `len(indexRange.LowVal) > 1` condition here, it means we only recursively call + // `GetRowCountByIndexRanges()` when the input `indexRange` is a multi-column range. This + // check avoids infinite recursion. + for _, idxID := range idxIDs { + if idxID == idx.Histogram.ID { + continue + } + foundStats = true + count, err = coll.GetRowCountByIndexRanges(sctx, idxID, tmpRan) + if err == nil { + break + } + } + } + if !foundStats { continue } if err != nil { diff --git a/statistics/merge_worker.go b/statistics/merge_worker.go new file mode 100644 index 0000000000000..ac34605835559 --- /dev/null +++ b/statistics/merge_worker.go @@ -0,0 +1,188 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics + +import ( + "time" + + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/hack" +) + +// StatsWrapper wrapper stats +type StatsWrapper struct { + AllHg []*Histogram + AllTopN []*TopN +} + +// NewStatsWrapper returns wrapper +func NewStatsWrapper(hg []*Histogram, topN []*TopN) *StatsWrapper { + return &StatsWrapper{ + AllHg: hg, + AllTopN: topN, + } +} + +type topnStatsMergeWorker struct { + taskCh <-chan *TopnStatsMergeTask + respCh chan<- *TopnStatsMergeResponse + // the stats in the wrapper should only be read during the worker + statsWrapper *StatsWrapper +} + +// NewTopnStatsMergeWorker returns topn merge worker +func NewTopnStatsMergeWorker( + taskCh <-chan *TopnStatsMergeTask, + respCh chan<- *TopnStatsMergeResponse, + wrapper *StatsWrapper) *topnStatsMergeWorker { + worker := &topnStatsMergeWorker{ + taskCh: taskCh, + respCh: respCh, + } + worker.statsWrapper = wrapper + return worker +} + +// TopnStatsMergeTask indicates a task for merge topn stats +type TopnStatsMergeTask struct { + start int + end int +} + +// NewTopnStatsMergeTask returns task +func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask { + return &TopnStatsMergeTask{ + start: start, + end: end, + } +} + +// TopnStatsMergeResponse indicates topn merge worker response +type TopnStatsMergeResponse struct { + TopN *TopN + PopedTopn []TopNMeta + RemoveVals [][]TopNMeta + Err error +} + +// Run runs topn merge like statistics.MergePartTopN2GlobalTopN +func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, + n uint32, + version int) { + for task := range worker.taskCh { + start := task.start + end := task.end + checkTopNs := worker.statsWrapper.AllTopN[start:end] + allTopNs := worker.statsWrapper.AllTopN + allHists := worker.statsWrapper.AllHg + resp := &TopnStatsMergeResponse{} + if checkEmptyTopNs(checkTopNs) { + worker.respCh <- resp + return + } + partNum := len(allTopNs) + checkNum := len(checkTopNs) + topNsNum := make([]int, checkNum) + removeVals := make([][]TopNMeta, partNum) + for i, topN := range checkTopNs { + if topN == nil { + topNsNum[i] = 0 + continue + } + topNsNum[i] = len(topN.TopN) + } + // Different TopN structures may hold the same value, we have to merge them. + counter := make(map[hack.MutableString]float64) + // datumMap is used to store the mapping from the string type to datum type. + // The datum is used to find the value in the histogram. + datumMap := make(map[hack.MutableString]types.Datum) + + for i, topN := range checkTopNs { + if topN.TotalCount() == 0 { + continue + } + for _, val := range topN.TopN { + encodedVal := hack.String(val.Encoded) + _, exists := counter[encodedVal] + counter[encodedVal] += float64(val.Count) + if exists { + // We have already calculated the encodedVal from the histogram, so just continue to next topN value. + continue + } + // We need to check whether the value corresponding to encodedVal is contained in other partition-level stats. + // 1. Check the topN first. + // 2. If the topN doesn't contain the value corresponding to encodedVal. We should check the histogram. + for j := 0; j < partNum; j++ { + if (j == i && version >= 2) || allTopNs[j].findTopN(val.Encoded) != -1 { + continue + } + // Get the encodedVal from the hists[j] + datum, exists := datumMap[encodedVal] + if !exists { + // If the datumMap does not have the encodedVal datum, + // we should generate the datum based on the encoded value. + // This part is copied from the function MergePartitionHist2GlobalHist. + var d types.Datum + if isIndex { + d.SetBytes(val.Encoded) + } else { + var err error + if types.IsTypeTime(allHists[0].Tp.GetType()) { + // handle datetime values specially since they are encoded to int and we'll get int values if using DecodeOne. + _, d, err = codec.DecodeAsDateTime(val.Encoded, allHists[0].Tp.GetType(), timeZone) + } else if types.IsTypeFloat(allHists[0].Tp.GetType()) { + _, d, err = codec.DecodeAsFloat32(val.Encoded, allHists[0].Tp.GetType()) + } else { + _, d, err = codec.DecodeOne(val.Encoded) + } + if err != nil { + resp.Err = err + worker.respCh <- resp + return + } + } + datumMap[encodedVal] = d + datum = d + } + // Get the row count which the value is equal to the encodedVal from histogram. + count, _ := allHists[j].equalRowCount(datum, isIndex) + if count != 0 { + counter[encodedVal] += count + // Remove the value corresponding to encodedVal from the histogram. + removeVals[j] = append(removeVals[j], TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) + } + } + } + } + // record remove values + resp.RemoveVals = removeVals + + numTop := len(counter) + if numTop == 0 { + worker.respCh <- resp + continue + } + sorted := make([]TopNMeta, 0, numTop) + for value, cnt := range counter { + data := hack.Slice(string(value)) + sorted = append(sorted, TopNMeta{Encoded: data, Count: uint64(cnt)}) + } + globalTopN, leftTopN := getMergedTopNFromSortedSlice(sorted, n) + resp.TopN = globalTopN + resp.PopedTopn = leftTopN + worker.respCh <- resp + } +} diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index a71de236c2483..26e79c93cc5fd 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -85,7 +85,7 @@ func BenchmarkSelectivity(b *testing.B) { require.NoErrorf(b, err, "error %v, for expr %s", err, exprs) require.Len(b, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoErrorf(b, err, "for %s", exprs) p, _, err := plannercore.BuildLogicalPlanForTest(context.Background(), sctx, stmts[0], ret.InfoSchema) require.NoErrorf(b, err, "error %v, for building plan, expr %s", err, exprs) @@ -527,7 +527,7 @@ func TestSelectivity(t *testing.T) { require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoErrorf(t, err, "for expr %s", tt.exprs) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoErrorf(t, err, "for building plan, expr %s", err, tt.exprs) @@ -639,7 +639,7 @@ func TestDNFCondSelectivity(t *testing.T) { require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoErrorf(t, err, "error %v, for sql %s", err, tt) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoErrorf(t, err, "error %v, for building plan, sql %s", err, tt) diff --git a/statistics/table.go b/statistics/table.go index 81cb4e9bf284f..2e66e39ab8152 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -101,10 +101,10 @@ type HistColl struct { Indices map[int64]*Index // Idx2ColumnIDs maps the index id to its column ids. It's used to calculate the selectivity in planner. Idx2ColumnIDs map[int64][]int64 - // ColID2IdxID maps the column id to index id whose first column is it. It's used to calculate the selectivity in planner. - ColID2IdxID map[int64]int64 - Count int64 - ModifyCount int64 // Total modify count in a table. + // ColID2IdxIDs maps the column id to a list index ids whose first column is it. It's used to calculate the selectivity in planner. + ColID2IdxIDs map[int64][]int64 + Count int64 + ModifyCount int64 // Total modify count in a table. // HavePhysicalID is true means this HistColl is from single table and have its ID's information. // The physical id is used when try to load column stats from storage. @@ -846,7 +846,7 @@ func (coll *HistColl) ID2UniqueID(columns []*expression.Column) *HistColl { return newColl } -// GenerateHistCollFromColumnInfo generates a new HistColl whose ColID2IdxID and IdxID2ColIDs is built from the given parameter. +// GenerateHistCollFromColumnInfo generates a new HistColl whose ColID2IdxIDs and IdxID2ColIDs is built from the given parameter. func (coll *HistColl) GenerateHistCollFromColumnInfo(infos []*model.ColumnInfo, columns []*expression.Column) *HistColl { newColHistMap := make(map[int64]*Column) colInfoID2UniqueID := make(map[int64]int64, len(columns)) @@ -869,7 +869,7 @@ func (coll *HistColl) GenerateHistCollFromColumnInfo(infos []*model.ColumnInfo, } newIdxHistMap := make(map[int64]*Index) idx2Columns := make(map[int64][]int64) - colID2IdxID := make(map[int64]int64) + colID2IdxIDs := make(map[int64][]int64) for _, idxHist := range coll.Indices { ids := make([]int64, 0, len(idxHist.Info.Columns)) for _, idxCol := range idxHist.Info.Columns { @@ -883,10 +883,13 @@ func (coll *HistColl) GenerateHistCollFromColumnInfo(infos []*model.ColumnInfo, if len(ids) == 0 { continue } - colID2IdxID[ids[0]] = idxHist.ID + colID2IdxIDs[ids[0]] = append(colID2IdxIDs[ids[0]], idxHist.ID) newIdxHistMap[idxHist.ID] = idxHist idx2Columns[idxHist.ID] = ids } + for _, idxIDs := range colID2IdxIDs { + slices.Sort(idxIDs) + } newColl := &HistColl{ PhysicalID: coll.PhysicalID, HavePhysicalID: coll.HavePhysicalID, @@ -895,7 +898,7 @@ func (coll *HistColl) GenerateHistCollFromColumnInfo(infos []*model.ColumnInfo, ModifyCount: coll.ModifyCount, Columns: newColHistMap, Indices: newIdxHistMap, - ColID2IdxID: colID2IdxID, + ColID2IdxIDs: colID2IdxIDs, Idx2ColumnIDs: idx2Columns, } return newColl @@ -1084,8 +1087,9 @@ func (coll *HistColl) getIndexRowCount(sctx sessionctx.Context, idxID int64, ind colID = colIDs[rangePosition] } // prefer index stats over column stats - if idx, ok := coll.ColID2IdxID[colID]; ok { - count, err = coll.GetRowCountByIndexRanges(sctx, idx, []*ranger.Range{&rang}) + if idxIDs, ok := coll.ColID2IdxIDs[colID]; ok && len(idxIDs) > 0 { + idxID := idxIDs[0] + count, err = coll.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{&rang}) } else { count, err = coll.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{&rang}) } diff --git a/util/dbutil/common.go b/util/dbutil/common.go index 37b6da5fd1f49..df54e18bd6909 100644 --- a/util/dbutil/common.go +++ b/util/dbutil/common.go @@ -19,7 +19,7 @@ import ( "database/sql" "encoding/json" "fmt" - "net/url" + "net" "os" "strconv" "strings" @@ -107,26 +107,31 @@ func GetDBConfigFromEnv(schema string) DBConfig { // OpenDB opens a mysql connection FD func OpenDB(cfg DBConfig, vars map[string]string) (*sql.DB, error) { - var dbDSN string + driverCfg := mysql.NewConfig() + driverCfg.Params = make(map[string]string) + driverCfg.User = cfg.User + driverCfg.Passwd = cfg.Password + driverCfg.Net = "tcp" + driverCfg.Addr = net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)) + driverCfg.Params["charset"] = "utf8mb4" + if len(cfg.Snapshot) != 0 { log.Info("create connection with snapshot", zap.String("snapshot", cfg.Snapshot)) - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&tidb_snapshot=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Snapshot) - } else { - dbDSN = fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port) + driverCfg.Params["tidb_snapshot"] = cfg.Snapshot } for key, val := range vars { // key='val'. add single quote for better compatibility. - dbDSN += fmt.Sprintf("&%s=%%27%s%%27", key, url.QueryEscape(val)) + driverCfg.Params[key] = fmt.Sprintf("'%s'", val) } - dbConn, err := sql.Open("mysql", dbDSN) + c, err := mysql.NewConnector(driverCfg) if err != nil { return nil, errors.Trace(err) } - - err = dbConn.Ping() - return dbConn, errors.Trace(err) + db := sql.OpenDB(c) + err = db.Ping() + return db, errors.Trace(err) } // CloseDB closes the mysql fd diff --git a/util/ranger/bench_test.go b/util/ranger/bench_test.go index e72a5caa31ee6..c980a806f6066 100644 --- a/util/ranger/bench_test.go +++ b/util/ranger/bench_test.go @@ -109,7 +109,7 @@ WHERE require.NoError(b, err) require.Len(b, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(b, err) ctx := context.Background() p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) diff --git a/util/ranger/ranger_test.go b/util/ranger/ranger_test.go index ccaba66b5e50d..411abe846499b 100644 --- a/util/ranger/ranger_test.go +++ b/util/ranger/ranger_test.go @@ -264,7 +264,7 @@ func TestTableRange(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -453,7 +453,7 @@ create table t( require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -814,7 +814,7 @@ func TestColumnRange(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1196,7 +1196,7 @@ func TestIndexRangeForYear(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1264,7 +1264,7 @@ func TestPrefixIndexRangeScan(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1673,7 +1673,7 @@ create table t( require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1914,7 +1914,7 @@ func TestTableShardIndex(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1942,7 +1942,7 @@ func TestTableShardIndex(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -1960,7 +1960,7 @@ func TestTableShardIndex(t *testing.T) { require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err) @@ -2105,7 +2105,7 @@ func getSelectionFromQuery(t *testing.T, sctx sessionctx.Context, sql string) *p require.NoError(t, err) require.Len(t, stmts, 1) ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess(sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) + err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret)) require.NoError(t, err) p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema) require.NoError(t, err)