diff --git a/gocqlxtest/gocqlxtest.go b/gocqlxtest/gocqlxtest.go index 26cb669..b0f747d 100644 --- a/gocqlxtest/gocqlxtest.go +++ b/gocqlxtest/gocqlxtest.go @@ -18,7 +18,7 @@ import ( ) var ( - flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + flagCluster = flag.String("cluster", "127.3.0.5", "a comma-separated list of host:port tuples") flagKeyspace = flag.String("keyspace", "gocqlx_test", "keyspace name") flagProto = flag.Int("proto", 0, "protcol version") flagCQL = flag.String("cql", "3.0.0", "CQL version") diff --git a/migrate/example/example_test.go b/migrate/example/example_test.go index 51c3f4a..457f132 100644 --- a/migrate/example/example_test.go +++ b/migrate/example/example_test.go @@ -50,6 +50,12 @@ func TestExample(t *testing.T) { reg.Add(migrate.CallComment, "3", log) migrate.Callback = reg.Callback + pending, err := migrate.Pending(context.Background(), session, cql.Files) + if err != nil { + t.Fatal("Pending:", err) + } + t.Log("Pending migrations:", len(pending)) + // First run prints data if err := migrate.FromFS(context.Background(), session, cql.Files); err != nil { t.Fatal("Migrate:", err) diff --git a/migrate/migrate.go b/migrate/migrate.go index b9329b6..37bda0e 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -87,6 +87,48 @@ func List(ctx context.Context, session gocqlx.Session) ([]*Info, error) { return v, nil } +// Pending provides a listing of pending migrations. +func Pending(ctx context.Context, session gocqlx.Session, f fs.FS) ([]*Info, error) { + applied, err := List(ctx, session) + if err != nil { + return nil, err + } + + fm, err := fs.Glob(f, "*.cql") + if err != nil { + return nil, fmt.Errorf("list migrations: %w", err) + } + + if len(applied) > len(fm) { + return nil, fmt.Errorf("database is ahead") + } + + pending := make([]*Info, 0, len(fm)-len(applied)) + + for i := range applied { + if applied[i].Name != fm[i] { + return nil, fmt.Errorf("inconsistent migrations found, expected %q got %q at %d", applied[i].Name, fm[i], i) + } + } + + for _, name := range fm[len(applied):] { + c, err := fileChecksum(f, name) + if err != nil { + return nil, fmt.Errorf("calculate checksum for %q: %w", name, err) + } + + info := &Info{ + Name: filepath.Base(name), + StartTime: time.Now(), + Checksum: c, + } + + pending = append(pending, info) + } + + return pending, nil +} + func ensureInfoTable(ctx context.Context, session gocqlx.Session) error { return session.ContextQuery(ctx, infoSchema, nil).ExecRelease() } diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 2aaeb08..aa04e5c 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -45,6 +45,89 @@ func recreateTables(tb testing.TB, session gocqlx.Session) { } } +func TestPending(t *testing.T) { + session := gocqlxtest.CreateSession(t) + defer session.Close() + recreateTables(t, session) + + ctx := context.Background() + + t.Run("ahead", func(t *testing.T) { + defer recreateTables(t, session) + + if err := migrate.FromFS(ctx, session, makeTestFS(4)); err != nil { + t.Fatal(err) + } + + _, err := migrate.Pending(ctx, session, makeTestFS(2)) + + if err == nil || !strings.Contains(err.Error(), "ahead") { + t.Fatal("expected error") + } else { + t.Log(err) + } + }) + + t.Run("inconsistent", func(t *testing.T) { + defer recreateTables(t, session) + + if err := migrate.FromFS(ctx, session, makeTestFS(1)); err != nil { + t.Fatal(err) + } + + f := memfs.New() + writeFile(f, 1, fmt.Sprintf(insertMigrate, 1)+";") + + _, err := migrate.Pending(ctx, session, f) + + if err == nil || !strings.Contains(err.Error(), "inconsistent") { + t.Fatal("expected error") + } else { + t.Log(err) + } + }) + + t.Run("pending", func(t *testing.T) { + defer recreateTables(t, session) + + f := memfs.New() + writeFile(f, 0, fmt.Sprintf(insertMigrate, 0)+";") + + pending, err := migrate.Pending(ctx, session, f) + if err != nil { + t.Fatal(err) + } + if len(pending) != 1 { + t.Fatal("expected 2 pending migrations got", len(pending)) + } + + err = migrate.FromFS(ctx, session, f) + if err != nil { + t.Fatal(err) + } + + pending, err = migrate.Pending(ctx, session, f) + if err != nil { + t.Fatal(err) + } + if len(pending) != 0 { + t.Fatal("expected no pending migrations got", len(pending)) + } + + for i := 1; i < 3; i++ { + writeFile(f, i, fmt.Sprintf(insertMigrate, i)+";") + } + + pending, err = migrate.Pending(ctx, session, f) + if err != nil { + t.Fatal(err) + } + if len(pending) != 2 { + t.Fatal("expected 2 pending migrations got", len(pending)) + } + }) +} + func TestMigration(t *testing.T) { session := gocqlxtest.CreateSession(t) defer session.Close()