From 07803e549240489474d2dbb1b979c4ae5807e25b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20IRMAK?= Date: Tue, 19 Sep 2023 10:21:46 +0300 Subject: [PATCH] Discard database txn if user callback panics --- db/pebble/db.go | 14 ++++++++++++++ db/pebble/db_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/db/pebble/db.go b/db/pebble/db.go index 5cd073c4c5..95bde89cf4 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -1,6 +1,8 @@ package pebble import ( + "fmt" + "os" "sync" "github.com/NethermindEth/juno/db" @@ -93,12 +95,14 @@ func (d *DB) Close() error { // View : see db.DB.View func (d *DB) View(fn func(txn db.Transaction) error) error { txn := d.NewTransaction(false) + defer discardTxnOnPanic(txn) return utils.RunAndWrapOnError(txn.Discard, fn(txn)) } // Update : see db.DB.Update func (d *DB) Update(fn func(txn db.Transaction) error) error { txn := d.NewTransaction(true) + defer discardTxnOnPanic(txn) if err := fn(txn); err != nil { return utils.RunAndWrapOnError(txn.Discard, err) } @@ -109,3 +113,13 @@ func (d *DB) Update(fn func(txn db.Transaction) error) error { func (d *DB) Impl() any { return d.pebble } + +func discardTxnOnPanic(txn db.Transaction) { + p := recover() + if p != nil { + if err := txn.Discard(); err != nil { + fmt.Fprintf(os.Stderr, "failed discarding panicing txn err: %s", err) + } + panic(p) + } +} diff --git a/db/pebble/db_test.go b/db/pebble/db_test.go index a6f51e46e7..c7f00847ad 100644 --- a/db/pebble/db_test.go +++ b/db/pebble/db_test.go @@ -391,3 +391,41 @@ func TestNext(t *testing.T) { require.NoError(t, it.Close()) }) } + +func TestPanic(t *testing.T) { + testDB := pebble.NewMemTest() + t.Cleanup(func() { + require.NoError(t, testDB.Close()) + }) + + t.Run("view", func(t *testing.T) { + defer func() { + p := recover() + require.NotNil(t, p) + }() + + require.NoError(t, testDB.View(func(txn db.Transaction) error { + panic("view") + })) + }) + + t.Run("update", func(t *testing.T) { + var panicingTxn db.Transaction + defer func() { + p := recover() + require.NotNil(t, p) + + require.ErrorIs(t, testDB.View(func(txn db.Transaction) error { + return txn.Get([]byte{0}, func(b []byte) error { return nil }) + }), db.ErrKeyNotFound) + require.EqualError(t, panicingTxn.Get([]byte{0}, func(b []byte) error { return nil }), "discarded txn") + }() + + require.NoError(t, testDB.Update(func(txn db.Transaction) error { + panicingTxn = txn + require.ErrorIs(t, txn.Get([]byte{0}, func(b []byte) error { return nil }), db.ErrKeyNotFound) + require.NoError(t, txn.Set([]byte{0}, []byte{0})) + panic("update") + })) + }) +}