Skip to content

Commit

Permalink
[chore] submodule pattern updates to TreeStore
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanlott committed Jul 12, 2023
1 parent 021d90f commit 6181db1
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 68 deletions.
11 changes: 8 additions & 3 deletions persistence/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ func (*persistenceModule) Create(bus modules.Bus, options ...modules.ModuleOptio
treeModule, err := trees.Create(
bus,
trees.WithTreeStoreDirectory(persistenceCfg.TreesStoreDir),
trees.WithLogger(m.logger))
trees.WithLogger(m.logger),
trees.WithTxIndexer(txIndexer))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create TreeStoreModule: %w", err)
}
treeStoreModule, ok := treeModule.(modules.TreeStoreModule)
if !ok {
return nil, fmt.Errorf("failed to cast %T as TreeStoreModule", treeModule)
}

m.config = persistenceCfg
Expand All @@ -117,7 +122,7 @@ func (*persistenceModule) Create(bus modules.Bus, options ...modules.ModuleOptio

m.blockStore = blockStore
m.txIndexer = txIndexer
m.stateTrees = treeModule
m.stateTrees = treeStoreModule

// TECHDEBT: reconsider if this is the best place to call `populateGenesisState`. Note that
// this forces the genesis state to be reloaded on every node startup until state
Expand Down
11 changes: 0 additions & 11 deletions persistence/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"

"github.com/jackc/pgx/v5"
"github.com/pokt-network/pocket/persistence/indexer"
ptypes "github.com/pokt-network/pocket/persistence/types"
coreTypes "github.com/pokt-network/pocket/shared/core/types"
)
Expand Down Expand Up @@ -91,16 +90,6 @@ func GetAccountsUpdated(
return accounts, nil
}

// GetTransactions takes a transaction indexer and returns the transactions for the current height
func GetTransactions(txi indexer.TxIndexer, height uint64) ([]*coreTypes.IndexedTransaction, error) {
// TECHDEBT(#813): Avoid this cast to int64
indexedTxs, err := txi.GetByHeight(int64(height), false)
if err != nil {
return nil, fmt.Errorf("failed to get transactions by height: %w", err)
}
return indexedTxs, nil
}

// GetPools returns the pools updated at the given height
func GetPools(pgtx pgx.Tx, height uint64) ([]*coreTypes.Account, error) {
pools, err := GetAccountsUpdated(pgtx, ptypes.Pool, height)
Expand Down
58 changes: 41 additions & 17 deletions persistence/trees/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@ package trees
import (
"fmt"

"github.com/pokt-network/pocket/persistence/indexer"
"github.com/pokt-network/pocket/persistence/kvstore"
"github.com/pokt-network/pocket/shared/modules"
"github.com/pokt-network/smt"
)

func (*treeStore) Create(bus modules.Bus, options ...modules.TreeStoreOption) (modules.TreeStoreModule, error) {
m := &treeStore{}
var _ modules.Module = &TreeStore{}

func (*TreeStore) Create(bus modules.Bus, options ...modules.ModuleOption) (modules.Module, error) {
m := &TreeStore{}

bus.RegisterModule(m)

for _, option := range options {
option(m)
}

m.SetBus(bus)
if m.TXI == nil {
m.TXI = bus.GetPersistenceModule().GetTxIndexer()
}

if err := m.setupTrees(); err != nil {
return nil, err
Expand All @@ -24,35 +31,52 @@ func (*treeStore) Create(bus modules.Bus, options ...modules.TreeStoreOption) (m
return m, nil
}

func Create(bus modules.Bus, options ...modules.TreeStoreOption) (modules.TreeStoreModule, error) {
return new(treeStore).Create(bus, options...)
func Create(bus modules.Bus, options ...modules.ModuleOption) (modules.Module, error) {
return new(TreeStore).Create(bus, options...)
}

// WithLogger assigns a logger for the tree store
func WithLogger(logger *modules.Logger) modules.TreeStoreOption {
return func(m modules.TreeStoreModule) {
if mod, ok := m.(*treeStore); ok {
func WithLogger(logger *modules.Logger) modules.ModuleOption {
return func(m modules.InjectableModule) {
if mod, ok := m.(*TreeStore); ok {
mod.logger = logger
}
}
}

// WithTreeStoreDirectory assigns the path where the tree store
// saves its data.
func WithTreeStoreDirectory(path string) modules.TreeStoreOption {
return func(m modules.TreeStoreModule) {
if mod, ok := m.(*treeStore); ok {
mod.treeStoreDir = path
func WithTreeStoreDirectory(path string) modules.ModuleOption {
return func(m modules.InjectableModule) {
mod, ok := m.(*TreeStore)
if ok {
mod.TreeStoreDir = path
}
}
}

// WithTxIndexer assigns a TxIndexer for use during operation.
func WithTxIndexer(txi indexer.TxIndexer) modules.ModuleOption {
return func(m modules.InjectableModule) {
mod, ok := m.(*TreeStore)
if ok {
mod.TXI = txi
}
}
}

func (t *treeStore) setupTrees() error {
if t.treeStoreDir == ":memory:" {
func (t *TreeStore) GetModuleName() string { return modules.TreeStoreModuleName }
func (t *TreeStore) Start() error { return nil }
func (t *TreeStore) Stop() error { return nil }
func (t *TreeStore) GetBus() modules.Bus { return t.Bus }
func (t *TreeStore) SetBus(bus modules.Bus) { t.Bus = bus }

func (t *TreeStore) setupTrees() error {
if t.TreeStoreDir == ":memory:" {
return t.setupInMemory()
}

nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.treeStoreDir, RootTreeName))
nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.TreeStoreDir, RootTreeName))
if err != nil {
return err
}
Expand All @@ -64,7 +88,7 @@ func (t *treeStore) setupTrees() error {
t.merkleTrees = make(map[string]*stateTree, len(stateTreeNames))

for i := 0; i < len(stateTreeNames); i++ {
nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.treeStoreDir, stateTreeNames[i]))
nodeStore, err := kvstore.NewKVStore(fmt.Sprintf("%s/%s_nodes", t.TreeStoreDir, stateTreeNames[i]))
if err != nil {
return err
}
Expand All @@ -78,7 +102,7 @@ func (t *treeStore) setupTrees() error {
return nil
}

func (t *treeStore) setupInMemory() error {
func (t *TreeStore) setupInMemory() error {
nodeStore := kvstore.NewMemKVStore()
t.rootTree = &stateTree{
name: RootTreeName,
Expand Down
121 changes: 121 additions & 0 deletions persistence/trees/module_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package trees_test

import (
"fmt"
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/pokt-network/pocket/internal/testutil"
"github.com/pokt-network/pocket/p2p/providers/current_height_provider"
"github.com/pokt-network/pocket/p2p/providers/peerstore_provider"
"github.com/pokt-network/pocket/persistence/trees"
"github.com/pokt-network/pocket/runtime"
"github.com/pokt-network/pocket/runtime/genesis"
"github.com/pokt-network/pocket/runtime/test_artifacts"
coreTypes "github.com/pokt-network/pocket/shared/core/types"
cryptoPocket "github.com/pokt-network/pocket/shared/crypto"
"github.com/pokt-network/pocket/shared/modules"
mockModules "github.com/pokt-network/pocket/shared/modules/mocks"
)

const (
serviceURLFormat = "node%d.consensus:42069"
)

func TestTreeStore_Create(t *testing.T) {
ctrl := gomock.NewController(t)
mockRuntimeMgr := mockModules.NewMockRuntimeMgr(ctrl)
mockBus := createMockBus(t, mockRuntimeMgr)

genesisStateMock := createMockGenesisState(nil)
persistenceMock := preparePersistenceMock(t, mockBus, genesisStateMock)

mockBus.EXPECT().GetPersistenceModule().Return(persistenceMock).AnyTimes()
persistenceMock.EXPECT().GetBus().AnyTimes().Return(mockBus)
persistenceMock.EXPECT().NewRWContext(int64(0)).AnyTimes()
persistenceMock.EXPECT().GetTxIndexer().AnyTimes()

treemod, err := trees.Create(mockBus,
trees.WithTreeStoreDirectory(":memory:"))
assert.NoError(t, err)
got := treemod.GetBus()
assert.Equal(t, got, mockBus)
}

func TestTreeStore_DebugClearAll(t *testing.T) {
// TODO: Write test case for the DebugClearAll method
t.Skip("TODO: Write test case for DebugClearAll method")
}

// createMockGenesisState configures and returns a mocked GenesisState
func createMockGenesisState(valKeys []cryptoPocket.PrivateKey) *genesis.GenesisState {
genesisState := new(genesis.GenesisState)
validators := make([]*coreTypes.Actor, len(valKeys))
for i, valKey := range valKeys {
addr := valKey.Address().String()
mockActor := &coreTypes.Actor{
ActorType: coreTypes.ActorType_ACTOR_TYPE_VAL,
Address: addr,
PublicKey: valKey.PublicKey().String(),
ServiceUrl: validatorId(i + 1),
StakedAmount: test_artifacts.DefaultStakeAmountString,
PausedHeight: int64(0),
UnstakingHeight: int64(0),
Output: addr,
}
validators[i] = mockActor
}
genesisState.Validators = validators

return genesisState
}

// Persistence mock - only needed for validatorMap access
func preparePersistenceMock(t *testing.T, busMock *mockModules.MockBus, genesisState *genesis.GenesisState) *mockModules.MockPersistenceModule {
ctrl := gomock.NewController(t)

persistenceModuleMock := mockModules.NewMockPersistenceModule(ctrl)
readCtxMock := mockModules.NewMockPersistenceReadContext(ctrl)

readCtxMock.EXPECT().GetAllValidators(gomock.Any()).Return(genesisState.GetValidators(), nil).AnyTimes()
readCtxMock.EXPECT().GetAllStakedActors(gomock.Any()).DoAndReturn(func(height int64) ([]*coreTypes.Actor, error) {
return testutil.Concatenate[*coreTypes.Actor](
genesisState.GetValidators(),
genesisState.GetServicers(),
genesisState.GetFishermen(),
genesisState.GetApplications(),
), nil
}).AnyTimes()
persistenceModuleMock.EXPECT().NewReadContext(gomock.Any()).Return(readCtxMock, nil).AnyTimes()
readCtxMock.EXPECT().Release().AnyTimes()

persistenceModuleMock.EXPECT().GetBus().Return(busMock).AnyTimes()
persistenceModuleMock.EXPECT().SetBus(busMock).AnyTimes()
persistenceModuleMock.EXPECT().GetModuleName().Return(modules.PersistenceModuleName).AnyTimes()
busMock.RegisterModule(persistenceModuleMock)

return persistenceModuleMock
}

func validatorId(i int) string {
return fmt.Sprintf(serviceURLFormat, i)
}

// createMockBus returns a mock bus with stubbed out functions for bus registration
func createMockBus(t *testing.T, runtimeMgr modules.RuntimeMgr) *mockModules.MockBus {
t.Helper()
ctrl := gomock.NewController(t)
mockBus := mockModules.NewMockBus(ctrl)
mockBus.EXPECT().GetRuntimeMgr().Return(runtimeMgr).AnyTimes()
mockBus.EXPECT().RegisterModule(gomock.Any()).DoAndReturn(func(m modules.Module) {
m.SetBus(mockBus)
}).AnyTimes()
mockModulesRegistry := mockModules.NewMockModulesRegistry(ctrl)
mockModulesRegistry.EXPECT().GetModule(peerstore_provider.PeerstoreProviderSubmoduleName).Return(nil, runtime.ErrModuleNotRegistered(peerstore_provider.PeerstoreProviderSubmoduleName)).AnyTimes()
mockModulesRegistry.EXPECT().GetModule(current_height_provider.ModuleName).Return(nil, runtime.ErrModuleNotRegistered(current_height_provider.ModuleName)).AnyTimes()
mockBus.EXPECT().GetModulesRegistry().Return(mockModulesRegistry).AnyTimes()
mockBus.EXPECT().PublishEventToBus(gomock.Any()).AnyTimes()
return mockBus
}
Loading

0 comments on commit 6181db1

Please sign in to comment.