Skip to content

Commit

Permalink
Implement test
Browse files Browse the repository at this point in the history
  • Loading branch information
rianhughes committed Sep 25, 2024
1 parent 33ca761 commit 48a55b2
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 17 deletions.
98 changes: 98 additions & 0 deletions mocks/mock_plugin.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package junoplugin_test

import (
"context"
"testing"
"time"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/clients/feeder"
"github.com/NethermindEth/juno/db/pebble"
"github.com/NethermindEth/juno/mocks"
junoplugin "github.com/NethermindEth/juno/plugin"
adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder"
"github.com/NethermindEth/juno/sync"
"github.com/NethermindEth/juno/utils"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

const timeout = time.Second

func TestPlugin(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

plugin := mocks.NewMockJunoPlugin(mockCtrl)

mainClient := feeder.NewTestClient(t, &utils.Mainnet)
mainGw := adaptfeeder.New(mainClient)

integClient := feeder.NewTestClient(t, &utils.Integration)
integGw := adaptfeeder.New(integClient)

testDB := pebble.NewMemTest(t)

// sync to integration for 2 blocks
for i := range 2 {
su, block, err := integGw.StateUpdateWithBlock(context.Background(), uint64(i))
require.NoError(t, err)
plugin.EXPECT().NewBlock(block, su, gomock.Any())
}
bc := blockchain.New(testDB, &utils.Integration)
synchronizer := sync.New(bc, integGw, utils.NewNopZapLogger(), 0, false).WithPlugin(plugin)

ctx, cancel := context.WithTimeout(context.Background(), timeout)
require.NoError(t, synchronizer.Run(ctx))
cancel()

t.Run("resync to mainnet with the same db", func(t *testing.T) {
bc := blockchain.New(testDB, &utils.Mainnet)

// Ensure current head is Integration head
head, err := bc.HeadsHeader()
require.NoError(t, err)
require.Equal(t, utils.HexToFelt(t, "0x34e815552e42c5eb5233b99de2d3d7fd396e575df2719bf98e7ed2794494f86"), head.Hash)

// Reorg 2 blocks, then sync 3 blocks
su1, block1, err := integGw.StateUpdateWithBlock(context.Background(), uint64(1))
require.NoError(t, err)
su0, block0, err := integGw.StateUpdateWithBlock(context.Background(), uint64(0))
require.NoError(t, err)
plugin.EXPECT().RevertBlock(&junoplugin.BlockAndStateUpdate{block1, su1}, &junoplugin.BlockAndStateUpdate{block0, su0}, gomock.Any())
plugin.EXPECT().RevertBlock(&junoplugin.BlockAndStateUpdate{block0, su0}, &junoplugin.BlockAndStateUpdate{nil, nil}, gomock.Any())
for i := range 3 {
su, block, err := mainGw.StateUpdateWithBlock(context.Background(), uint64(i))
require.NoError(t, err)
plugin.EXPECT().NewBlock(block, su, gomock.Any())
}

synchronizer = sync.New(bc, mainGw, utils.NewNopZapLogger(), 0, false).WithPlugin(plugin)
ctx, cancel = context.WithTimeout(context.Background(), timeout)
require.NoError(t, synchronizer.Run(ctx))
cancel()

// After syncing (and reorging) the current head should be at mainnet
head, err = bc.HeadsHeader()
require.NoError(t, err)
require.Equal(t, utils.HexToFelt(t, "0x4e1f77f39545afe866ac151ac908bd1a347a2a8a7d58bef1276db4f06fdf2f6"), head.Hash)
})
}
37 changes: 24 additions & 13 deletions sync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,40 @@ func (s *Synchronizer) fetchUnknownClasses(ctx context.Context, stateUpdate *cor
return newClasses, closer()
}

func (s *Synchronizer) handlePluginRevertBlock(block *core.Block, stateUpdate *core.StateUpdate, reverseStateDiff *core.StateDiff) {
func (s *Synchronizer) handlePluginRevertBlock() {
if s.plugin == nil {
return

Check warning on line 193 in sync/sync.go

View check run for this annotation

Codecov / codecov/patch

sync/sync.go#L193

Added line #L193 was not covered by tests
}

toBlock, err := s.blockchain.Head()
fromBlock, err := s.blockchain.Head()
if err != nil {
s.log.Warnw("Failed to retrieve the reverted blockchain head block for the plugin", "err", err)
return

Check warning on line 198 in sync/sync.go

View check run for this annotation

Codecov / codecov/patch

sync/sync.go#L197-L198

Added lines #L197 - L198 were not covered by tests
}

toSU, err := s.blockchain.StateUpdateByNumber(toBlock.Number)
fromSU, err := s.blockchain.StateUpdateByNumber(fromBlock.Number)
if err != nil {
s.log.Warnw("Failed to retrieve the reverted blockchain head state-update for the plugin", "err", err)
return

Check warning on line 203 in sync/sync.go

View check run for this annotation

Codecov / codecov/patch

sync/sync.go#L202-L203

Added lines #L202 - L203 were not covered by tests
}

reverseStateDiff, err := s.blockchain.GetReverseStateDiff()
if err != nil {
s.log.Warnw("Failed to retrieve reverse state diff", "head", fromBlock.Number, "hash", fromBlock.Hash.ShortString(), "err", err)
}
var toBlock *core.Block
var toSU *core.StateUpdate
if fromBlock.Number != 0 {
toBlock, err = s.blockchain.BlockByHash(fromBlock.ParentHash)
if err != nil {
s.log.Warnw("Failed to retrieve the parent block for the plugin", "err", err)
return

Check warning on line 215 in sync/sync.go

View check run for this annotation

Codecov / codecov/patch

sync/sync.go#L214-L215

Added lines #L214 - L215 were not covered by tests
}
toSU, err = s.blockchain.StateUpdateByNumber(toBlock.Number)
if err != nil {
s.log.Warnw("Failed to retrieve the parents state-update for the plugin", "err", err)
return

Check warning on line 220 in sync/sync.go

View check run for this annotation

Codecov / codecov/patch

sync/sync.go#L219-L220

Added lines #L219 - L220 were not covered by tests
}
}
err = (s.plugin).RevertBlock(
&junoplugin.BlockAndStateUpdate{Block: block, StateUpdate: stateUpdate},
&junoplugin.BlockAndStateUpdate{Block: fromBlock, StateUpdate: fromSU},
&junoplugin.BlockAndStateUpdate{Block: toBlock, StateUpdate: toSU},
reverseStateDiff)
if err != nil {
Expand Down Expand Up @@ -239,14 +254,10 @@ func (s *Synchronizer) verifierTask(ctx context.Context, block *core.Block, stat
// revert the head and restart the sync process, hoping that the reorg is not deep
// if the reorg is deeper, we will end up here again and again until we fully revert reorged
// blocks
reverseStateDiff, err := s.blockchain.GetReverseStateDiff()
if err != nil {
s.log.Warnw("Failed to retrieve reverse state diff", "head", block.Number, "hash", block.Hash.ShortString(), "err", err)
}
s.revertHead(block)
if s.plugin != nil {
s.handlePluginRevertBlock(block, stateUpdate, reverseStateDiff)
s.handlePluginRevertBlock()
}
s.revertHead(block)
} else {
s.log.Warnw("Failed storing Block", "number", block.Number,
"hash", block.Hash.ShortString(), "err", err)
Expand Down
15 changes: 11 additions & 4 deletions vm/rust/src/juno_state_reader.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
ffi::{c_char, c_uchar, c_void, c_int, CStr},
ffi::{c_char, c_int, c_uchar, c_void, CStr},
slice,
sync::Mutex,
};
Expand Down Expand Up @@ -75,8 +75,14 @@ impl StateReader for JunoStateReader {
let addr = felt_to_byte_array(contract_address.0.key());
let storage_key = felt_to_byte_array(key.0.key());
let mut buffer: [u8; 32] = [0; 32];
let wrote =
unsafe { JunoStateGetStorageAt(self.handle, addr.as_ptr(), storage_key.as_ptr(), buffer.as_mut_ptr()) };
let wrote = unsafe {
JunoStateGetStorageAt(
self.handle,
addr.as_ptr(),
storage_key.as_ptr(),
buffer.as_mut_ptr(),
)
};
if wrote == 0 {
Err(StateError::StateReadError(format!(
"failed to read location {} at address {}",
Expand Down Expand Up @@ -111,7 +117,8 @@ impl StateReader for JunoStateReader {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let addr = felt_to_byte_array(contract_address.0.key());
let mut buffer: [u8; 32] = [0; 32];
let wrote = unsafe { JunoStateGetClassHashAt(self.handle, addr.as_ptr(), buffer.as_mut_ptr()) };
let wrote =
unsafe { JunoStateGetClassHashAt(self.handle, addr.as_ptr(), buffer.as_mut_ptr()) };
if wrote == 0 {
Err(StateError::StateReadError(format!(
"failed to read class hash of address {}",
Expand Down

0 comments on commit 48a55b2

Please sign in to comment.