diff --git a/.golangci.yml b/.golangci.yml index dddccb2f25e39..4324251605d48 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,5 @@ run: - timeout: 7m + timeout: 10m linters: disable-all: true enable: diff --git a/br/cmd/tidb-lightning/main.go b/br/cmd/tidb-lightning/main.go index 083e47e82d65d..84362433f222c 100644 --- a/br/cmd/tidb-lightning/main.go +++ b/br/cmd/tidb-lightning/main.go @@ -30,7 +30,10 @@ import ( func main() { globalCfg := config.Must(config.LoadGlobalConfig(os.Args[1:], nil)) - fmt.Fprintf(os.Stdout, "Verbose debug logs will be written to %s\n\n", globalCfg.App.Config.File) + logToFile := globalCfg.App.File != "" && globalCfg.App.File != "-" + if logToFile { + fmt.Fprintf(os.Stdout, "Verbose debug logs will be written to %s\n\n", globalCfg.App.Config.File) + } app := lightning.New(globalCfg) @@ -95,7 +98,7 @@ func main() { } // call Sync() with log to stdout may return error in some case, so just skip it - if globalCfg.App.File != "" { + if logToFile { syncErr := logger.Sync() if syncErr != nil { fmt.Fprintln(os.Stderr, "sync log failed", syncErr) diff --git a/br/pkg/backup/client_test.go b/br/pkg/backup/client_test.go index 3c3688f79bc9f..de4b6dfd2b588 100644 --- a/br/pkg/backup/client_test.go +++ b/br/pkg/backup/client_test.go @@ -258,7 +258,6 @@ func (r *testBackup) TestSendCreds(c *C) { c.Assert(err, IsNil) opts := &storage.ExternalStorageOptions{ SendCredentials: true, - SkipCheckPath: true, } _, err = storage.New(r.ctx, backend, opts) c.Assert(err, IsNil) @@ -277,7 +276,6 @@ func (r *testBackup) TestSendCreds(c *C) { c.Assert(err, IsNil) opts = &storage.ExternalStorageOptions{ SendCredentials: false, - SkipCheckPath: true, } _, err = storage.New(r.ctx, backend, opts) c.Assert(err, IsNil) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 030c53b1509c3..7c56a8e8df9dd 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -92,7 +92,8 @@ const ( gRPCBackOffMaxDelay = 10 * time.Minute // See: https://github.com/tikv/tikv/blob/e030a0aae9622f3774df89c62f21b2171a72a69e/etc/config-template.toml#L360 - regionMaxKeyCount = 1_440_000 + // lower the max-key-count to avoid tikv trigger region auto split + regionMaxKeyCount = 1_280_000 defaultRegionSplitSize = 96 * units.MiB propRangeIndex = "tikv.range_index" @@ -1513,7 +1514,12 @@ func (local *local) WriteToTiKV( size := int64(0) totalCount := int64(0) firstLoop := true - regionMaxSize := regionSplitSize * 4 / 3 + // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split + // because the range-properties is not 100% accurate + regionMaxSize := regionSplitSize + if regionSplitSize <= defaultRegionSplitSize { + regionMaxSize = regionSplitSize * 4 / 3 + } for iter.First(); iter.Valid(); iter.Next() { size += int64(len(iter.Key()) + len(iter.Value())) diff --git a/br/pkg/lightning/backend/noop/noop.go b/br/pkg/lightning/backend/noop/noop.go index ac02ab9482e1f..430c4c5a83e8c 100644 --- a/br/pkg/lightning/backend/noop/noop.go +++ b/br/pkg/lightning/backend/noop/noop.go @@ -140,7 +140,7 @@ func (b noopBackend) ResetEngine(ctx context.Context, engineUUID uuid.UUID) erro // LocalWriter obtains a thread-local EngineWriter for writing rows into the given engine. func (b noopBackend) LocalWriter(context.Context, *backend.LocalWriterConfig, uuid.UUID) (backend.EngineWriter, error) { - return noopWriter{}, nil + return Writer{}, nil } func (b noopBackend) CollectLocalDuplicateRows(ctx context.Context, tbl table.Table, tableName string, opts *kv.SessionOptions) (bool, error) { @@ -174,16 +174,23 @@ func (r noopRow) Size() uint64 { func (r noopRow) ClassifyAndAppend(*kv.Rows, *verification.KVChecksum, *kv.Rows, *verification.KVChecksum) { } -type noopWriter struct{} +// Writer define a local writer that do nothing. +type Writer struct{} -func (w noopWriter) AppendRows(context.Context, string, []string, kv.Rows) error { +func (w Writer) AppendRows(context.Context, string, []string, kv.Rows) error { return nil } -func (w noopWriter) IsSynced() bool { +func (w Writer) IsSynced() bool { return true } -func (w noopWriter) Close(context.Context) (backend.ChunkFlushStatus, error) { - return nil, nil +func (w Writer) Close(context.Context) (backend.ChunkFlushStatus, error) { + return trueStatus{}, nil +} + +type trueStatus struct{} + +func (s trueStatus) Flushed() bool { + return true } diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index 938d2bae72d9e..1b95fe558ef88 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -331,12 +331,12 @@ func (enc *tidbEncoder) Encode(logger log.Logger, row []types.Datum, _ int64, co } // EncodeRowForRecord encodes a row to a string compatible with INSERT statements. -func EncodeRowForRecord(encTable table.Table, sqlMode mysql.SQLMode, row []types.Datum) string { +func EncodeRowForRecord(encTable table.Table, sqlMode mysql.SQLMode, row []types.Datum, columnPermutation []int) string { enc := tidbEncoder{ tbl: encTable, mode: sqlMode, } - resRow, err := enc.Encode(log.L(), row, 0, nil, "", 0) + resRow, err := enc.Encode(log.L(), row, 0, columnPermutation, "", 0) if err != nil { return fmt.Sprintf("/* ERROR: %s */", err) } diff --git a/br/pkg/lightning/common/storage_unix.go b/br/pkg/lightning/common/storage_unix.go index ba22e92354ceb..7e602cbe58eec 100644 --- a/br/pkg/lightning/common/storage_unix.go +++ b/br/pkg/lightning/common/storage_unix.go @@ -23,13 +23,18 @@ import ( "syscall" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "golang.org/x/sys/unix" ) // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { - var stat unix.Statfs_t + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + injectedSize := val.(int) + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) + var stat unix.Statfs_t err = unix.Statfs(dir, &stat) if err != nil { return size, errors.Annotatef(err, "cannot get disk capacity at %s", dir) diff --git a/br/pkg/lightning/common/storage_windows.go b/br/pkg/lightning/common/storage_windows.go index 21a2398ad66c3..a95e8f8eeebfc 100644 --- a/br/pkg/lightning/common/storage_windows.go +++ b/br/pkg/lightning/common/storage_windows.go @@ -23,6 +23,7 @@ import ( "unsafe" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" ) var ( @@ -32,6 +33,10 @@ var ( // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + injectedSize := val.(int) + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) r, _, e := getDiskFreeSpaceExW.Call( uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(dir))), uintptr(unsafe.Pointer(&size.Available)), diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index 8f2e6f2dfa9ac..438ec148b4118 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -101,7 +101,7 @@ const ( ) var ( - supportedStorageTypes = []string{"file", "local", "s3", "noop", "gcs"} + supportedStorageTypes = []string{"file", "local", "s3", "noop", "gcs", "gs"} DefaultFilter = []string{ "*.*", diff --git a/br/pkg/lightning/config/config_test.go b/br/pkg/lightning/config/config_test.go index 1e7e751b20b3d..4710f493dca93 100644 --- a/br/pkg/lightning/config/config_test.go +++ b/br/pkg/lightning/config/config_test.go @@ -156,6 +156,33 @@ func (s *configTestSuite) TestAdjustInvalidBackend(c *C) { c.Assert(err, ErrorMatches, "invalid config: unsupported `tikv-importer\\.backend` \\(no_such_backend\\)") } +func (s *configTestSuite) TestCheckAndAdjustFilePath(c *C) { + tmpDir := c.MkDir() + // use slashPath in url to be compatible with windows + slashPath := filepath.ToSlash(tmpDir) + + cfg := config.NewConfig() + cases := []string{ + tmpDir, + ".", + "file://" + slashPath, + "local://" + slashPath, + "s3://bucket_name", + "s3://bucket_name/path/to/dir", + "gcs://bucketname/path/to/dir", + "gs://bucketname/path/to/dir", + "noop:///", + } + + for _, testCase := range cases { + cfg.Mydumper.SourceDir = testCase + + err := cfg.CheckAndAdjustFilePath() + c.Assert(err, IsNil) + } + +} + func (s *configTestSuite) TestAdjustFileRoutePath(c *C) { cfg := config.NewConfig() assignMinimalLegalValue(cfg) @@ -581,6 +608,13 @@ func (s *configTestSuite) TestLoadConfig(c *C) { result := taskCfg.String() c.Assert(result, Matches, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`) + + cfg, err = config.LoadGlobalConfig([]string{}, nil) + c.Assert(err, IsNil) + c.Assert(cfg.App.Config.File, Matches, ".*lightning.log.*") + cfg, err = config.LoadGlobalConfig([]string{"--log-file", "-"}, nil) + c.Assert(err, IsNil) + c.Assert(cfg.App.Config.File, Equals, "-") } func (s *configTestSuite) TestDefaultImporterBackendValue(c *C) { diff --git a/br/pkg/lightning/config/global.go b/br/pkg/lightning/config/global.go index 7eb8e240c9dfe..d9bd80ef4139a 100644 --- a/br/pkg/lightning/config/global.go +++ b/br/pkg/lightning/config/global.go @@ -200,10 +200,7 @@ func LoadGlobalConfig(args []string, extraFlags func(*flag.FlagSet)) (*GlobalCon if *logFilePath != "" { cfg.App.Config.File = *logFilePath } - // "-" is a special config for log to stdout - if cfg.App.Config.File == "-" { - cfg.App.Config.File = "" - } else if cfg.App.Config.File == "" { + if cfg.App.Config.File == "" { cfg.App.Config.File = timestampLogFileName() } if *tidbHost != "" { diff --git a/br/pkg/lightning/lightning.go b/br/pkg/lightning/lightning.go index 0a9cf1585456a..723b73518fa03 100644 --- a/br/pkg/lightning/lightning.go +++ b/br/pkg/lightning/lightning.go @@ -297,6 +297,19 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, g glue. return errors.Annotate(err, "create storage failed") } + // return expectedErr means at least meet one file + expectedErr := errors.New("Stop Iter") + walkErr := s.WalkDir(ctx, &storage.WalkOption{ListCount: 1}, func(string, int64) error { + // return an error when meet the first regular file to break the walk loop + return expectedErr + }) + if !errors.ErrorEqual(walkErr, expectedErr) { + if walkErr == nil { + return errors.Errorf("data-source-dir '%s' doesn't exist or contains no files", taskCfg.Mydumper.SourceDir) + } + return errors.Annotatef(walkErr, "visit data-source-dir '%s' failed", taskCfg.Mydumper.SourceDir) + } + loadTask := log.L().Begin(zap.InfoLevel, "load data source") var mdl *mydump.MDLoader mdl, err = mydump.NewMyDumpLoaderWithStore(ctx, taskCfg, s) diff --git a/br/pkg/lightning/log/log.go b/br/pkg/lightning/log/log.go index 2dc24acac1541..8521cf85a6579 100644 --- a/br/pkg/lightning/log/log.go +++ b/br/pkg/lightning/log/log.go @@ -93,8 +93,8 @@ func InitLogger(cfg *Config, tidbLoglevel string) error { // Filter logs from TiDB and PD. return NewFilterCore(core, "github.com/tikv/pd/") }) - - if len(cfg.File) > 0 { + // "-" is a special config for log to stdout. + if len(cfg.File) > 0 && cfg.File != "-" { logCfg.File = pclog.FileLogConfig{ Filename: cfg.File, MaxSize: cfg.FileMaxSize, diff --git a/br/pkg/lightning/log/log_serial_test.go b/br/pkg/lightning/log/log_serial_test.go new file mode 100644 index 0000000000000..63ef2bf321ab1 --- /dev/null +++ b/br/pkg/lightning/log/log_serial_test.go @@ -0,0 +1,43 @@ +package log_test + +import ( + "io" + "os" + "testing" + + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/stretchr/testify/require" +) + +func TestInitStdoutLogger(t *testing.T) { + r, w, err := os.Pipe() + require.NoError(t, err) + oldStdout := os.Stdout + os.Stdout = w + + msg := "logger is initialized to stdout" + outputC := make(chan string, 1) + go func() { + buf := make([]byte, 4096) + n := 0 + for { + nn, err := r.Read(buf[n:]) + if nn == 0 || err == io.EOF { + break + } + require.NoError(t, err) + n += nn + } + outputC <- string(buf[:n]) + }() + + logCfg := &log.Config{File: "-"} + log.InitLogger(logCfg, "info") + log.L().Info(msg) + + os.Stdout = oldStdout + require.NoError(t, w.Close()) + output := <-outputC + require.NoError(t, r.Close()) + require.Contains(t, output, msg) +} diff --git a/br/pkg/lightning/mydump/loader.go b/br/pkg/lightning/mydump/loader.go index 27bab8fa5cf7b..e50f61c1dce5a 100644 --- a/br/pkg/lightning/mydump/loader.go +++ b/br/pkg/lightning/mydump/loader.go @@ -18,10 +18,12 @@ import ( "context" "path/filepath" "sort" + "strings" "github.com/pingcap/errors" filter "github.com/pingcap/tidb-tools/pkg/table-filter" router "github.com/pingcap/tidb-tools/pkg/table-router" + "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/storage" @@ -30,12 +32,30 @@ import ( type MDDatabaseMeta struct { Name string - SchemaFile string + SchemaFile FileInfo Tables []*MDTableMeta Views []*MDTableMeta charSet string } +func (m *MDDatabaseMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) (string, error) { + schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) + if err != nil { + log.L().Warn("failed to extract table schema", + zap.String("Path", m.SchemaFile.FileMeta.Path), + log.ShortError(err), + ) + schema = nil + } + schemaStr := strings.TrimSpace(string(schema)) + // set default if schema sql is empty + if len(schemaStr) == 0 { + schemaStr = "CREATE DATABASE IF NOT EXISTS " + common.EscapeIdentifier(m.Name) + } + + return schemaStr, nil +} + type MDTableMeta struct { DB string Name string @@ -219,7 +239,7 @@ func (s *mdLoaderSetup) setup(ctx context.Context, store storage.ExternalStorage // setup database schema if len(s.dbSchemas) != 0 { for _, fileInfo := range s.dbSchemas { - if _, dbExists := s.insertDB(fileInfo.TableName.Schema, fileInfo.FileMeta.Path); dbExists && s.loader.router == nil { + if _, dbExists := s.insertDB(fileInfo); dbExists && s.loader.router == nil { return errors.Errorf("invalid database schema file, duplicated item - %s", fileInfo.FileMeta.Path) } } @@ -406,15 +426,15 @@ func (s *mdLoaderSetup) route() error { return nil } -func (s *mdLoaderSetup) insertDB(dbName string, path string) (*MDDatabaseMeta, bool) { - dbIndex, ok := s.dbIndexMap[dbName] +func (s *mdLoaderSetup) insertDB(f FileInfo) (*MDDatabaseMeta, bool) { + dbIndex, ok := s.dbIndexMap[f.TableName.Schema] if ok { return s.loader.dbs[dbIndex], true } - s.dbIndexMap[dbName] = len(s.loader.dbs) + s.dbIndexMap[f.TableName.Schema] = len(s.loader.dbs) ptr := &MDDatabaseMeta{ - Name: dbName, - SchemaFile: path, + Name: f.TableName.Schema, + SchemaFile: f, charSet: s.loader.charSet, } s.loader.dbs = append(s.loader.dbs, ptr) @@ -422,7 +442,13 @@ func (s *mdLoaderSetup) insertDB(dbName string, path string) (*MDDatabaseMeta, b } func (s *mdLoaderSetup) insertTable(fileInfo FileInfo) (*MDTableMeta, bool, bool) { - dbMeta, dbExists := s.insertDB(fileInfo.TableName.Schema, "") + dbFileInfo := FileInfo{ + TableName: filter.Table{ + Schema: fileInfo.TableName.Schema, + }, + FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, + } + dbMeta, dbExists := s.insertDB(dbFileInfo) tableIndex, ok := s.tableIndexMap[fileInfo.TableName] if ok { return dbMeta.Tables[tableIndex], dbExists, true @@ -442,7 +468,13 @@ func (s *mdLoaderSetup) insertTable(fileInfo FileInfo) (*MDTableMeta, bool, bool } func (s *mdLoaderSetup) insertView(fileInfo FileInfo) (bool, bool) { - dbMeta, dbExists := s.insertDB(fileInfo.TableName.Schema, "") + dbFileInfo := FileInfo{ + TableName: filter.Table{ + Schema: fileInfo.TableName.Schema, + }, + FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, + } + dbMeta, dbExists := s.insertDB(dbFileInfo) _, ok := s.tableIndexMap[fileInfo.TableName] if ok { meta := &MDTableMeta{ diff --git a/br/pkg/lightning/mydump/loader_test.go b/br/pkg/lightning/mydump/loader_test.go index 76bc50eba2793..08442cacffd86 100644 --- a/br/pkg/lightning/mydump/loader_test.go +++ b/br/pkg/lightning/mydump/loader_test.go @@ -179,6 +179,9 @@ func (s *testMydumpLoaderSuite) TestTableInfoNotFound(c *C) { loader, err := md.NewMyDumpLoader(ctx, s.cfg) c.Assert(err, IsNil) for _, dbMeta := range loader.GetDatabases() { + dbSQL, err := dbMeta.GetSchema(ctx, store) + c.Assert(err, IsNil) + c.Assert(dbSQL, Equals, "CREATE DATABASE IF NOT EXISTS `db`") for _, tblMeta := range dbMeta.Tables { sql, err := tblMeta.GetSchema(ctx, store) c.Assert(sql, Equals, "") @@ -272,8 +275,14 @@ func (s *testMydumpLoaderSuite) TestDataWithoutSchema(c *C) { mdl, err := md.NewMyDumpLoader(context.Background(), s.cfg) c.Assert(err, IsNil) c.Assert(mdl.GetDatabases(), DeepEquals, []*md.MDDatabaseMeta{{ - Name: "db", - SchemaFile: "", + Name: "db", + SchemaFile: md.FileInfo{ + TableName: filter.Table{ + Schema: "db", + Name: "", + }, + FileMeta: md.SourceFileMeta{Type: md.SourceTypeSchemaSchema}, + }, Tables: []*md.MDTableMeta{{ DB: "db", Name: "tbl", @@ -302,7 +311,7 @@ func (s *testMydumpLoaderSuite) TestTablesWithDots(c *C) { c.Assert(err, IsNil) c.Assert(mdl.GetDatabases(), DeepEquals, []*md.MDDatabaseMeta{{ Name: "db", - SchemaFile: "db-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "db", Name: ""}, FileMeta: md.SourceFileMeta{Path: "db-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "db", @@ -396,7 +405,7 @@ func (s *testMydumpLoaderSuite) TestRouter(c *C) { c.Assert(mdl.GetDatabases(), DeepEquals, []*md.MDDatabaseMeta{ { Name: "a1", - SchemaFile: "a1-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "a1", Name: ""}, FileMeta: md.SourceFileMeta{Path: "a1-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "a1", @@ -427,11 +436,11 @@ func (s *testMydumpLoaderSuite) TestRouter(c *C) { }, { Name: "d0", - SchemaFile: "d0-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "d0", Name: ""}, FileMeta: md.SourceFileMeta{Path: "d0-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, }, { Name: "b", - SchemaFile: "a0-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "b", Name: ""}, FileMeta: md.SourceFileMeta{Path: "a0-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "b", @@ -449,7 +458,7 @@ func (s *testMydumpLoaderSuite) TestRouter(c *C) { }, { Name: "c", - SchemaFile: "c0-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "c", Name: ""}, FileMeta: md.SourceFileMeta{Path: "c0-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "c", @@ -463,7 +472,7 @@ func (s *testMydumpLoaderSuite) TestRouter(c *C) { }, { Name: "v", - SchemaFile: "e0-schema-create.sql", + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "v", Name: ""}, FileMeta: md.SourceFileMeta{Path: "e0-schema-create.sql", Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "v", @@ -552,7 +561,7 @@ func (s *testMydumpLoaderSuite) TestFileRouting(c *C) { c.Assert(mdl.GetDatabases(), DeepEquals, []*md.MDDatabaseMeta{ { Name: "d1", - SchemaFile: filepath.FromSlash("d1/schema.sql"), + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "d1", Name: ""}, FileMeta: md.SourceFileMeta{Path: filepath.FromSlash("d1/schema.sql"), Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "d1", @@ -605,7 +614,7 @@ func (s *testMydumpLoaderSuite) TestFileRouting(c *C) { }, { Name: "d2", - SchemaFile: filepath.FromSlash("d2/schema.sql"), + SchemaFile: md.FileInfo{TableName: filter.Table{Schema: "d2", Name: ""}, FileMeta: md.SourceFileMeta{Path: filepath.FromSlash("d2/schema.sql"), Type: md.SourceTypeSchemaSchema}}, Tables: []*md.MDTableMeta{ { DB: "d2", diff --git a/br/pkg/lightning/restore/check_info.go b/br/pkg/lightning/restore/check_info.go index 1b86ee482f362..d597b6e2646fb 100644 --- a/br/pkg/lightning/restore/check_info.go +++ b/br/pkg/lightning/restore/check_info.go @@ -454,33 +454,31 @@ func (rc *Controller) localResource(sourceSize int64) error { if err != nil { return errors.Trace(err) } - localAvailable := storageSize.Available + localAvailable := int64(storageSize.Available) var message string var passed bool switch { - case localAvailable > uint64(sourceSize): + case localAvailable > sourceSize: message = fmt.Sprintf("local disk resources are rich, estimate sorted data size %s, local available is %s", units.BytesSize(float64(sourceSize)), units.BytesSize(float64(localAvailable))) passed = true + case int64(rc.cfg.TikvImporter.DiskQuota) > localAvailable: + message = fmt.Sprintf("local disk space may not enough to finish import, estimate sorted data size is %s,"+ + " but local available is %s, please set `tikv-importer.disk-quota` to a smaller value than %s"+ + " or change `mydumper.sorted-kv-dir` to another disk with enough space to finish imports", + units.BytesSize(float64(sourceSize)), + units.BytesSize(float64(localAvailable)), units.BytesSize(float64(localAvailable))) + passed = false + log.L().Error(message) default: - if int64(rc.cfg.TikvImporter.DiskQuota) > int64(localAvailable) { - message = fmt.Sprintf("local disk space may not enough to finish import"+ - "estimate sorted data size is %s, but local available is %s,"+ - "you need a smaller number for tikv-importer.disk-quota (%s) to finish imports", - units.BytesSize(float64(sourceSize)), - units.BytesSize(float64(localAvailable)), units.BytesSize(float64(rc.cfg.TikvImporter.DiskQuota))) - passed = false - log.L().Error(message) - } else { - message = fmt.Sprintf("local disk space may not enough to finish import, "+ - "estimate sorted data size is %s, but local available is %s,"+ - "we will use disk-quota (size: %s) to finish imports, which may slow down import", - units.BytesSize(float64(sourceSize)), - units.BytesSize(float64(localAvailable)), units.BytesSize(float64(rc.cfg.TikvImporter.DiskQuota))) - passed = true - log.L().Warn(message) - } + message = fmt.Sprintf("local disk space may not enough to finish import, "+ + "estimate sorted data size is %s, but local available is %s,"+ + "we will use disk-quota (size: %s) to finish imports, which may slow down import", + units.BytesSize(float64(sourceSize)), + units.BytesSize(float64(localAvailable)), units.BytesSize(float64(rc.cfg.TikvImporter.DiskQuota))) + passed = true + log.L().Warn(message) } rc.checkTemplate.Collect(Critical, passed, message) return nil diff --git a/br/pkg/lightning/restore/check_info_test.go b/br/pkg/lightning/restore/check_info_test.go new file mode 100644 index 0000000000000..98556d6f78ef7 --- /dev/null +++ b/br/pkg/lightning/restore/check_info_test.go @@ -0,0 +1,80 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package restore + +import ( + "context" + + . "github.com/pingcap/check" + "github.com/pingcap/failpoint" + + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/worker" + "github.com/pingcap/tidb/br/pkg/storage" +) + +var _ = Suite(&checkInfoSuite{}) + +type checkInfoSuite struct{} + +func (s *checkInfoSuite) TestLocalResource(c *C) { + dir := c.MkDir() + mockStore, err := storage.NewLocalStorage(dir) + c.Assert(err, IsNil) + + err = failpoint.Enable("github.com/pingcap/tidb/br/pkg/lightning/common/GetStorageSize", "return(2048)") + c.Assert(err, IsNil) + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/GetStorageSize") + }() + + cfg := config.NewConfig() + cfg.Mydumper.SourceDir = dir + cfg.TikvImporter.SortedKVDir = dir + cfg.TikvImporter.Backend = "local" + rc := &Controller{ + cfg: cfg, + store: mockStore, + ioWorkers: worker.NewPool(context.Background(), 1, "io"), + } + + // 1. source-size is smaller than disk-size, won't trigger error information + rc.checkTemplate = NewSimpleTemplate() + err = rc.localResource(1000) + c.Assert(err, IsNil) + tmpl := rc.checkTemplate.(*SimpleTemplate) + c.Assert(tmpl.warnFailedCount, Equals, 1) + c.Assert(tmpl.criticalFailedCount, Equals, 0) + c.Assert(tmpl.normalMsgs[1], Matches, "local disk resources are rich, estimate sorted data size 1000B, local available is 2KiB") + + // 2. source-size is bigger than disk-size, with default disk-quota will trigger a critical error + rc.checkTemplate = NewSimpleTemplate() + err = rc.localResource(4096) + c.Assert(err, IsNil) + tmpl = rc.checkTemplate.(*SimpleTemplate) + c.Assert(tmpl.warnFailedCount, Equals, 1) + c.Assert(tmpl.criticalFailedCount, Equals, 1) + c.Assert(tmpl.criticalMsgs[0], Matches, "local disk space may not enough to finish import, estimate sorted data size is 4KiB, but local available is 2KiB, please set `tikv-importer.disk-quota` to a smaller value than 2KiB or change `mydumper.sorted-kv-dir` to another disk with enough space to finish imports") + + // 3. source-size is bigger than disk-size, with a vaild disk-quota will trigger a warning + rc.checkTemplate = NewSimpleTemplate() + rc.cfg.TikvImporter.DiskQuota = config.ByteSize(1024) + err = rc.localResource(4096) + c.Assert(err, IsNil) + tmpl = rc.checkTemplate.(*SimpleTemplate) + c.Assert(tmpl.warnFailedCount, Equals, 1) + c.Assert(tmpl.criticalFailedCount, Equals, 0) + c.Assert(tmpl.normalMsgs[1], Matches, "local disk space may not enough to finish import, estimate sorted data size is 4KiB, but local available is 2KiB,we will use disk-quota \\(size: 1KiB\\) to finish imports, which may slow down import") +} diff --git a/br/pkg/lightning/restore/check_template.go b/br/pkg/lightning/restore/check_template.go index 3fb8c22904caa..f38e23aa00f8e 100644 --- a/br/pkg/lightning/restore/check_template.go +++ b/br/pkg/lightning/restore/check_template.go @@ -51,7 +51,8 @@ type SimpleTemplate struct { count int warnFailedCount int criticalFailedCount int - failedMsg []string + normalMsgs []string // only used in unit test now + criticalMsgs []string t table.Writer } @@ -65,16 +66,12 @@ func NewSimpleTemplate() Template { {Name: "Passed", WidthMax: 6}, }) return &SimpleTemplate{ - 0, - 0, - 0, - make([]string, 0), - t, + t: t, } } func (c *SimpleTemplate) FailedMsg() string { - return strings.Join(c.failedMsg, ";\n") + return strings.Join(c.criticalMsgs, ";\n") } func (c *SimpleTemplate) Collect(t CheckType, passed bool, msg string) { @@ -87,7 +84,11 @@ func (c *SimpleTemplate) Collect(t CheckType, passed bool, msg string) { c.warnFailedCount++ } } - c.failedMsg = append(c.failedMsg, msg) + if !passed && t == Critical { + c.criticalMsgs = append(c.criticalMsgs, msg) + } else { + c.normalMsgs = append(c.normalMsgs, msg) + } c.t.AppendRow(table.Row{c.count, msg, t, passed}) c.t.AppendSeparator() } @@ -108,7 +109,7 @@ func (c *SimpleTemplate) FailedCount(t CheckType) int { func (c *SimpleTemplate) Output() string { c.t.SetAllowedRowLength(170) - c.t.SetRowPainter(table.RowPainter(func(row table.Row) text.Colors { + c.t.SetRowPainter(func(row table.Row) text.Colors { if passed, ok := row[3].(bool); ok { if !passed { if typ, ok := row[2].(CheckType); ok { @@ -122,7 +123,7 @@ func (c *SimpleTemplate) Output() string { } } return nil - })) + }) res := c.t.Render() summary := "\n" if c.criticalFailedCount > 0 { diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index 0ea46ea67bf14..0ff3260d9875f 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -432,6 +432,7 @@ func (rc *Controller) Run(ctx context.Context) error { rc.setGlobalVariables, rc.restoreSchema, rc.preCheckRequirements, + rc.initCheckpoint, rc.restoreTables, rc.fullCompact, rc.cleanCheckpoints, @@ -497,11 +498,7 @@ type schemaJob struct { dbName string tblName string // empty for create db jobs stmtType schemaStmtType - stmts []*schemaStmt -} - -type schemaStmt struct { - sql string + stmts []string } type restoreSchemaWorker struct { @@ -514,6 +511,15 @@ type restoreSchemaWorker struct { store storage.ExternalStorage } +func (worker *restoreSchemaWorker) addJob(sqlStr string, job *schemaJob) error { + stmts, err := createIfNotExistsStmt(worker.glue.GetParser(), sqlStr, job.dbName, job.tblName) + if err != nil { + return err + } + job.stmts = stmts + return worker.appendJob(job) +} + func (worker *restoreSchemaWorker) makeJobs( dbMetas []*mydump.MDDatabaseMeta, getTables func(context.Context, string) ([]*model.TableInfo, error), @@ -525,15 +531,15 @@ func (worker *restoreSchemaWorker) makeJobs( var err error // 1. restore databases, execute statements concurrency for _, dbMeta := range dbMetas { - restoreSchemaJob := &schemaJob{ + sql, err := dbMeta.GetSchema(worker.ctx, worker.store) + if err != nil { + return err + } + err = worker.addJob(sql, &schemaJob{ dbName: dbMeta.Name, + tblName: "", stmtType: schemaCreateDatabase, - stmts: make([]*schemaStmt, 0, 1), - } - restoreSchemaJob.stmts = append(restoreSchemaJob.stmts, &schemaStmt{ - sql: createDatabaseIfNotExistStmt(dbMeta.Name), }) - err = worker.appendJob(restoreSchemaJob) if err != nil { return err } @@ -559,30 +565,19 @@ func (worker *restoreSchemaWorker) makeJobs( return errors.Errorf("table `%s`.`%s` schema not found", dbMeta.Name, tblMeta.Name) } sql, err := tblMeta.GetSchema(worker.ctx, worker.store) + if err != nil { + return err + } if sql != "" { - stmts, err := createTableIfNotExistsStmt(worker.glue.GetParser(), sql, dbMeta.Name, tblMeta.Name) - if err != nil { - return err - } - restoreSchemaJob := &schemaJob{ + err = worker.addJob(sql, &schemaJob{ dbName: dbMeta.Name, tblName: tblMeta.Name, stmtType: schemaCreateTable, - stmts: make([]*schemaStmt, 0, len(stmts)), - } - for _, sql := range stmts { - restoreSchemaJob.stmts = append(restoreSchemaJob.stmts, &schemaStmt{ - sql: sql, - }) - } - err = worker.appendJob(restoreSchemaJob) + }) if err != nil { return err } } - if err != nil { - return err - } } } err = worker.wait() @@ -594,22 +589,11 @@ func (worker *restoreSchemaWorker) makeJobs( for _, viewMeta := range dbMeta.Views { sql, err := viewMeta.GetSchema(worker.ctx, worker.store) if sql != "" { - stmts, err := createTableIfNotExistsStmt(worker.glue.GetParser(), sql, dbMeta.Name, viewMeta.Name) - if err != nil { - return err - } - restoreSchemaJob := &schemaJob{ + err = worker.addJob(sql, &schemaJob{ dbName: dbMeta.Name, tblName: viewMeta.Name, stmtType: schemaCreateView, - stmts: make([]*schemaStmt, 0, len(stmts)), - } - for _, sql := range stmts { - restoreSchemaJob.stmts = append(restoreSchemaJob.stmts, &schemaStmt{ - sql: sql, - }) - } - err = worker.appendJob(restoreSchemaJob) + }) if err != nil { return err } @@ -670,8 +654,8 @@ loop: DB: session, } for _, stmt := range job.stmts { - task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt.sql)) - err = sqlWithRetry.Exec(worker.ctx, "run create schema job", stmt.sql) + task := logger.Begin(zap.DebugLevel, fmt.Sprintf("execute SQL: %s", stmt)) + err = sqlWithRetry.Exec(worker.ctx, "run create schema job", stmt) task.End(zap.ErrorLevel, err) if err != nil { err = errors.Annotatef(err, "%s %s failed", job.stmtType.String(), common.UniqueTable(job.dbName, job.tblName)) @@ -731,7 +715,7 @@ func (worker *restoreSchemaWorker) appendJob(job *schemaJob) error { case <-worker.ctx.Done(): // cancel the job worker.wg.Done() - return worker.ctx.Err() + return errors.Trace(worker.ctx.Err()) case worker.jobCh <- job: return nil } @@ -771,14 +755,20 @@ func (rc *Controller) restoreSchema(ctx context.Context) error { } rc.dbInfos = dbInfos - if rc.tidbGlue.OwnsSQLExecutor() { - if err = rc.DataCheck(ctx); err != nil { - return errors.Trace(err) - } + sysVars := ObtainImportantVariables(ctx, rc.tidbGlue.GetSQLExecutor(), !rc.isTiDBBackend()) + // override by manually set vars + for k, v := range rc.cfg.TiDB.Vars { + sysVars[k] = v } + rc.sysVars = sysVars + + return nil +} +// initCheckpoint initializes all tables' checkpoint data +func (rc *Controller) initCheckpoint(ctx context.Context) error { // Load new checkpoints - err = rc.checkpointsDB.Initialize(ctx, rc.cfg, dbInfos) + err := rc.checkpointsDB.Initialize(ctx, rc.cfg, rc.dbInfos) if err != nil { return errors.Trace(err) } @@ -789,20 +779,8 @@ func (rc *Controller) restoreSchema(ctx context.Context) error { go rc.listenCheckpointUpdates() - sysVars := ObtainImportantVariables(ctx, rc.tidbGlue.GetSQLExecutor(), !rc.isTiDBBackend()) - // override by manually set vars - for k, v := range rc.cfg.TiDB.Vars { - sysVars[k] = v - } - rc.sysVars = sysVars - // Estimate the number of chunks for progress reporting - err = rc.estimateChunkCountIntoMetrics(ctx) - if err != nil { - return errors.Trace(err) - } - - return nil + return rc.estimateChunkCountIntoMetrics(ctx) } // verifyCheckpoint check whether previous task checkpoint is compatible with task config @@ -1486,6 +1464,9 @@ func (rc *Controller) restoreTables(ctx context.Context) error { if err != nil { return errors.Trace(err) } + if cp.Status < checkpoints.CheckpointStatusAllWritten && len(tableMeta.DataFiles) == 0 { + continue + } igCols, err := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbInfo.Name, tableInfo.Name, rc.cfg.Mydumper.CaseSensitive) if err != nil { return errors.Trace(err) @@ -1551,7 +1532,6 @@ func (tr *TableRestore) restoreTable( cp *checkpoints.TableCheckpoint, ) (bool, error) { // 1. Load the table info. - select { case <-ctx.Done(): return false, ctx.Err() @@ -1812,7 +1792,10 @@ func (rc *Controller) setGlobalVariables(ctx context.Context) error { return nil } // set new collation flag base on tidb config - enabled := ObtainNewCollationEnabled(ctx, rc.tidbGlue.GetSQLExecutor()) + enabled, err := ObtainNewCollationEnabled(ctx, rc.tidbGlue.GetSQLExecutor()) + if err != nil { + return err + } // we should enable/disable new collation here since in server mode, tidb config // may be different in different tasks collate.SetNewCollationEnabledForTest(enabled) @@ -1865,6 +1848,10 @@ func (rc *Controller) isTiDBBackend() bool { // 4. Lightning configuration // before restore tables start. func (rc *Controller) preCheckRequirements(ctx context.Context) error { + if err := rc.DataCheck(ctx); err != nil { + return errors.Trace(err) + } + if rc.cfg.App.CheckRequirements { if err := rc.ClusterIsAvailable(ctx); err != nil { return errors.Trace(err) @@ -1926,8 +1913,7 @@ func (rc *Controller) preCheckRequirements(ctx context.Context) error { if !taskExist && rc.taskMgr != nil { rc.taskMgr.CleanupTask(ctx) } - return errors.Errorf("tidb-lightning check failed."+ - " Please fix the failed check(s):\n %s", rc.checkTemplate.FailedMsg()) + return errors.Errorf("tidb-lightning pre-check failed: %s", rc.checkTemplate.FailedMsg()) } return nil } @@ -2349,7 +2335,7 @@ func (cr *chunkRestore) encodeLoop( hasIgnoredEncodeErr := false if encodeErr != nil { - rowText := tidb.EncodeRowForRecord(t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row) + rowText := tidb.EncodeRowForRecord(t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row, cr.chunk.ColumnPermutation) encodeErr = rc.errorMgr.RecordTypeError(ctx, logger, t.tableName, cr.chunk.Key.Path, newOffset, rowText, encodeErr) err = errors.Annotatef(encodeErr, "in file %s at offset %d", &cr.chunk.Key, newOffset) hasIgnoredEncodeErr = true diff --git a/br/pkg/lightning/restore/restore_test.go b/br/pkg/lightning/restore/restore_test.go index fc44917fdf3b0..c322a1037df39 100644 --- a/br/pkg/lightning/restore/restore_test.go +++ b/br/pkg/lightning/restore/restore_test.go @@ -300,6 +300,63 @@ func (s *restoreSuite) TestDiskQuotaLock(c *C) { } } +// failMetaMgrBuilder mocks meta manager init failure +type failMetaMgrBuilder struct { + metaMgrBuilder +} + +func (b failMetaMgrBuilder) Init(context.Context) error { + return errors.New("mock init meta failure") +} + +type panicCheckpointDB struct { + checkpoints.DB +} + +func (cp panicCheckpointDB) Initialize(context.Context, *config.Config, map[string]*checkpoints.TidbDBInfo) error { + panic("should not reach here") +} + +func (s *restoreSuite) TestPreCheckFailed(c *C) { + cfg := config.NewConfig() + cfg.TikvImporter.Backend = config.BackendTiDB + cfg.App.CheckRequirements = false + + db, mock, err := sqlmock.New() + c.Assert(err, IsNil) + g := glue.NewExternalTiDBGlue(db, mysql.ModeNone) + + ctl := &Controller{ + cfg: cfg, + saveCpCh: make(chan saveCp), + checkpointsDB: panicCheckpointDB{}, + metaMgrBuilder: failMetaMgrBuilder{}, + checkTemplate: NewSimpleTemplate(), + tidbGlue: g, + } + + mock.ExpectBegin() + mock.ExpectQuery("SHOW VARIABLES WHERE Variable_name IN .*"). + WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("tidb_row_format_version", "2")) + mock.ExpectCommit() + // precheck failed, will not do init checkpoint. + err = ctl.Run(context.Background()) + c.Assert(err, ErrorMatches, ".*mock init meta failure") + c.Assert(mock.ExpectationsWereMet(), IsNil) + + mock.ExpectBegin() + mock.ExpectQuery("SHOW VARIABLES WHERE Variable_name IN .*"). + WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}). + AddRow("tidb_row_format_version", "2")) + mock.ExpectCommit() + ctl.saveCpCh = make(chan saveCp) + // precheck failed, will not do init checkpoint. + err1 := ctl.Run(context.Background()) + c.Assert(err1.Error(), Equals, err.Error()) + c.Assert(mock.ExpectationsWereMet(), IsNil) +} + var _ = Suite(&tableRestoreSuite{}) type tableRestoreSuiteBase struct { @@ -534,6 +591,93 @@ func (s *tableRestoreSuite) TestPopulateChunks(c *C) { s.cfg.Mydumper.CSV.Header = false } +type errorLocalWriter struct{} + +func (w errorLocalWriter) AppendRows(context.Context, string, []string, kv.Rows) error { + return errors.New("mock write rows failed") +} + +func (w errorLocalWriter) IsSynced() bool { + return true +} + +func (w errorLocalWriter) Close(context.Context) (backend.ChunkFlushStatus, error) { + return nil, nil +} + +func (s *tableRestoreSuite) TestRestoreEngineFailed(c *C) { + ctx := context.Background() + ctrl := gomock.NewController(c) + mockBackend := mock.NewMockBackend(ctrl) + rc := &Controller{ + cfg: s.cfg, + pauser: DeliverPauser, + ioWorkers: worker.NewPool(ctx, 1, "io"), + regionWorkers: worker.NewPool(ctx, 10, "region"), + store: s.store, + backend: backend.MakeBackend(mockBackend), + errorSummaries: makeErrorSummaries(log.L()), + saveCpCh: make(chan saveCp, 1), + diskQuotaLock: newDiskQuotaLock(), + } + defer close(rc.saveCpCh) + go func() { + for cp := range rc.saveCpCh { + cp.waitCh <- nil + } + }() + + cp := &checkpoints.TableCheckpoint{ + Engines: make(map[int32]*checkpoints.EngineCheckpoint), + } + err := s.tr.populateChunks(ctx, rc, cp) + c.Assert(err, IsNil) + + tbl, err := tables.TableFromMeta(kv.NewPanickingAllocators(0), s.tableInfo.Core) + c.Assert(err, IsNil) + _, indexUUID := backend.MakeUUID("`db`.`table`", -1) + _, dataUUID := backend.MakeUUID("`db`.`table`", 0) + realBackend := tidb.NewTiDBBackend(nil, "replace", nil) + mockBackend.EXPECT().OpenEngine(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockBackend.EXPECT().OpenEngine(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockBackend.EXPECT().CloseEngine(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockBackend.EXPECT().NewEncoder(gomock.Any(), gomock.Any()). + Return(realBackend.NewEncoder(tbl, &kv.SessionOptions{})). + AnyTimes() + mockBackend.EXPECT().MakeEmptyRows().Return(realBackend.MakeEmptyRows()).AnyTimes() + mockBackend.EXPECT().LocalWriter(gomock.Any(), gomock.Any(), dataUUID).Return(noop.Writer{}, nil) + mockBackend.EXPECT().LocalWriter(gomock.Any(), gomock.Any(), indexUUID). + Return(nil, errors.New("mock open index local writer failed")) + openedIdxEngine, err := rc.backend.OpenEngine(ctx, nil, "`db`.`table`", -1) + c.Assert(err, IsNil) + + // open the first engine meet error, should directly return the error + _, err = s.tr.restoreEngine(ctx, rc, openedIdxEngine, 0, cp.Engines[0]) + c.Assert(err, ErrorMatches, "mock open index local writer failed") + + localWriter := func(ctx context.Context, cfg *backend.LocalWriterConfig, engineUUID uuid.UUID) (backend.EngineWriter, error) { + time.Sleep(20 * time.Millisecond) + select { + case <-ctx.Done(): + return nil, errors.New("mock open index local writer failed after ctx.Done") + default: + return noop.Writer{}, nil + } + } + mockBackend.EXPECT().OpenEngine(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockBackend.EXPECT().OpenEngine(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockBackend.EXPECT().LocalWriter(gomock.Any(), gomock.Any(), dataUUID).Return(errorLocalWriter{}, nil).AnyTimes() + mockBackend.EXPECT().LocalWriter(gomock.Any(), gomock.Any(), indexUUID). + DoAndReturn(localWriter).AnyTimes() + + openedIdxEngine, err = rc.backend.OpenEngine(ctx, nil, "`db`.`table`", -1) + c.Assert(err, IsNil) + + // open engine failed after write rows failed, should return write rows error + _, err = s.tr.restoreEngine(ctx, rc, openedIdxEngine, 0, cp.Engines[0]) + c.Assert(err, ErrorMatches, "mock write rows failed") +} + func (s *tableRestoreSuite) TestPopulateChunksCSVHeader(c *C) { fakeDataDir := c.MkDir() store, err := storage.NewLocalStorage(fakeDataDir) @@ -997,7 +1141,7 @@ func (s *tableRestoreSuite) TestTableRestoreMetrics(c *C) { chunkPending := metric.ReadCounter(metric.ChunkCounter.WithLabelValues(metric.ChunkStatePending)) chunkFinished := metric.ReadCounter(metric.ChunkCounter.WithLabelValues(metric.ChunkStatePending)) c.Assert(chunkPending-chunkPendingBase, Equals, float64(7)) - c.Assert(chunkFinished-chunkFinishedBase, Equals, chunkPending) + c.Assert(chunkFinished-chunkFinishedBase, Equals, chunkPending-chunkPendingBase) engineFinished := metric.ReadCounter(metric.ProcessedEngineCounter.WithLabelValues("imported", metric.TableResultSuccess)) c.Assert(engineFinished-engineFinishedBase, Equals, float64(8)) diff --git a/br/pkg/lightning/restore/table_restore.go b/br/pkg/lightning/restore/table_restore.go index 8da5a210ce885..37b842187d42d 100644 --- a/br/pkg/lightning/restore/table_restore.go +++ b/br/pkg/lightning/restore/table_restore.go @@ -196,17 +196,6 @@ func (tr *TableRestore) restoreEngines(pCtx context.Context, rc *Controller, cp tr.logger.Error("fail to restoreEngines because indexengine is nil") return errors.Errorf("table %v index engine checkpoint not found", tr.tableName) } - // If there is an index engine only, it indicates no data needs to restore. - // So we can change status to imported directly and avoid opening engine. - if len(cp.Engines) == 1 { - if err := rc.saveStatusCheckpoint(pCtx, tr.tableName, indexEngineID, nil, checkpoints.CheckpointStatusImported); err != nil { - return errors.Trace(err) - } - if err := rc.saveStatusCheckpoint(pCtx, tr.tableName, checkpoints.WholeTableEngineID, nil, checkpoints.CheckpointStatusIndexImported); err != nil { - return errors.Trace(err) - } - return nil - } ctx, cancel := context.WithCancel(pCtx) defer cancel() @@ -456,6 +445,11 @@ func (tr *TableRestore) restoreEngine( } }() + setError := func(err error) { + chunkErr.Set(err) + cancel() + } + // Restore table data for chunkIndex, chunk := range cp.Chunks { if chunk.Chunk.Offset >= chunk.Chunk.EndOffset { @@ -494,7 +488,8 @@ func (tr *TableRestore) restoreEngine( // 4. flush kvs data (into tikv node) cr, err := newChunkRestore(ctx, chunkIndex, rc.cfg, chunk, rc.ioWorkers, rc.store, tr.tableInfo) if err != nil { - return nil, errors.Trace(err) + setError(err) + break } var remainChunkCnt float64 if chunk.Chunk.Offset < chunk.Chunk.EndOffset { @@ -502,19 +497,23 @@ func (tr *TableRestore) restoreEngine( metric.ChunkCounter.WithLabelValues(metric.ChunkStatePending).Add(remainChunkCnt) } - restoreWorker := rc.regionWorkers.Apply() - wg.Add(1) - dataWriter, err := dataEngine.LocalWriter(ctx, dataWriterCfg) if err != nil { - return nil, errors.Trace(err) + cr.close() + setError(err) + break } indexWriter, err := indexEngine.LocalWriter(ctx, &backend.LocalWriterConfig{}) if err != nil { - return nil, errors.Trace(err) + _, _ = dataWriter.Close(ctx) + cr.close() + setError(err) + break } + restoreWorker := rc.regionWorkers.Apply() + wg.Add(1) go func(w *worker.Worker, cr *chunkRestore) { // Restore a chunk. defer func() { @@ -549,8 +548,7 @@ func (tr *TableRestore) restoreEngine( } } else { metric.ChunkCounter.WithLabelValues(metric.ChunkStateFailed).Add(remainChunkCnt) - chunkErr.Set(err) - cancel() + setError(err) } }(restoreWorker, cr) } @@ -669,10 +667,7 @@ func (tr *TableRestore) postProcess( forcePostProcess bool, metaMgr tableMetaMgr, ) (bool, error) { - // there are no data in this table, no need to do post process - // this is important for tables that are just the dump table of views - // because at this stage, the table was already deleted and replaced by the related view - if !rc.backend.ShouldPostProcess() || len(cp.Engines) == 1 { + if !rc.backend.ShouldPostProcess() { return false, nil } @@ -994,8 +989,8 @@ func estimateCompactionThreshold(cp *checkpoints.TableCheckpoint, factor int64) threshold := totalRawFileSize / 512 threshold = utils.NextPowerOfTwo(threshold) if threshold < compactionLowerThreshold { - // disable compaction if threshold is smaller than lower bound - threshold = 0 + // too may small SST files will cause inaccuracy of region range estimation, + threshold = compactionLowerThreshold } else if threshold > compactionUpperThreshold { threshold = compactionUpperThreshold } diff --git a/br/pkg/lightning/restore/tidb.go b/br/pkg/lightning/restore/tidb.go index ee1252fd3862d..2cdd2b894297d 100644 --- a/br/pkg/lightning/restore/tidb.go +++ b/br/pkg/lightning/restore/tidb.go @@ -120,7 +120,7 @@ func DBFromConfig(ctx context.Context, dsn config.DBStore) (*sql.DB, error) { } for k, v := range vars { - q := fmt.Sprintf("SET SESSION %s = %s;", k, v) + q := fmt.Sprintf("SET SESSION %s = '%s';", k, v) if _, err1 := db.ExecContext(ctx, q); err1 != nil { log.L().Warn("set session variable failed, will skip this query", zap.String("query", q), zap.Error(err1)) @@ -177,7 +177,7 @@ loopCreate: for tbl, sqlCreateTable := range tablesSchema { task.Debug("create table", zap.String("schema", sqlCreateTable)) - sqlCreateStmts, err = createTableIfNotExistsStmt(g.GetParser(), sqlCreateTable, database, tbl) + sqlCreateStmts, err = createIfNotExistsStmt(g.GetParser(), sqlCreateTable, database, tbl) if err != nil { break } @@ -200,14 +200,7 @@ loopCreate: return errors.Trace(err) } -func createDatabaseIfNotExistStmt(dbName string) string { - var createDatabase strings.Builder - createDatabase.WriteString("CREATE DATABASE IF NOT EXISTS ") - common.WriteMySQLIdentifier(&createDatabase, dbName) - return createDatabase.String() -} - -func createTableIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { +func createIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName string) ([]string, error) { stmts, _, err := p.ParseSQL(createTable) if err != nil { return []string{}, err @@ -219,6 +212,9 @@ func createTableIfNotExistsStmt(p *parser.Parser, createTable, dbName, tblName s retStmts := make([]string, 0, len(stmts)) for _, stmt := range stmts { switch node := stmt.(type) { + case *ast.CreateDatabaseStmt: + node.Name = dbName + node.IfNotExists = true case *ast.CreateTableStmt: node.Table.Schema = model.NewCIStr(dbName) node.Table.Name = model.NewCIStr(tblName) @@ -287,13 +283,15 @@ func LoadSchemaInfo( if err != nil { return nil, errors.Trace(err) } + // Table names are case-sensitive in mydump.MDTableMeta. + // We should always use the original tbl.Name in checkpoints. tableInfo := &checkpoints.TidbTableInfo{ ID: tblInfo.ID, DB: schema.Name, - Name: tableName, + Name: tbl.Name, Core: tblInfo, } - dbInfo.Tables[tableName] = tableInfo + dbInfo.Tables[tbl.Name] = tableInfo } result[schema.Name] = dbInfo @@ -369,7 +367,7 @@ func ObtainImportantVariables(ctx context.Context, g glue.SQLExecutor, needTiDBV return result } -func ObtainNewCollationEnabled(ctx context.Context, g glue.SQLExecutor) bool { +func ObtainNewCollationEnabled(ctx context.Context, g glue.SQLExecutor) (bool, error) { newCollationEnabled := false newCollationVal, err := g.ObtainStringWithLog( ctx, @@ -379,9 +377,13 @@ func ObtainNewCollationEnabled(ctx context.Context, g glue.SQLExecutor) bool { ) if err == nil && newCollationVal == "True" { newCollationEnabled = true + } else if errors.ErrorEqual(err, sql.ErrNoRows) { + // ignore if target variable is not found, this may happen if tidb < v4.0 + newCollationEnabled = false + err = nil } - return newCollationEnabled + return newCollationEnabled, errors.Trace(err) } // AlterAutoIncrement rebase the table auto increment id diff --git a/br/pkg/lightning/restore/tidb_test.go b/br/pkg/lightning/restore/tidb_test.go index f066f8438581e..1cb86f9a2a406 100644 --- a/br/pkg/lightning/restore/tidb_test.go +++ b/br/pkg/lightning/restore/tidb_test.go @@ -16,6 +16,7 @@ package restore import ( "context" + "database/sql" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -64,103 +65,109 @@ func (s *tidbSuite) TearDownTest(c *C) { func (s *tidbSuite) TestCreateTableIfNotExistsStmt(c *C) { dbName := "testdb" - createTableIfNotExistsStmt := func(createTable, tableName string) []string { - res, err := createTableIfNotExistsStmt(s.tiGlue.GetParser(), createTable, dbName, tableName) + createSQLIfNotExistsStmt := func(createTable, tableName string) []string { + res, err := createIfNotExistsStmt(s.tiGlue.GetParser(), createTable, dbName, tableName) c.Assert(err, IsNil) return res } c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo"), + createSQLIfNotExistsStmt("CREATE DATABASE `foo` CHARACTER SET = utf8 COLLATE = utf8_general_ci;", ""), + DeepEquals, + []string{"CREATE DATABASE IF NOT EXISTS `testdb` CHARACTER SET = utf8 COLLATE = utf8_general_ci;"}, + ) + + c.Assert( + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` TINYINT(1));", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, ) c.Assert( - createTableIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE IF NOT EXISTS `foo`(`bar` TINYINT(1));", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` TINYINT(1));"}, ) // case insensitive c.Assert( - createTableIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo"), + createSQLIfNotExistsStmt("/* cOmmEnt */ creAte tablE `fOo`(`bar` TinyinT(1));", "fOo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`fOo` (`bar` TINYINT(1));"}, ) c.Assert( - createTableIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO"), + createSQLIfNotExistsStmt("/* coMMenT */ crEatE tAble If not EXISts `FoO`(`bAR` tiNyInT(1));", "FoO"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`FoO` (`bAR` TINYINT(1));"}, ) // only one "CREATE TABLE" is replaced c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE');", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE');"}, ) // test clustered index consistency c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY CLUSTERED COMMENT 'CREATE TABLE');", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] CLUSTERED */ COMMENT 'CREATE TABLE');"}, ) c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) NONCLUSTERED);", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] NONCLUSTERED */);"}, ) c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![clustered_index] NONCLUSTERED */ COMMENT 'CREATE TABLE');"}, ) c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) COMMENT 'CREATE TABLE', PRIMARY KEY (`bar`) /*T![clustered_index] CLUSTERED */);", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) COMMENT 'CREATE TABLE',PRIMARY KEY(`bar`) /*T![clustered_index] CLUSTERED */);"}, ) c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo"), + createSQLIfNotExistsStmt("CREATE TABLE `foo`(`bar` INT(1) PRIMARY KEY AUTO_RANDOM(2) COMMENT 'CREATE TABLE');", "foo"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`foo` (`bar` INT(1) PRIMARY KEY /*T![auto_rand] AUTO_RANDOM(2) */ COMMENT 'CREATE TABLE');"}, ) // upper case becomes shorter c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ"), + createSQLIfNotExistsStmt("CREATE TABLE `ſ`(`ı` TINYINT(1));", "ſ"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ſ` (`ı` TINYINT(1));"}, ) // upper case becomes longer c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ"), + createSQLIfNotExistsStmt("CREATE TABLE `ɑ`(`ȿ` TINYINT(1));", "ɑ"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ɑ` (`ȿ` TINYINT(1));"}, ) // non-utf-8 c.Assert( - createTableIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc"), + createSQLIfNotExistsStmt("CREATE TABLE `\xcc\xcc\xcc`(`\xdd\xdd\xdd` TINYINT(1));", "\xcc\xcc\xcc"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`\xcc\xcc\xcc` (`ÝÝÝ` TINYINT(1));"}, ) // renaming a table c.Assert( - createTableIfNotExistsStmt("create table foo(x int);", "ba`r"), + createSQLIfNotExistsStmt("create table foo(x int);", "ba`r"), DeepEquals, []string{"CREATE TABLE IF NOT EXISTS `testdb`.`ba``r` (`x` INT);"}, ) // conditional comments c.Assert( - createTableIfNotExistsStmt(` + createSQLIfNotExistsStmt(` /*!40101 SET NAMES binary*/; /*!40014 SET FOREIGN_KEY_CHECKS=0*/; CREATE TABLE x.y (z double) ENGINE=InnoDB AUTO_INCREMENT=8343230 DEFAULT CHARSET=utf8; @@ -175,7 +182,7 @@ func (s *tidbSuite) TestCreateTableIfNotExistsStmt(c *C) { // create view c.Assert( - createTableIfNotExistsStmt(` + createSQLIfNotExistsStmt(` /*!40101 SET NAMES binary*/; DROP TABLE IF EXISTS v2; DROP VIEW IF EXISTS v2; @@ -319,7 +326,8 @@ func (s *tidbSuite) TestLoadSchemaInfo(c *C) { "CREATE TABLE `t1` (`a` INT PRIMARY KEY);"+ "CREATE TABLE `t2` (`b` VARCHAR(20), `c` BOOL, KEY (`b`, `c`));"+ // an extra table that not exists in dbMetas - "CREATE TABLE `t3` (`d` VARCHAR(20), `e` BOOL);", + "CREATE TABLE `t3` (`d` VARCHAR(20), `e` BOOL);"+ + "CREATE TABLE `T4` (`f` BIGINT PRIMARY KEY);", "", "") c.Assert(err, IsNil) tableInfos := make([]*model.TableInfo, 0, len(nodes)) @@ -344,6 +352,10 @@ func (s *tidbSuite) TestLoadSchemaInfo(c *C) { DB: "db", Name: "t2", }, + { + DB: "db", + Name: "t4", + }, }, }, } @@ -369,13 +381,19 @@ func (s *tidbSuite) TestLoadSchemaInfo(c *C) { Name: "t2", Core: tableInfos[1], }, + "t4": { + ID: 103, + DB: "db", + Name: "t4", + Core: tableInfos[3], + }, }, }, }) tableCntAfter := metric.ReadCounter(metric.TableCounter.WithLabelValues(metric.TableStatePending, metric.TableResultSuccess)) - c.Assert(tableCntAfter-tableCntBefore, Equals, 2.0) + c.Assert(tableCntAfter-tableCntBefore, Equals, 3.0) } func (s *tidbSuite) TestLoadSchemaInfoMissing(c *C) { @@ -505,8 +523,22 @@ func (s *tidbSuite) TestObtainNewCollationEnabled(c *C) { ctx := context.Background() s.mockDB. - ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E") - version := ObtainNewCollationEnabled(ctx, s.tiGlue.GetSQLExecutor()) + ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E"). + WillReturnError(errors.New("mock permission deny")) + s.mockDB. + ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E"). + WillReturnError(errors.New("mock permission deny")) + s.mockDB. + ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E"). + WillReturnError(errors.New("mock permission deny")) + _, err := ObtainNewCollationEnabled(ctx, s.tiGlue.GetSQLExecutor()) + c.Assert(err, ErrorMatches, "obtain new collation enabled failed: mock permission deny") + + s.mockDB. + ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E"). + WillReturnRows(sqlmock.NewRows([]string{"variable_value"}).RowError(0, sql.ErrNoRows)) + version, err := ObtainNewCollationEnabled(ctx, s.tiGlue.GetSQLExecutor()) + c.Assert(err, IsNil) c.Assert(version, Equals, false) kvMap := map[string]bool{ @@ -518,7 +550,8 @@ func (s *tidbSuite) TestObtainNewCollationEnabled(c *C) { ExpectQuery("\\QSELECT variable_value FROM mysql.tidb WHERE variable_name = 'new_collation_enabled'\\E"). WillReturnRows(sqlmock.NewRows([]string{"variable_value"}).AddRow(k)) - version := ObtainNewCollationEnabled(ctx, s.tiGlue.GetSQLExecutor()) + version, err = ObtainNewCollationEnabled(ctx, s.tiGlue.GetSQLExecutor()) + c.Assert(err, IsNil) c.Assert(version, Equals, v) } s.mockDB. diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go index 07ce5c8a862b9..c54141b8ee560 100644 --- a/br/pkg/storage/gcs.go +++ b/br/pkg/storage/gcs.go @@ -180,11 +180,6 @@ func (s *gcsStorage) WalkDir(ctx context.Context, opt *WalkOption, fn func(strin opt = &WalkOption{} } - maxKeys := int64(1000) - if opt.ListCount > 0 { - maxKeys = opt.ListCount - } - prefix := path.Join(s.gcs.Prefix, opt.SubDir) if len(prefix) > 0 && !strings.HasSuffix(prefix, "/") { prefix += "/" @@ -194,7 +189,7 @@ func (s *gcsStorage) WalkDir(ctx context.Context, opt *WalkOption, fn func(strin // only need each object's name and size query.SetAttrSelection([]string{"Name", "Size"}) iter := s.bucket.Objects(ctx, query) - for i := int64(0); i != maxKeys; i++ { + for { attrs, err := iter.Next() if err == iterator.Done { break @@ -281,14 +276,6 @@ func newGCSStorage(ctx context.Context, gcs *backuppb.GCS, opts *ExternalStorage // so we need find sst in slash directory gcs.Prefix += "//" } - // TODO remove it after BR remove cfg skip-check-path - if !opts.SkipCheckPath { - // check bucket exists - _, err = bucket.Attrs(ctx) - if err != nil { - return nil, errors.Annotatef(err, "gcs://%s/%s", gcs.Bucket, gcs.Prefix) - } - } return &gcsStorage{gcs: gcs, bucket: bucket}, nil } diff --git a/br/pkg/storage/gcs_test.go b/br/pkg/storage/gcs_test.go index c3e63d6d410a2..ccf3927497bea 100644 --- a/br/pkg/storage/gcs_test.go +++ b/br/pkg/storage/gcs_test.go @@ -4,6 +4,7 @@ package storage import ( "context" + "fmt" "io" "os" @@ -95,6 +96,31 @@ func (r *testStorageSuite) TestGCS(c *C) { c.Assert(list, Equals, "keykey1key2") c.Assert(totalSize, Equals, int64(42)) + // test 1003 files + totalSize = 0 + for i := 0; i < 1000; i += 1 { + err = stg.WriteFile(ctx, fmt.Sprintf("f%d", i), []byte("data")) + c.Assert(err, IsNil) + } + filesSet := make(map[string]struct{}, 1003) + err = stg.WalkDir(ctx, nil, func(name string, size int64) error { + filesSet[name] = struct{}{} + totalSize += size + return nil + }) + c.Assert(err, IsNil) + c.Assert(totalSize, Equals, int64(42+4000)) + _, ok := filesSet["key"] + c.Assert(ok, IsTrue) + _, ok = filesSet["key1"] + c.Assert(ok, IsTrue) + _, ok = filesSet["key2"] + c.Assert(ok, IsTrue) + for i := 0; i < 1000; i += 1 { + _, ok = filesSet[fmt.Sprintf("f%d", i)] + c.Assert(ok, IsTrue) + } + efr, err := stg.Open(ctx, "key2") c.Assert(err, IsNil) diff --git a/br/pkg/storage/s3.go b/br/pkg/storage/s3.go index 2c07b5af2cad0..6accafee7363d 100644 --- a/br/pkg/storage/s3.go +++ b/br/pkg/storage/s3.go @@ -283,14 +283,6 @@ func newS3Storage(backend *backuppb.S3, opts *ExternalStorageOptions) (*S3Storag } c := s3.New(ses) - // TODO remove it after BR remove cfg skip-check-path - if !opts.SkipCheckPath { - err = checkS3Bucket(c, &qs) - if err != nil { - return nil, errors.Annotatef(berrors.ErrStorageInvalidConfig, "Bucket %s is not accessible: %v", qs.Bucket, err) - } - } - if len(qs.Prefix) > 0 && !strings.HasSuffix(qs.Prefix, "/") { qs.Prefix += "/" } diff --git a/br/pkg/storage/s3_test.go b/br/pkg/storage/s3_test.go index 413f5e8881da1..cf30828b07c65 100644 --- a/br/pkg/storage/s3_test.go +++ b/br/pkg/storage/s3_test.go @@ -288,7 +288,6 @@ func (s *s3Suite) TestS3Storage(c *C) { _, err := New(ctx, s3, &ExternalStorageOptions{ SendCredentials: test.sendCredential, CheckPermissions: test.hackPermission, - SkipCheckPath: true, }) if test.errReturn { c.Assert(err, NotNil) @@ -414,7 +413,7 @@ func (s *s3Suite) TestS3Storage(c *C) { func (s *s3Suite) TestS3URI(c *C) { backend, err := ParseBackend("s3://bucket/prefix/", nil) c.Assert(err, IsNil) - storage, err := New(context.Background(), backend, &ExternalStorageOptions{SkipCheckPath: true}) + storage, err := New(context.Background(), backend, &ExternalStorageOptions{}) c.Assert(err, IsNil) c.Assert(storage.URI(), Equals, "s3://bucket/prefix/") } diff --git a/br/pkg/storage/storage.go b/br/pkg/storage/storage.go index af05abac398fa..177656fc378a0 100644 --- a/br/pkg/storage/storage.go +++ b/br/pkg/storage/storage.go @@ -121,18 +121,6 @@ type ExternalStorageOptions struct { // NoCredentials means that no cloud credentials are supplied to BR NoCredentials bool - // SkipCheckPath marks whether to skip checking path's existence. - // - // This should only be set to true in testing, to avoid interacting with the - // real world. - // When this field is false (i.e. path checking is enabled), the New() - // function will ensure the path referred by the backend exists by - // recursively creating the folders. This will also throw an error if such - // operation is impossible (e.g. when the bucket storing the path is missing). - - // deprecated: use checkPermissions and specify the checkPermission instead. - SkipCheckPath bool - // HTTPClient to use. The created storage may ignore this field if it is not // directly using HTTP (e.g. the local storage). HTTPClient *http.Client @@ -148,7 +136,6 @@ type ExternalStorageOptions struct { func Create(ctx context.Context, backend *backuppb.StorageBackend, sendCreds bool) (ExternalStorage, error) { return New(ctx, backend, &ExternalStorageOptions{ SendCredentials: sendCreds, - SkipCheckPath: false, HTTPClient: nil, }) } diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index 7a9037c20f80c..87461f53bab74 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -257,7 +257,6 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/task/backup_raw.go b/br/pkg/task/backup_raw.go index d8d11ea95c3a1..febe151218706 100644 --- a/br/pkg/task/backup_raw.go +++ b/br/pkg/task/backup_raw.go @@ -150,7 +150,6 @@ func RunBackupRaw(c context.Context, g glue.Glue, cmdName string, cfg *RawKvConf opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 36de8583ea92e..fc1c8ecc5748d 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -485,6 +485,9 @@ func (cfg *Config) ParseFromFlags(flags *pflag.FlagSet) error { if cfg.SkipCheckPath, err = flags.GetBool(flagSkipCheckPath); err != nil { return errors.Trace(err) } + if cfg.SkipCheckPath { + log.L().Info("--skip-check-path is deprecated, need explicitly set it anymore") + } if err = cfg.parseCipherInfo(flags); err != nil { return errors.Trace(err) @@ -548,7 +551,6 @@ func storageOpts(cfg *Config) *storage.ExternalStorageOptions { return &storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } } diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index a80549d005905..ae46f15b1f6ce 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -249,7 +249,6 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/task/restore_log.go b/br/pkg/task/restore_log.go index 26a8bdae0add0..45486c4f8e577 100644 --- a/br/pkg/task/restore_log.go +++ b/br/pkg/task/restore_log.go @@ -121,7 +121,6 @@ func RunLogRestore(c context.Context, g glue.Glue, cfg *LogRestoreConfig) error opts := storage.ExternalStorageOptions{ NoCredentials: cfg.NoCreds, SendCredentials: cfg.SendCreds, - SkipCheckPath: cfg.SkipCheckPath, } if err = client.SetStorage(ctx, u, &opts); err != nil { return errors.Trace(err) diff --git a/br/pkg/utils/retry.go b/br/pkg/utils/retry.go index a076190b953d6..51a833d8d136c 100644 --- a/br/pkg/utils/retry.go +++ b/br/pkg/utils/retry.go @@ -117,7 +117,9 @@ func isSingleRetryableError(err error) bool { case *mysql.MySQLError: switch nerr.Number { // ErrLockDeadlock can retry to commit while meet deadlock - case tmysql.ErrUnknown, tmysql.ErrLockDeadlock, tmysql.ErrWriteConflictInTiDB, tmysql.ErrPDServerTimeout, tmysql.ErrTiKVServerTimeout, tmysql.ErrTiKVServerBusy, tmysql.ErrResolveLockTimeout, tmysql.ErrRegionUnavailable: + case tmysql.ErrUnknown, tmysql.ErrLockDeadlock, tmysql.ErrWriteConflict, tmysql.ErrWriteConflictInTiDB, + tmysql.ErrPDServerTimeout, tmysql.ErrTiKVServerTimeout, tmysql.ErrTiKVServerBusy, tmysql.ErrResolveLockTimeout, + tmysql.ErrRegionUnavailable, tmysql.ErrInfoSchemaExpired, tmysql.ErrInfoSchemaChanged, tmysql.ErrTxnRetryable: return true default: return false diff --git a/br/pkg/utils/retry_test.go b/br/pkg/utils/retry_test.go index b5c54287f1cce..0186e2314ee99 100644 --- a/br/pkg/utils/retry_test.go +++ b/br/pkg/utils/retry_test.go @@ -37,6 +37,10 @@ func (s *utilSuite) TestIsRetryableError(c *C) { c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrResolveLockTimeout}), IsTrue) c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrRegionUnavailable}), IsTrue) c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrWriteConflictInTiDB}), IsTrue) + c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrWriteConflict}), IsTrue) + c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrInfoSchemaExpired}), IsTrue) + c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrInfoSchemaChanged}), IsTrue) + c.Assert(IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrTxnRetryable}), IsTrue) // gRPC Errors c.Assert(IsRetryableError(status.Error(codes.Canceled, "")), IsFalse) diff --git a/br/tests/lightning_column_permutation/data/perm-schema-create.sql b/br/tests/lightning_column_permutation/data/perm-schema-create.sql index fe9a5be60a3ff..28138f8d72659 100644 --- a/br/tests/lightning_column_permutation/data/perm-schema-create.sql +++ b/br/tests/lightning_column_permutation/data/perm-schema-create.sql @@ -1 +1 @@ -CREATE DATABASE `perm` IF NOT EXISTS; +CREATE DATABASE IF NOT EXISTS `perm`; diff --git a/br/tests/lightning_incremental/data/incr.empty_table-schema.sql b/br/tests/lightning_incremental/data/incr.empty_table-schema.sql new file mode 100644 index 0000000000000..881156cb99fd5 --- /dev/null +++ b/br/tests/lightning_incremental/data/incr.empty_table-schema.sql @@ -0,0 +1 @@ +CREATE TABLE `empty_table` (id int primary key); diff --git a/br/tests/lightning_incremental/data1/incr.empty_table-schema.sql b/br/tests/lightning_incremental/data1/incr.empty_table-schema.sql new file mode 100644 index 0000000000000..881156cb99fd5 --- /dev/null +++ b/br/tests/lightning_incremental/data1/incr.empty_table-schema.sql @@ -0,0 +1 @@ +CREATE TABLE `empty_table` (id int primary key); diff --git a/br/tests/lightning_incremental/data1/incr.empty_table2-schema.sql b/br/tests/lightning_incremental/data1/incr.empty_table2-schema.sql new file mode 100644 index 0000000000000..610412d940815 --- /dev/null +++ b/br/tests/lightning_incremental/data1/incr.empty_table2-schema.sql @@ -0,0 +1 @@ +CREATE TABLE `empty_table2` (id int primary key, s varchar(16)); diff --git a/br/tests/lightning_incremental/run.sh b/br/tests/lightning_incremental/run.sh index f04630055936a..4cdd5a53ec74b 100644 --- a/br/tests/lightning_incremental/run.sh +++ b/br/tests/lightning_incremental/run.sh @@ -18,60 +18,66 @@ set -eu check_cluster_version 4 0 0 "incremental restore" || exit 0 -DB_NAME=incr +run_lightning_and_check_meta() { + run_lightning --backend local "$@" + # check metadata table is not exist + run_sql "SHOW DATABASES like 'lightning_metadata';" + check_not_contains "Database: lightning_metadata" +} -for backend in importer local; do - run_sql "DROP DATABASE IF EXISTS incr;" - run_lightning --backend $backend +DB_NAME=incr - for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do - run_sql "SELECT count(*) from incr.$tbl" - check_contains "count(*): 3" - done +run_sql "DROP DATABASE IF EXISTS incr;" +run_sql "DROP DATABASE IF EXISTS lightning_metadata;" +run_lightning_and_check_meta - for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do - if [ "$tbl" = "auto_random" ]; then - run_sql "SELECT id & b'000001111111111111111111111111111111111111111111111111111111111' as inc FROM incr.$tbl" - else - run_sql "SELECT id as inc FROM incr.$tbl" - fi - check_contains 'inc: 1' - check_contains 'inc: 2' - check_contains 'inc: 3' - done +for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do + run_sql "SELECT count(*) from incr.$tbl" + check_contains "count(*): 3" +done - for tbl in pk_auto_inc rowid_uk_inc; do - run_sql "SELECT group_concat(v) from incr.$tbl group by 'all';" - check_contains "group_concat(v): a,b,c" - done +for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do + if [ "$tbl" = "auto_random" ]; then + run_sql "SELECT id & b'000001111111111111111111111111111111111111111111111111111111111' as inc FROM incr.$tbl" + else + run_sql "SELECT id as inc FROM incr.$tbl" + fi + check_contains 'inc: 1' + check_contains 'inc: 2' + check_contains 'inc: 3' +done - run_sql "SELECT sum(pk) from incr.uk_auto_inc;" - check_contains "sum(pk): 6" +for tbl in pk_auto_inc rowid_uk_inc; do + run_sql "SELECT group_concat(v) from incr.$tbl group by 'all';" + check_contains "group_concat(v): a,b,c" +done - # incrementally import all data in data1 - run_lightning --backend $backend -d "tests/$TEST_NAME/data1" +run_sql "SELECT sum(pk) from incr.uk_auto_inc;" +check_contains "sum(pk): 6" - for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do - run_sql "SELECT count(*) from incr.$tbl" - check_contains "count(*): 6" - done +# incrementally import all data in data1 +run_lightning_and_check_meta -d "tests/$TEST_NAME/data1" - for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do - if [ "$tbl" = "auto_random" ]; then - run_sql "SELECT id & b'000001111111111111111111111111111111111111111111111111111111111' as inc FROM incr.$tbl" - else - run_sql "SELECT id as inc FROM incr.$tbl" - fi - check_contains 'inc: 4' - check_contains 'inc: 5' - check_contains 'inc: 6' - done +for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do + run_sql "SELECT count(*) from incr.$tbl" + check_contains "count(*): 6" +done - for tbl in pk_auto_inc rowid_uk_inc; do - run_sql "SELECT group_concat(v) from incr.$tbl group by 'all';" - check_contains "group_concat(v): a,b,c,d,e,f" - done +for tbl in auto_random pk_auto_inc rowid_uk_inc uk_auto_inc; do + if [ "$tbl" = "auto_random" ]; then + run_sql "SELECT id & b'000001111111111111111111111111111111111111111111111111111111111' as inc FROM incr.$tbl" + else + run_sql "SELECT id as inc FROM incr.$tbl" + fi + check_contains 'inc: 4' + check_contains 'inc: 5' + check_contains 'inc: 6' +done - run_sql "SELECT sum(pk) from incr.uk_auto_inc;" - check_contains "sum(pk): 21" +for tbl in pk_auto_inc rowid_uk_inc; do + run_sql "SELECT group_concat(v) from incr.$tbl group by 'all';" + check_contains "group_concat(v): a,b,c,d,e,f" done + +run_sql "SELECT sum(pk) from incr.uk_auto_inc;" +check_contains "sum(pk): 21" diff --git a/br/tests/lightning_new_collation/data/nc-schema-create.sql b/br/tests/lightning_new_collation/data/nc-schema-create.sql new file mode 100644 index 0000000000000..6608189c71304 --- /dev/null +++ b/br/tests/lightning_new_collation/data/nc-schema-create.sql @@ -0,0 +1 @@ +CREATE DATABASE nc CHARACTER SET = utf8mb4 COLLATE = utf8mb4_general_ci; diff --git a/br/tests/lightning_new_collation/data/nc.ci-schema.sql b/br/tests/lightning_new_collation/data/nc.ci-schema.sql new file mode 100644 index 0000000000000..1e7958a76409c --- /dev/null +++ b/br/tests/lightning_new_collation/data/nc.ci-schema.sql @@ -0,0 +1 @@ +CREATE TABLE ci(i INT PRIMARY KEY, v varchar(32)); diff --git a/br/tests/lightning_new_collation/data/nc.ci.0.csv b/br/tests/lightning_new_collation/data/nc.ci.0.csv new file mode 100644 index 0000000000000..a1b4dcff21e40 --- /dev/null +++ b/br/tests/lightning_new_collation/data/nc.ci.0.csv @@ -0,0 +1,2 @@ +i,v +1,aA diff --git a/br/tests/lightning_new_collation/run.sh b/br/tests/lightning_new_collation/run.sh index d4c49e3c61192..f360ed3a94fc4 100644 --- a/br/tests/lightning_new_collation/run.sh +++ b/br/tests/lightning_new_collation/run.sh @@ -54,6 +54,10 @@ for BACKEND in local importer tidb; do run_sql "SELECT j FROM nc.t WHERE s = 'This_Is_Test4'"; check_contains "j: 4" + run_sql "SELeCT i, v from nc.ci where v = 'aa';" + check_contains "i: 1" + check_contains "v: aA" + done # restart with original config diff --git a/br/tests/lightning_s3/run.sh b/br/tests/lightning_s3/run.sh index 6fed0af2b81da..5b2973784fd7e 100755 --- a/br/tests/lightning_s3/run.sh +++ b/br/tests/lightning_s3/run.sh @@ -62,6 +62,20 @@ _EOF_ run_sql "DROP DATABASE IF EXISTS $DB;" run_sql "DROP TABLE IF EXISTS $DB.$TABLE;" +# test not exist path +rm -f $TEST_DIR/lightning.log +SOURCE_DIR="s3://$BUCKET/not-exist-path?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" +! run_lightning -d $SOURCE_DIR --backend local 2> /dev/null +grep -Eq "data-source-dir .* doesn't exist or contains no files" $TEST_DIR/lightning.log + +# test empty dir +rm -f $TEST_DIR/lightning.log +emptyPath=empty-bucket/empty-path +mkdir -p $DBPATH/$emptyPath +SOURCE_DIR="s3://$emptyPath/not-exist-path?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" +! run_lightning -d $SOURCE_DIR --backend local 2> /dev/null +grep -Eq "data-source-dir .* doesn't exist or contains no files" $TEST_DIR/lightning.log + SOURCE_DIR="s3://$BUCKET/?endpoint=http%3A//127.0.0.1%3A9900&access_key=$MINIO_ACCESS_KEY&secret_access_key=$MINIO_SECRET_KEY&force_path_style=true" run_lightning -d $SOURCE_DIR --backend local 2> /dev/null run_sql "SELECT count(*), sum(i) FROM \`$DB\`.$TABLE" diff --git a/br/tests/lightning_sqlmode/on.toml b/br/tests/lightning_sqlmode/on.toml index 4bd09629d1394..c78047e08516f 100644 --- a/br/tests/lightning_sqlmode/on.toml +++ b/br/tests/lightning_sqlmode/on.toml @@ -1,3 +1,7 @@ +[lightning] +max-error = 20 +task-info-schema-name = 'sqlmodedb_lightning_task_info' + [tikv-importer] backend = 'local' diff --git a/br/tests/lightning_sqlmode/run.sh b/br/tests/lightning_sqlmode/run.sh index e2829964009dc..81d44c2450d6d 100755 --- a/br/tests/lightning_sqlmode/run.sh +++ b/br/tests/lightning_sqlmode/run.sh @@ -49,10 +49,39 @@ check_contains 'hex(c): ' check_contains 'd: ' run_sql 'DROP DATABASE IF EXISTS sqlmodedb' +run_sql 'DROP DATABASE IF EXISTS sqlmodedb_lightning_task_info' -set +e run_lightning --config "tests/$TEST_NAME/on.toml" --log-file "$TEST_DIR/sqlmode-error.log" -[ $? -ne 0 ] || exit 1 -set -e grep -q '\["kv convert failed"\].*\[original=.*kind=uint64,val=9.*\] \[originalCol=1\] \[colName=a\] \[colType="timestamp BINARY"\]' "$TEST_DIR/sqlmode-error.log" + +run_sql 'SELECT min(id), max(id) FROM sqlmodedb.t' +check_contains 'min(id): 4' +check_contains 'max(id): 4' + +run_sql 'SELECT count(*) FROM sqlmodedb_lightning_task_info.type_error_v1' +check_contains 'count(*): 4' + +run_sql 'SELECT path, `offset`, error, row_data FROM sqlmodedb_lightning_task_info.type_error_v1 WHERE table_name = "`sqlmodedb`.`t`" AND row_data LIKE "(1,%";' +check_contains 'path: sqlmodedb.t.1.sql' +check_contains 'offset: 53' +check_contains 'cannot convert datum from unsigned bigint to type timestamp.' +check_contains "row_data: (1,9,128,'too long','x,y,z')" + +run_sql 'SELECT path, `offset`, error, row_data FROM sqlmodedb_lightning_task_info.type_error_v1 WHERE table_name = "`sqlmodedb`.`t`" AND row_data LIKE "(2,%";' +check_contains 'path: sqlmodedb.t.1.sql' +check_contains 'offset: 100' +check_contains "Incorrect timestamp value: '2000-00-00 00:00:00'" +check_contains "row_data: (2,'2000-00-00 00:00:00',-99999,'🤩',3)" + +run_sql 'SELECT path, `offset`, error, row_data FROM sqlmodedb_lightning_task_info.type_error_v1 WHERE table_name = "`sqlmodedb`.`t`" AND row_data LIKE "(3,%";' +check_contains 'path: sqlmodedb.t.1.sql' +check_contains 'offset: 149' +check_contains "Incorrect timestamp value: '9999-12-31 23:59:59'" +check_contains "row_data: (3,'9999-12-31 23:59:59','NaN',x'99','x+y')" + +run_sql 'SELECT path, `offset`, error, row_data FROM sqlmodedb_lightning_task_info.type_error_v1 WHERE table_name = "`sqlmodedb`.`t`" AND row_data LIKE "(5,%";' +check_contains 'path: sqlmodedb.t.1.sql' +check_contains 'offset: 237' +check_contains "Column 'a' cannot be null" +check_contains "row_data: (5,NULL,NULL,NULL,NULL)" diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 7cf20eb790270..29215c58b0f0e 100755 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -878,7 +878,20 @@ func prepareTableListToDump(tctx *tcontext.Context, conf *Config, db *sql.Conn) if !conf.NoViews { tableTypes = append(tableTypes, TableTypeView) } - conf.Tables, err = ListAllDatabasesTables(tctx, db, databases, getListTableTypeByConf(conf), tableTypes...) + + ifSeqExists, err := CheckIfSeqExists(db) + if err != nil { + return err + } + var listType listTableType + if ifSeqExists { + tctx.L().Warn("dumpling tableType `sequence` is unsupported for now") + listType = listTableByShowFullTables + } else { + listType = getListTableTypeByConf(conf) + } + + conf.Tables, err = ListAllDatabasesTables(tctx, db, databases, listType, tableTypes...) if err != nil { return err } diff --git a/dumpling/export/prepare.go b/dumpling/export/prepare.go index f9036ec32ea98..300c971b1eee2 100644 --- a/dumpling/export/prepare.go +++ b/dumpling/export/prepare.go @@ -112,6 +112,9 @@ const ( TableTypeBase TableType = iota // TableTypeView represents the view table TableTypeView + // TableTypeSequence represents the view table + // TODO: need to be supported + TableTypeSequence ) const ( @@ -119,6 +122,8 @@ const ( TableTypeBaseStr = "BASE TABLE" // TableTypeViewStr represents the view table string TableTypeViewStr = "VIEW" + // TableTypeSequenceStr represents the view table string + TableTypeSequenceStr = "SEQUENCE" ) func (t TableType) String() string { @@ -127,6 +132,8 @@ func (t TableType) String() string { return TableTypeBaseStr case TableTypeView: return TableTypeViewStr + case TableTypeSequence: + return TableTypeSequenceStr default: return "UNKNOWN" } @@ -139,6 +146,8 @@ func ParseTableType(s string) (TableType, error) { return TableTypeBase, nil case TableTypeViewStr: return TableTypeView, nil + case TableTypeSequenceStr: + return TableTypeSequence, nil default: return TableTypeBase, errors.Errorf("unknown table type %s", s) } diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 4b0203a1665da..9a984acaaa599 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -691,6 +691,19 @@ func CheckTiDBWithTiKV(db *sql.DB) (bool, error) { return count > 0, nil } +// CheckIfSeqExists use sql to check whether sequence exists +func CheckIfSeqExists(db *sql.Conn) (bool, error) { + var count int + const query = "SELECT COUNT(1) as c FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='SEQUENCE'" + row := db.QueryRowContext(context.Background(), query) + err := row.Scan(&count) + if err != nil { + return false, errors.Annotatef(err, "sql: %s", query) + } + + return count > 0, nil +} + // CheckTiDBEnableTableLock use sql variable to check whether current TiDB has TiKV func CheckTiDBEnableTableLock(db *sql.Conn) (bool, error) { tidbConfig, err := getTiDBConfig(db) diff --git a/dumpling/export/sql_test.go b/dumpling/export/sql_test.go index 8209e6e54b53d..19edc27c25df9 100644 --- a/dumpling/export/sql_test.go +++ b/dumpling/export/sql_test.go @@ -1736,6 +1736,35 @@ func TestPickupPossibleField(t *testing.T) { } } +func TestCheckIfSeqExists(t *testing.T) { + t.Parallel() + + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + + mock.ExpectQuery("SELECT COUNT"). + WillReturnRows(sqlmock.NewRows([]string{"c"}). + AddRow("1")) + + exists, err := CheckIfSeqExists(conn) + require.NoError(t, err) + require.Equal(t, true, exists) + + mock.ExpectQuery("SELECT COUNT"). + WillReturnRows(sqlmock.NewRows([]string{"c"}). + AddRow("0")) + + exists, err = CheckIfSeqExists(conn) + require.NoError(t, err) + require.Equal(t, false, exists) +} + func makeVersion(major, minor, patch int64, preRelease string) *semver.Version { return &semver.Version{ Major: major, diff --git a/dumpling/tests/basic/run.sh b/dumpling/tests/basic/run.sh index 6caccce221433..5eccbb77514e5 100644 --- a/dumpling/tests/basic/run.sh +++ b/dumpling/tests/basic/run.sh @@ -88,6 +88,12 @@ actual=$(sed -n '2p' ${DUMPLING_OUTPUT_DIR}/result.000000000.csv) echo "expected 2, actual ${actual}" [ "$actual" = 2 ] +# Test for dump with sequence +run_dumpling | tee ${DUMPLING_OUTPUT_DIR}/dumpling.log +actual=$(grep -w "dump failed" ${DUMPLING_OUTPUT_DIR}/dumpling.log|wc -l) +echo "expected 0, actual ${actual}" +[ "$actual" = 0 ] + # Test for tidb_mem_quota_query configuration export GO_FAILPOINTS="github.com/pingcap/tidb/dumpling/export/PrintTiDBMemQuotaQuery=1*return" run_dumpling > ${DUMPLING_OUTPUT_DIR}/dumpling.log diff --git a/executor/coprocessor.go b/executor/coprocessor.go index 6eb438d5aaeb5..3811475fa9212 100644 --- a/executor/coprocessor.go +++ b/executor/coprocessor.go @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec Username: dagReq.User.UserName, Hostname: dagReq.User.UserHost, } - authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost) - if success { + authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false) + if success && pm.GetAuthWithoutVerification(authName, authHost) { h.sctx.GetSessionVars().User.AuthUsername = authName h.sctx.GetSessionVars().User.AuthHostname = authHost h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index b155370d64462..eca2d075f6930 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -292,6 +292,9 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E argTp := args[0].GetType().EvalType() switch argTp { case types.ETInt: + if bf.tp.Flen == types.UnspecifiedLength { + bf.tp.Flen = args[0].GetType().Flen + } sig = &builtinCastIntAsStringSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString) case types.ETReal: diff --git a/expression/integration_test.go b/expression/integration_test.go index 68f9e1a45684d..eded8ae8e9c48 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -10526,3 +10526,16 @@ func (s *testIntegrationSuite) TestIssue29244(c *C) { tk.MustExec("set tidb_enable_vectorized_expression = off;") tk.MustQuery("select microsecond(a) from t;").Check(testkit.Rows("123500", "123500")) } + +func (s *testIntegrationSuite) TestIssue29513(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustQuery("select '123' union select cast(45678 as char);").Sort().Check(testkit.Rows("123", "45678")) + tk.MustQuery("select '123' union select cast(45678 as char(2));").Sort().Check(testkit.Rows("123", "45")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int);") + tk.MustExec("insert into t values(45678);") + tk.MustQuery("select '123' union select cast(a as char) from t;").Sort().Check(testkit.Rows("123", "45678")) + tk.MustQuery("select '123' union select cast(a as char(2)) from t;").Sort().Check(testkit.Rows("123", "45")) +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index dd7805a6c282f..943eb4fc89cb8 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -288,6 +288,7 @@ func (sf *ScalarFunction) Clone() Expression { } c.SetCharsetAndCollation(sf.CharsetAndCollation(sf.GetCtx())) c.SetCoercibility(sf.Coercibility()) + c.SetRepertoire(sf.Repertoire()) return c } diff --git a/expression/scalar_function_test.go b/expression/scalar_function_test.go index 66e4222dbc310..827cc63af6060 100644 --- a/expression/scalar_function_test.go +++ b/expression/scalar_function_test.go @@ -50,6 +50,8 @@ func TestScalarFunction(t *testing.T) { require.True(t, ok) require.Equal(t, "values", newSf.FuncName.O) require.Equal(t, mysql.TypeLonglong, newSf.RetType.Tp) + require.Equal(t, sf.Coercibility(), newSf.Coercibility()) + require.Equal(t, sf.Repertoire(), newSf.Repertoire()) _, ok = newSf.Function.(*builtinValuesIntSig) require.True(t, ok) } diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 13f5d81380e7e..e6735d632ece7 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -187,9 +187,9 @@ func (s *testInferTypeSuite) createTestCase4Constants() []typeInferTestCase { func (s *testInferTypeSuite) createTestCase4Cast() []typeInferTestCase { return []typeInferTestCase{ - {"CAST(c_int_d AS BINARY)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, -1, -1}, // TODO: Flen should be 11. + {"CAST(c_int_d AS BINARY)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 11, -1}, {"CAST(c_int_d AS BINARY(5))", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 5, -1}, - {"CAST(c_int_d AS CHAR)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, -1, -1}, // TODO: Flen should be 11. + {"CAST(c_int_d AS CHAR)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 11, -1}, {"CAST(c_int_d AS CHAR(5))", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 5, -1}, {"CAST(c_int_d AS DATE)", mysql.TypeDate, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, {"CAST(c_int_d AS DATETIME)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, @@ -436,8 +436,8 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"reverse(c_int_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"reverse(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, - {"reverse(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, -1, types.UnspecifiedLength}, - {"reverse(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, -1, types.UnspecifiedLength}, + {"reverse(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, types.UnspecifiedLength, types.UnspecifiedLength}, + {"reverse(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, types.UnspecifiedLength, types.UnspecifiedLength}, {"reverse(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 8, types.UnspecifiedLength}, {"reverse(c_char )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"reverse(c_varchar )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, diff --git a/privilege/privilege.go b/privilege/privilege.go index e0b9d41f41b1d..af5ff9924ffe9 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -59,10 +59,15 @@ type Manager interface { RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) + // Requires exact match on user name and host name. + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool // GetAuthWithoutVerification uses to get auth name without verification. - GetAuthWithoutVerification(user, host string) (string, string, bool) + // Requires exact match on user name and host name. + GetAuthWithoutVerification(user, host string) bool + + // MatchIdentity matches an identity + MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index dc90170a500ef..4d388e073b205 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" @@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { // See https://dev.mysql.com/doc/refman/5.7/en/account-names.html func (record *baseRecord) hostMatch(s string) bool { if record.hostIPNet == nil { + if record.Host == "localhost" && net.ParseIP(s).IsLoopback() { + return true + } return false } ip := net.ParseIP(s).To4() @@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool { return stringutil.DoMatchBytes(str, patChars, patTypes) } -// connectionVerification verifies the connection have access to TiDB server. -func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { +// matchIdentity finds an identity to match a user + host +// using the correct rules according to MySQL. +func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord { for i := 0; i < len(p.User); i++ { record := &p.User[i] if record.match(user, host) { return record } } + + // If skip-name resolve is not enabled, and the host is not localhost + // we can fallback and try to resolve with all addrs that match. + // TODO: this is imported from previous code in session.Auth(), and can be improved in future. + if !skipNameResolve && host != variable.DefHostname { + addrs, err := net.LookupAddr(host) + if err != nil { + logutil.BgLogger().Warn( + "net.LookupAddr returned an error during auth check", + zap.String("host", host), + zap.Error(err), + ) + return nil + } + for _, addr := range addrs { + for i := 0; i < len(p.User); i++ { + record := &p.User[i] + if record.match(user, addr) { + return record + } + } + } + } + return nil +} + +// connectionVerification verifies the username + hostname according to exact +// match from the mysql.user privilege table. call matchIdentity() first if you +// do not have an exact match yet. +func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { + records, exists := p.UserMap[user] + if exists { + for i := 0; i < len(records); i++ { + record := &records[i] + if record.Host == host { // exact match + return record + } + } + } return nil } diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 104c2c3782387..fea22acef641c 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { return "", errors.New("Failed to get plugin for user") } +// MatchIdentity implements the Manager interface. +func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) { + if SkipWithGrant { + return user, host, true + } + mysqlPriv := p.Handle.Get() + record := mysqlPriv.matchIdentity(user, host, skipNameResolve) + if record != nil { + return record.User, record.Host, true + } + return "", "", false +} + // GetAuthWithoutVerification implements the Manager interface. -func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) { +func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string return } - u = record.User - h = record.Host p.user = user - p.host = h + p.host = record.Host success = true return } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } - u = record.User - h = record.Host - globalPriv := mysqlPriv.matchGlobalPriv(user, host) if globalPriv != nil { if !p.checkSSL(globalPriv, tlsState) { @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio // empty password if len(pwd) == 0 && len(authentication) == 0 { p.user = user - p.host = h + p.host = record.Host success = true return } @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio } p.user = user - p.host = h + p.host = record.Host success = true return } diff --git a/server/conn.go b/server/conn.go index 49a42fe54bb94..54e1c3728b672 100644 --- a/server/conn.go +++ b/server/conn.go @@ -211,6 +211,9 @@ func (cc *clientConn) String() string { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest // https://bugs.mysql.com/bug.php?id=93044 func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) { + failpoint.Inject("FakeAuthSwitch", func() { + failpoint.Return([]byte(plugin), nil) + }) enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1 data := cc.alloc.AllocWithLen(4, enclen) data = append(data, mysql.AuthSwitchRequest) // switch request @@ -708,9 +711,10 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error { if resp.Capability&mysql.ClientPluginAuth > 0 { - newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin) + newAuth, err := cc.checkAuthPlugin(ctx, resp) if err != nil { logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + return err } if len(newAuth) > 0 { resp.Auth = newAuth @@ -718,30 +722,18 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo switch resp.AuthPlugin { case mysql.AuthCachingSha2Password: - resp.Auth, err = cc.authSha(ctx) - if err != nil { - return err - } case mysql.AuthNativePassword: case mysql.AuthSocket: default: logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin)) } } else { + // MySQL 5.1 and older clients don't support authentication plugins. logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client") - if cc.ctx == nil { - err := cc.openSession() - if err != nil { - return err - } - } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + _, err := cc.checkAuthPlugin(ctx, resp) if err != nil { return err } - if userplugin != mysql.AuthNativePassword && userplugin != "" { - return errNotSupportedAuthMode - } resp.AuthPlugin = mysql.AuthNativePassword } return nil @@ -845,7 +837,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e } // Check if the Authentication Plugin of the server, client and user configuration matches -func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) { +func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) { // Open a context unless this was done before. if cc.ctx == nil { err := cc.openSession() @@ -854,12 +846,34 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( } } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + authData := resp.Auth + hasPassword := "YES" + if len(authData) == 0 { + hasPassword = "NO" + } + host, _, err := cc.PeerHost(hasPassword) if err != nil { return nil, err } + // Find the identity of the user based on username and peer host. + identity, err := cc.ctx.MatchIdentity(cc.user, host) + if err != nil { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + // Get the plugin for the identity. + userplugin, err := cc.ctx.AuthPluginForUser(identity) + if err != nil { + logutil.Logger(ctx).Warn("Failed to get authentication method for user", + zap.String("user", cc.user), zap.String("host", host)) + } + failpoint.Inject("FakeUser", func(val failpoint.Value) { + userplugin = val.(string) + }) if userplugin == mysql.AuthSocket { - *authPlugin = mysql.AuthSocket + if !cc.isUnixSocket { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + resp.AuthPlugin = mysql.AuthSocket user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) if err != nil { return nil, err @@ -867,9 +881,19 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( return []byte(user.Username), nil } if len(userplugin) == 0 { - logutil.Logger(ctx).Warn("No user plugin set, assuming MySQL Native Password", - zap.String("user", cc.user), zap.String("host", cc.peerHost)) - *authPlugin = mysql.AuthNativePassword + // No user plugin set, assuming MySQL Native Password + // This happens if the account doesn't exist or if the account doesn't have + // a password set. + if resp.AuthPlugin != mysql.AuthNativePassword { + if resp.Capability&mysql.ClientPluginAuth > 0 { + resp.AuthPlugin = mysql.AuthNativePassword + authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword) + if err != nil { + return nil, err + } + return authData, nil + } + } return nil, nil } @@ -878,13 +902,18 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( // or if the authentication method send by the server doesn't match the authentication // method send by the client (*authPlugin) then we need to switch the authentication // method to match the one configured for that specific user. - if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) { - authData, err := cc.authSwitchRequest(ctx, userplugin) - if err != nil { - return nil, err + if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { + if resp.Capability&mysql.ClientPluginAuth > 0 { + authData, err := cc.authSwitchRequest(ctx, userplugin) + if err != nil { + return nil, err + } + resp.AuthPlugin = userplugin + return authData, nil + } else if userplugin != mysql.AuthNativePassword { + // MySQL 5.1 and older don't support authentication plugins yet + return nil, errNotSupportedAuthMode } - *authPlugin = userplugin - return authData, nil } return nil, nil diff --git a/server/conn_test.go b/server/conn_test.go index dc50900e41624..95aeaf7af41d9 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -894,8 +894,6 @@ func TestShowErrors(t *testing.T) { } func TestHandleAuthPlugin(t *testing.T) { - t.Parallel() - store, clean := testkit.CreateMockStore(t) defer clean() @@ -905,25 +903,202 @@ func TestHandleAuthPlugin(t *testing.T) { drv := NewTiDBDriver(store) srv, err := NewServer(cfg, drv) require.NoError(t, err) + ctx := context.Background() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("CREATE USER unativepassword") + defer func() { + tk.MustExec("DROP USER unativepassword") + }() + // 5.7 or newer client trying to authenticate with mysql_native_password cc := &clientConn{ connectionID: 1, alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", pkt: &packetIO{ bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), }, - collation: mysql.DefaultCollationID, - server: srv, - user: "root", + server: srv, + user: "unativepassword", } - ctx := context.Background() resp := handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword)) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // === Target account has mysql_native_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"mysql_native_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) + + // === Target account has caching_sha2_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"caching_sha2_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - resp.Capability = mysql.ClientProtocol41 + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.Error(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) } diff --git a/server/http_handler_test.go b/server/http_handler_test.go index d78642b1651a0..ca17cfef8f29c 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -483,6 +483,7 @@ func (ts *basicHTTPHandlerTestSuite) startServer(c *C) { cfg.Port = 0 cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true + cfg.Socket = "" server, err := NewServer(cfg, ts.tidbdrv) c.Assert(err, IsNil) diff --git a/server/plan_replayer_test.go b/server/plan_replayer_test.go index 903f771463ee8..a007e1f5d3c9a 100644 --- a/server/plan_replayer_test.go +++ b/server/plan_replayer_test.go @@ -39,6 +39,7 @@ func TestDumpPlanReplayerAPI(t *testing.T) { client := newTestServerClient() cfg := newTestConfig() cfg.Port = client.port + cfg.Socket = "" cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true diff --git a/server/statistics_handler_serial_test.go b/server/statistics_handler_serial_test.go index 7c56cf2186831..9d81d36dcb083 100644 --- a/server/statistics_handler_serial_test.go +++ b/server/statistics_handler_serial_test.go @@ -38,6 +38,7 @@ func TestDumpStatsAPI(t *testing.T) { client := newTestServerClient() cfg := newTestConfig() cfg.Port = client.port + cfg.Socket = "" cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true diff --git a/session/session.go b/session/session.go index aa08d554c9f13..913faa3591aa9 100644 --- a/session/session.go +++ b/session/session.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "encoding/json" "fmt" - "net" "runtime/pprof" "runtime/trace" "strconv" @@ -146,6 +145,7 @@ type Session interface { Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool AuthWithoutVerification(user *auth.UserIdentity) bool AuthPluginForUser(user *auth.UserIdentity) (string, error) + MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) ShowProcess() *util.ProcessInfo // Return the information of the txn current running TxnInfo() *txninfo.TxnInfo @@ -2211,91 +2211,61 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { return authplugin, nil } +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.ConnectionVerification(authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false } + return false +} - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var skipNameResolve bool + var user = &auth.UserIdentity{} + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true } - return false + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) + if success { + return user, nil + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) } // AuthWithoutVerification is required by the ResetConnection RPC func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false - } - - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.GetAuthWithoutVerification(user.Username, addr) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } } return false } -func (s *session) getHostByIP(ip string) []string { - if ip == "127.0.0.1" { - return []string{variable.DefHostname} - } - skipNameResolve, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) - if err == nil && variable.TiDBOptOn(skipNameResolve) { - return []string{ip} // user wants to skip name resolution - } - addrs, err := net.LookupAddr(ip) - if err != nil { - // These messages can be noisy. - // See: https://github.com/pingcap/tidb/pull/13989 - logutil.BgLogger().Debug( - "net.LookupAddr returned an error during auth check", - zap.String("ip", ip), - zap.Error(err), - ) - return []string{ip} - } - return addrs -} - // RefreshVars implements the sessionctx.Context interface. func (s *session) RefreshVars(ctx context.Context) error { pruneMode, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBPartitionPruneMode) diff --git a/session/session_test.go b/session/session_test.go index b2b386fefbd61..b4aadd0306bc1 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -18,6 +18,7 @@ import ( "context" "flag" "fmt" + "net" "os" "path" "runtime" @@ -691,6 +692,50 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(terror.ErrorEqual(err, variable.ErrUnknownTimeZone), IsTrue) } +func (s *testSessionSuite) TestMatchIdentity(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("CREATE USER `useridentity`@`%`") + tk.MustExec("CREATE USER `useridentity`@`localhost`") + tk.MustExec("CREATE USER `useridentity`@`192.168.1.1`") + tk.MustExec("CREATE USER `useridentity`@`example.com`") + + // The MySQL matching rule is most specific to least specific. + // So if I log in from 192.168.1.1 I should match that entry always. + identity, err := tk.Se.MatchIdentity("useridentity", "192.168.1.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "192.168.1.1") + + // If I log in from localhost, I should match localhost + identity, err = tk.Se.MatchIdentity("useridentity", "localhost") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // If I log in from 192.168.1.2 I should match wildcard. + identity, err = tk.Se.MatchIdentity("useridentity", "192.168.1.2") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "%") + + identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // This uses the lookup of example.com to get an IP address. + // We then login with that IP address, but expect it to match the example.com + // entry in the privileges table (by reverse lookup). + ips, err := net.LookupHost("example.com") + c.Assert(err, IsNil) + identity, err = tk.Se.MatchIdentity("useridentity", ips[0]) + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + // FIXME: we *should* match example.com instead + // as long as skip-name-resolve is not set (DEFAULT) + c.Assert(identity.Hostname, Equals, "%") +} + func (s *testSessionSuite) TestGetSysVariables(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) diff --git a/util/logutil/slow_query_logger.go b/util/logutil/slow_query_logger.go index 5f81f3d73b2f4..2588c36131fd9 100644 --- a/util/logutil/slow_query_logger.go +++ b/util/logutil/slow_query_logger.go @@ -29,14 +29,12 @@ var _pool = buffer.NewPool() func newSlowQueryLogger(cfg *LogConfig) (*zap.Logger, *log.ZapProperties, error) { - // copy global config and override slow query log file - // if slow query log filename is empty, slow query log will behave the same as global log + // copy the global log config to slow log config + // if the filename of slow log config is empty, slow log will behave the same as global log. sqConfig := cfg.Config if len(cfg.SlowQueryFile) != 0 { - sqConfig.File = log.FileLogConfig{ - MaxSize: cfg.File.MaxSize, - Filename: cfg.SlowQueryFile, - } + sqConfig.File = cfg.File + sqConfig.File.Filename = cfg.SlowQueryFile } // create the slow query logger