diff --git a/db/buckets.go b/db/buckets.go index 3b5f75ebdf..c75c8d98a0 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -31,6 +31,7 @@ const ( Pending BlockCommitments Temporary // used temporarily for migrations + SchemaIntermediateState ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/migration/bucket_migrator.go b/migration/bucket_migrator.go index d680cb93cf..bff63799de 100644 --- a/migration/bucket_migrator.go +++ b/migration/bucket_migrator.go @@ -2,6 +2,7 @@ package migration import ( "bytes" + "context" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" @@ -66,15 +67,16 @@ func (m *BucketMigrator) WithKeyFilter(keyFilter BucketMigratorKeyFilter) *Bucke return m } -func (m *BucketMigrator) Before() { +func (m *BucketMigrator) Before(_ []byte) error { m.before() + return nil } -func (m *BucketMigrator) Migrate(txn db.Transaction, network utils.Network) error { +func (m *BucketMigrator) Migrate(_ context.Context, txn db.Transaction, network utils.Network) ([]byte, error) { remainingInBatch := m.batchSize iterator, err := txn.NewIterator() if err != nil { - return err + return nil, err } for iterator.Seek(m.startFrom); iterator.Valid(); iterator.Next() { @@ -84,24 +86,24 @@ func (m *BucketMigrator) Migrate(txn db.Transaction, network utils.Network) erro } if pass, err := m.keyFilter(key); err != nil { - return utils.RunAndWrapOnError(iterator.Close, err) + return nil, utils.RunAndWrapOnError(iterator.Close, err) } else if pass { if remainingInBatch == 0 { m.startFrom = key - return utils.RunAndWrapOnError(iterator.Close, ErrCallWithNewTransaction) + return nil, utils.RunAndWrapOnError(iterator.Close, ErrCallWithNewTransaction) } remainingInBatch-- value, err := iterator.Value() if err != nil { - return utils.RunAndWrapOnError(iterator.Close, err) + return nil, utils.RunAndWrapOnError(iterator.Close, err) } if err = m.do(txn, key, value, network); err != nil { - return utils.RunAndWrapOnError(iterator.Close, err) + return nil, utils.RunAndWrapOnError(iterator.Close, err) } } } - return iterator.Close() + return nil, iterator.Close() } diff --git a/migration/bucket_migrator_test.go b/migration/bucket_migrator_test.go index 22dbb7e91c..718ce28946 100644 --- a/migration/bucket_migrator_test.go +++ b/migration/bucket_migrator_test.go @@ -2,6 +2,7 @@ package migration_test import ( "bytes" + "context" "errors" "testing" @@ -32,17 +33,20 @@ func TestBucketMover(t *testing.T) { return txn.Set(sourceBucket.Key(), []byte{44}) })) - mover.Before() + require.NoError(t, mover.Before(nil)) require.True(t, beforeCalled) - - err := testDB.Update(func(txn db.Transaction) error { - err := mover.Migrate(txn, utils.Mainnet) + var ( + intermediateState []byte + err error + ) + err = testDB.Update(func(txn db.Transaction) error { + intermediateState, err = mover.Migrate(context.Background(), txn, utils.Mainnet) require.ErrorIs(t, err, migration.ErrCallWithNewTransaction) return nil }) require.NoError(t, err) err = testDB.Update(func(txn db.Transaction) error { - err = mover.Migrate(txn, utils.Mainnet) + intermediateState, err = mover.Migrate(context.Background(), txn, utils.Mainnet) require.NoError(t, err) return nil }) @@ -76,4 +80,5 @@ func TestBucketMover(t *testing.T) { return nil }) require.NoError(t, err) + require.Nil(t, intermediateState) } diff --git a/migration/migration.go b/migration/migration.go index 8f760482a7..a14605cc8f 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -2,6 +2,7 @@ package migration import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -16,27 +17,34 @@ import ( "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/utils" "github.com/bits-and-blooms/bitset" + "github.com/fxamacker/cbor/v2" "github.com/sourcegraph/conc/pool" ) +type schemaMetadata struct { + Version uint64 + IntermediateState []byte +} + type Migration interface { - Before() - Migrate(db.Transaction, utils.Network) error + Before(intermediateState []byte) error + // Migration should return intermediate state whenever it requests new txn or detects cancelled ctx. + Migrate(context.Context, db.Transaction, utils.Network) ([]byte, error) } type MigrationFunc func(db.Transaction, utils.Network) error // Migrate returns f(txn). -func (f MigrationFunc) Migrate(txn db.Transaction, network utils.Network) error { - return f(txn, network) +func (f MigrationFunc) Migrate(_ context.Context, txn db.Transaction, network utils.Network) ([]byte, error) { + return nil, f(txn, network) } // Before is a no-op. -func (f MigrationFunc) Before() {} +func (f MigrationFunc) Before(_ []byte) error { return nil } -// migrations contains a set of migrations that can be applied to a database. +// defaultMigrations contains a set of migrations that can be applied to a database. // After making breaking changes to the DB layout, add new migrations to this list. -var migrations = []Migration{ +var defaultMigrations = []Migration{ MigrationFunc(migration0000), MigrationFunc(relocateContractStorageRootKeys), MigrationFunc(recalculateBloomFilters), @@ -56,9 +64,13 @@ var migrations = []Migration{ var ErrCallWithNewTransaction = errors.New("call with new transaction") -func MigrateIfNeeded(targetDB db.DB, network utils.Network, log utils.SimpleLogger) error { +func MigrateIfNeeded(ctx context.Context, targetDB db.DB, network utils.Network, log utils.SimpleLogger) error { + return migrateIfNeeded(ctx, targetDB, network, log, defaultMigrations) +} + +func migrateIfNeeded(ctx context.Context, targetDB db.DB, network utils.Network, log utils.SimpleLogger, migrations []Migration) error { /* - Schema version of the targetDB determines which set of migrations need to be applied to the database. + Schema metadata of the targetDB determines which set of migrations need to be applied to the database. After a migration is successfully executed, which may update the database, the schema version is incremented by 1 by this loop. @@ -73,36 +85,39 @@ func MigrateIfNeeded(targetDB db.DB, network utils.Network, log utils.SimpleLogg new ones. It will be able to do this since the schema version it reads from the database will be non-zero and that is what we use to initialise the i loop variable. */ - version, err := SchemaVersion(targetDB) + metadata, err := SchemaMetadata(targetDB) if err != nil { return err } - for i := version; i < uint64(len(migrations)); i++ { + for i := metadata.Version; i < uint64(len(migrations)); i++ { + if err := ctx.Err(); err != nil { + return err + } log.Infow("Applying database migration", "stage", fmt.Sprintf("%d/%d", i+1, len(migrations))) migration := migrations[i] - migration.Before() + if err := migration.Before(metadata.IntermediateState); err != nil { + return err + } for { var migrationErr error if dbErr := targetDB.Update(func(txn db.Transaction) error { - migrationErr = migration.Migrate(txn, network) - if migrationErr != nil { - if errors.Is(migrationErr, ErrCallWithNewTransaction) { - return nil // Run the migration again with a new transaction. + metadata.IntermediateState, migrationErr = migration.Migrate(ctx, txn, network) + switch { + case migrationErr == nil || errors.Is(migrationErr, ctx.Err()): + if metadata.IntermediateState == nil { + metadata.Version++ } + return updateSchemaMetadata(txn, metadata) + case errors.Is(migrationErr, ErrCallWithNewTransaction): + return nil // Run migration again with new transaction. + default: return migrationErr } - - // Migration successful. Bump the version. - var versionBytes [8]byte - binary.BigEndian.PutUint64(versionBytes[:], i+1) - return txn.Set(db.SchemaVersion.Key(), versionBytes[:]) }); dbErr != nil { return dbErr } else if migrationErr == nil { break - } else if !errors.Is(migrationErr, ErrCallWithNewTransaction) { - return migrationErr } } } @@ -110,21 +125,46 @@ func MigrateIfNeeded(targetDB db.DB, network utils.Network, log utils.SimpleLogg return nil } -func SchemaVersion(targetDB db.DB) (uint64, error) { - version := uint64(0) +// SchemaMetadata retrieves metadata about a database schema from the given database. +func SchemaMetadata(targetDB db.DB) (schemaMetadata, error) { + metadata := schemaMetadata{} txn, err := targetDB.NewTransaction(false) if err != nil { - return 0, nil + return metadata, err } - err = txn.Get(db.SchemaVersion.Key(), func(bytes []byte) error { - version = binary.BigEndian.Uint64(bytes) + if err := txn.Get(db.SchemaVersion.Key(), func(b []byte) error { + metadata.Version = binary.BigEndian.Uint64(b) return nil - }) - if err != nil && !errors.Is(err, db.ErrKeyNotFound) { - return 0, utils.RunAndWrapOnError(txn.Discard, err) + }); err != nil && !errors.Is(err, db.ErrKeyNotFound) { + return metadata, utils.RunAndWrapOnError(txn.Discard, err) + } + + if err := txn.Get(db.SchemaIntermediateState.Key(), func(b []byte) error { + return cbor.Unmarshal(b, &metadata.IntermediateState) + }); err != nil && !errors.Is(err, db.ErrKeyNotFound) { + return metadata, utils.RunAndWrapOnError(txn.Discard, err) } - return version, txn.Discard() + return metadata, txn.Discard() +} + +// updateSchemaMetadata updates the schema in given database. +func updateSchemaMetadata(txn db.Transaction, schema schemaMetadata) error { + var ( + version [8]byte + state []byte + err error + ) + binary.BigEndian.PutUint64(version[:], schema.Version) + state, err = cbor.Marshal(schema.IntermediateState) + if err != nil { + return err + } + + if err := txn.Set(db.SchemaVersion.Key(), version[:]); err != nil { + return err + } + return txn.Set(db.SchemaIntermediateState.Key(), state) } // migration0000 makes sure the targetDB is empty @@ -227,7 +267,7 @@ type changeTrieNodeEncoding struct { } } -func (m *changeTrieNodeEncoding) Before() { +func (m *changeTrieNodeEncoding) Before(_ []byte) error { m.trieNodeBuckets = map[db.Bucket]*struct { seekTo []byte skipLen int @@ -245,6 +285,7 @@ func (m *changeTrieNodeEncoding) Before() { skipLen: 1 + felt.Bytes, }, } + return nil } type node struct { @@ -314,7 +355,7 @@ func (n *node) _UnmarshalBinary(data []byte) error { return err } -func (m *changeTrieNodeEncoding) Migrate(txn db.Transaction, _ utils.Network) error { +func (m *changeTrieNodeEncoding) Migrate(_ context.Context, txn db.Transaction, _ utils.Network) ([]byte, error) { // If we made n a trie.Node, the encoder would fall back to the custom encoding methods. // We instead define a cutom struct to force the encoder to use the default encoding. var n node @@ -371,15 +412,15 @@ func (m *changeTrieNodeEncoding) Migrate(txn db.Transaction, _ utils.Network) er iterator, err := txn.NewIterator() if err != nil { - return err + return nil, err } for bucket, info := range m.trieNodeBuckets { if err := migrateF(iterator, bucket, info.seekTo, info.skipLen); err != nil { - return utils.RunAndWrapOnError(iterator.Close, err) + return nil, utils.RunAndWrapOnError(iterator.Close, err) } } - return iterator.Close() + return nil, iterator.Close() } // calculateBlockCommitments calculates the txn and event commitments for each block and stores them separately diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index 85ff3e1e00..cc24ef08f9 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -3,6 +3,8 @@ package migration import ( "bytes" "context" + "encoding/binary" + "errors" "testing" "github.com/NethermindEth/juno/blockchain" @@ -137,9 +139,10 @@ func TestChangeTrieNodeEncoding(t *testing.T) { })) m := new(changeTrieNodeEncoding) - m.Before() + require.NoError(t, m.Before(nil)) require.NoError(t, testdb.Update(func(txn db.Transaction) error { - return m.Migrate(txn, utils.Mainnet) + _, err := m.Migrate(context.Background(), txn, utils.Mainnet) + return err })) require.NoError(t, testdb.Update(func(txn db.Transaction) error { @@ -248,3 +251,144 @@ func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { require.Equal(t, felt.Zero, trieNode.Left.Felt()) require.Equal(t, felt.Zero, trieNode.Right.Felt()) } + +func TestSchemaMetadata(t *testing.T) { + t.Run("conversion", func(t *testing.T) { + t.Run("version not set", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + metadata, err := SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, uint64(0), metadata.Version) + require.Nil(t, metadata.IntermediateState) + }) + + t.Run("version set", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + var version [8]byte + binary.BigEndian.PutUint64(version[:], 1) + require.NoError(t, testDB.Update(func(txn db.Transaction) error { + return txn.Set(db.SchemaVersion.Key(), version[:]) + })) + + metadata, err := SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, uint64(1), metadata.Version) + require.Nil(t, metadata.IntermediateState) + }) + }) + t.Run("update", func(t *testing.T) { + t.Run("Intermediate nil", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + version := uint64(5) + require.NoError(t, testDB.Update(func(txn db.Transaction) error { + return updateSchemaMetadata(txn, schemaMetadata{ + Version: version, + IntermediateState: nil, + }) + })) + metadata, err := SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, version, metadata.Version) + require.Nil(t, metadata.IntermediateState) + }) + + t.Run("Intermediate not nil", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + var ( + intermediateState = []byte{1, 2, 3, 4} + version = uint64(5) + ) + require.NoError(t, testDB.Update(func(txn db.Transaction) error { + return updateSchemaMetadata(txn, schemaMetadata{ + Version: version, + IntermediateState: intermediateState, + }) + })) + metadata, err := SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, version, metadata.Version) + require.Equal(t, intermediateState, metadata.IntermediateState) + }) + + t.Run("Intermediate empty", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + var ( + intermediateState = make([]byte, 0) + version = uint64(5) + ) + require.NoError(t, testDB.Update(func(txn db.Transaction) error { + return updateSchemaMetadata(txn, schemaMetadata{ + Version: version, + IntermediateState: intermediateState, + }) + })) + metadata, err := SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, version, metadata.Version) + require.Equal(t, intermediateState, metadata.IntermediateState) + }) + }) +} + +type testMigration struct { + exec func(context.Context, db.Transaction, utils.Network) ([]byte, error) + before func([]byte) error +} + +func (f testMigration) Migrate(ctx context.Context, txn db.Transaction, network utils.Network) ([]byte, error) { + return f.exec(ctx, txn, network) +} + +func (f testMigration) Before(state []byte) error { return f.before(state) } + +func TestMigrateIfNeededInternal(t *testing.T) { + t.Run("failure at schema", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + migrations := []Migration{ + testMigration{ + exec: func(context.Context, db.Transaction, utils.Network) ([]byte, error) { + return nil, errors.New("foo") + }, + before: func([]byte) error { + return errors.New("bar") + }, + }, + } + require.ErrorContains(t, migrateIfNeeded(context.Background(), testDB, utils.Mainnet, utils.NewNopZapLogger(), migrations), "bar") + }) + + t.Run("call with new tx", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + var counter int + migrations := []Migration{ + testMigration{ + exec: func(context.Context, db.Transaction, utils.Network) ([]byte, error) { + if counter == 0 { + counter++ + return nil, ErrCallWithNewTransaction + } + return nil, nil + }, + before: func([]byte) error { + return nil + }, + }, + } + require.NoError(t, migrateIfNeeded(context.Background(), testDB, utils.Mainnet, utils.NewNopZapLogger(), migrations)) + }) + + t.Run("error during migration", func(t *testing.T) { + testDB := pebble.NewMemTest(t) + migrations := []Migration{ + testMigration{ + exec: func(context.Context, db.Transaction, utils.Network) ([]byte, error) { + return nil, errors.New("foo") + }, + before: func([]byte) error { + return nil + }, + }, + } + require.ErrorContains(t, migrateIfNeeded(context.Background(), testDB, utils.Mainnet, utils.NewNopZapLogger(), migrations), "foo") + }) +} diff --git a/migration/migration_test.go b/migration/migration_test.go index 1e27b6b6c3..48dd3257fc 100644 --- a/migration/migration_test.go +++ b/migration/migration_test.go @@ -1,6 +1,7 @@ package migration_test import ( + "context" "testing" "github.com/NethermindEth/juno/db/pebble" @@ -12,18 +13,30 @@ import ( func TestMigrateIfNeeded(t *testing.T) { testDB := pebble.NewMemTest(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + t.Run("Migration should not happen on cancelled ctx", func(t *testing.T) { + require.ErrorIs(t, migration.MigrateIfNeeded(ctx, testDB, utils.Mainnet, utils.NewNopZapLogger()), ctx.Err()) + }) + + meta, err := migration.SchemaMetadata(testDB) + require.NoError(t, err) + require.Equal(t, uint64(0), meta.Version) + require.Nil(t, meta.IntermediateState) + t.Run("Migration should happen on empty DB", func(t *testing.T) { - require.NoError(t, migration.MigrateIfNeeded(testDB, utils.Mainnet, utils.NewNopZapLogger())) + require.NoError(t, migration.MigrateIfNeeded(context.Background(), testDB, utils.Mainnet, utils.NewNopZapLogger())) }) - version, err := migration.SchemaVersion(testDB) + meta, err = migration.SchemaMetadata(testDB) require.NoError(t, err) - require.NotEqual(t, 0, version) + require.NotEqual(t, uint64(0), meta.Version) + require.Nil(t, meta.IntermediateState) t.Run("subsequent calls to MigrateIfNeeded should not change the DB version", func(t *testing.T) { - require.NoError(t, migration.MigrateIfNeeded(testDB, utils.Mainnet, utils.NewNopZapLogger())) - postVersion, postErr := migration.SchemaVersion(testDB) + require.NoError(t, migration.MigrateIfNeeded(context.Background(), testDB, utils.Mainnet, utils.NewNopZapLogger())) + postVersion, postErr := migration.SchemaMetadata(testDB) require.NoError(t, postErr) - require.Equal(t, version, postVersion) + require.Equal(t, meta, postVersion) }) } diff --git a/node/node.go b/node/node.go index 40098b8537..8cb4ea7666 100644 --- a/node/node.go +++ b/node/node.go @@ -257,7 +257,11 @@ func (n *Node) Run(ctx context.Context) { } n.log.Debugw(fmt.Sprintf("Running Juno with config:\n%s", string(yamlConfig))) - if err := migration.MigrateIfNeeded(n.db, n.cfg.Network, n.log); err != nil { + if err := migration.MigrateIfNeeded(ctx, n.db, n.cfg.Network, n.log); err != nil { + if errors.Is(err, context.Canceled) { + n.log.Infow("DB Migration cancelled") + return + } n.log.Errorw("Error while migrating the DB", "err", err) return } diff --git a/node/node_test.go b/node/node_test.go index 2504ca91f0..31c677c4dd 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -1,6 +1,7 @@ package node_test import ( + "context" "testing" "time" @@ -33,6 +34,10 @@ func TestNewNode(t *testing.T) { P2PBootPeers: "", } - _, err := node.New(config, "v0.3") + n, err := node.New(config, "v0.3") require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + n.Run(ctx) }