Skip to content

Commit

Permalink
refactor(auth): use collections for Account state management (#16016)
Browse files Browse the repository at this point in the history
Co-authored-by: unknown unknown <unknown@unknown>
  • Loading branch information
testinginprod and unknown unknown committed May 26, 2023
1 parent 52ccb7b commit 3d15f9e
Show file tree
Hide file tree
Showing 24 changed files with 188 additions and 538 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ Ref: https://keepachangelog.com/en/1.0.0/
* `simulation.NewOperationMsg` is now 2-arity instead of 3-arity with the obsolete argument `codec.ProtoCodec` removed.
* The field `OperationMsg.Msg` is now of type `[]byte` instead of `json.RawMessage`.
* (cli) [#16209](https://github.com/cosmos/cosmos-sdk/pull/16209) Add API `StartCmdWithOptions` to create customized start command.
* (x/auth) [#16016](https://github.com/cosmos/cosmos-sdk/pull/16016) Use collections for accounts state management:
- removed: keeper `HasAccountByID`, `AccountAddressByID`, `SetParams
* (x/distribution) [#16211](https://github.com/cosmos/cosmos-sdk/pull/16211) Use collections for params state management.
* [#15284](https://github.com/cosmos/cosmos-sdk/pull/15284)
* `sdk.Msg.GetSigners` was deprecated and is no longer supported. Use the `cosmos.msg.v1.signer` protobuf annotation instead.
Expand Down
2 changes: 1 addition & 1 deletion baseapp/block_gas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func TestBaseApp_BlockGas(t *testing.T) {
require.Equal(t, []byte("ok"), okValue)
}
// check block gas is always consumed
baseGas := uint64(50702) // baseGas is the gas consumed before tx msg
baseGas := uint64(57554) // baseGas is the gas consumed before tx msg
expGasConsumed := addUint64Saturating(tc.gasToConsume, baseGas)
if expGasConsumed > txtypes.MaxGasWanted {
// capped by gasLimit
Expand Down
6 changes: 6 additions & 0 deletions collections/indexes/reverse_pair_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,10 @@ func TestReversePair(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "address1", pks[0].K1())
require.Equal(t, "address2", pks[1].K1())

// assert if we remove address1 atom balance, we can no longer find it in the index
err = indexedMap.Remove(ctx, collections.Join("address1", "atom"))
require.NoError(t, err)
_, err = indexedMap.Indexes.Denom.MatchExact(ctx, "atom")
require.ErrorIs(t, collections.ErrInvalidIterator, err)
}
48 changes: 39 additions & 9 deletions tests/integration/gov/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"testing"

dbm "github.com/cosmos/cosmos-db"

abci "github.com/cometbft/cometbft/abci/types"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
"gotest.tools/v3/assert"
Expand Down Expand Up @@ -75,9 +77,10 @@ func TestImportExportQueues(t *testing.T) {
ctx := s1.app.BaseApp.NewContext(false, cmtproto.Header{})
addrs := simtestutil.AddTestAddrs(s1.BankKeeper, s1.StakingKeeper, ctx, 1, valTokens)

s1.app.FinalizeBlock(&abci.RequestFinalizeBlock{
_, err = s1.app.FinalizeBlock(&abci.RequestFinalizeBlock{
Height: s1.app.LastBlockHeight() + 1,
})
assert.NilError(t, err)

ctx = s1.app.BaseApp.NewContext(false, cmtproto.Header{})
// Create two proposals, put the second into the voting period
Expand Down Expand Up @@ -121,29 +124,41 @@ func TestImportExportQueues(t *testing.T) {
assert.NilError(t, err)

s2 := suite{}
db := dbm.NewMemDB()
conf2 := simtestutil.DefaultStartUpConfig()
conf2.DB = db
s2.app, err = simtestutil.SetupWithConfiguration(
depinject.Configs(
appConfig,
depinject.Supply(log.NewNopLogger()),
),
simtestutil.DefaultStartUpConfig(),
conf2,
&s2.AccountKeeper, &s2.BankKeeper, &s2.DistrKeeper, &s2.GovKeeper, &s2.StakingKeeper, &s2.cdc, &s2.appBuilder,
)
assert.NilError(t, err)

s2.app.InitChain(&abci.RequestInitChain{
Validators: []abci.ValidatorUpdate{},
ConsensusParams: simtestutil.DefaultConsensusParams,
AppStateBytes: stateBytes,
})
clearDB(t, db)
err = s2.app.CommitMultiStore().LoadLatestVersion()
assert.NilError(t, err)

_, err = s2.app.InitChain(
&abci.RequestInitChain{
Validators: []abci.ValidatorUpdate{},
ConsensusParams: simtestutil.DefaultConsensusParams,
AppStateBytes: stateBytes,
},
)
assert.NilError(t, err)

s2.app.FinalizeBlock(&abci.RequestFinalizeBlock{
_, err = s2.app.FinalizeBlock(&abci.RequestFinalizeBlock{
Height: s2.app.LastBlockHeight() + 1,
})
assert.NilError(t, err)

s2.app.FinalizeBlock(&abci.RequestFinalizeBlock{
_, err = s2.app.FinalizeBlock(&abci.RequestFinalizeBlock{
Height: s2.app.LastBlockHeight() + 1,
})
assert.NilError(t, err)

ctx2 := s2.app.BaseApp.NewContext(false, cmtproto.Header{})

Expand Down Expand Up @@ -174,3 +189,18 @@ func TestImportExportQueues(t *testing.T) {
assert.NilError(t, err)
assert.Assert(t, proposal2.Status == v1.StatusRejected)
}

func clearDB(t *testing.T, db *dbm.MemDB) {
iter, err := db.Iterator(nil, nil)
assert.NilError(t, err)
defer iter.Close()

var keys [][]byte
for ; iter.Valid(); iter.Next() {
keys = append(keys, iter.Key())
}

for _, k := range keys {
assert.NilError(t, db.Delete(k))
}
}
4 changes: 2 additions & 2 deletions x/auth/ante/ante_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1418,15 +1418,15 @@ func TestAnteHandlerReCheck(t *testing.T) {
for _, tc := range testCases {

// set testcase parameters
err := suite.accountKeeper.SetParams(suite.ctx, tc.params)
err := suite.accountKeeper.Params.Set(suite.ctx, tc.params)
require.NoError(t, err)

_, err = suite.anteHandler(suite.ctx, tx, false)

require.NotNil(t, err, "tx does not fail on recheck with updated params in test case: %s", tc.name)

// reset parameters to default values
err = suite.accountKeeper.SetParams(suite.ctx, authtypes.DefaultParams())
err = suite.accountKeeper.Params.Set(suite.ctx, authtypes.DefaultParams())
require.NoError(t, err)
}

Expand Down
20 changes: 11 additions & 9 deletions x/auth/ante/sigverify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestSetPubKey(t *testing.T) {
// set accounts and create msg for each address
for i, addr := range addrs {
acc := suite.accountKeeper.NewAccountWithAddress(suite.ctx, addr)
require.NoError(t, acc.SetAccountNumber(uint64(i)))
require.NoError(t, acc.SetAccountNumber(uint64(i+1000)))
suite.accountKeeper.SetAccount(suite.ctx, acc)
msgs[i] = testdata.NewTestMsg(addr)
}
Expand Down Expand Up @@ -153,12 +153,14 @@ func TestSigVerification(t *testing.T) {
addrs := []sdk.AccAddress{addr1, addr2, addr3}

msgs := make([]sdk.Msg, len(addrs))
accs := make([]sdk.AccountI, len(addrs))
// set accounts and create msg for each address
for i, addr := range addrs {
acc := suite.accountKeeper.NewAccountWithAddress(suite.ctx, addr)
require.NoError(t, acc.SetAccountNumber(uint64(i)))
require.NoError(t, acc.SetAccountNumber(uint64(i)+1000))
suite.accountKeeper.SetAccount(suite.ctx, acc)
msgs[i] = testdata.NewTestMsg(addr)
accs[i] = acc
}

feeAmount := testdata.NewTestFeeAmount()
Expand Down Expand Up @@ -190,11 +192,11 @@ func TestSigVerification(t *testing.T) {
validSigs := false
testCases := []testCase{
{"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, validSigs, false, true},
{"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0}, validSigs, false, true},
{"wrong order signers", []cryptotypes.PrivKey{priv3, priv2, priv1}, []uint64{2, 1, 0}, []uint64{0, 0, 0}, validSigs, false, true},
{"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber()}, []uint64{0, 0}, validSigs, false, true},
{"wrong order signers", []cryptotypes.PrivKey{priv3, priv2, priv1}, []uint64{accs[2].GetAccountNumber(), accs[1].GetAccountNumber(), accs[0].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, true},
{"wrong accnums", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{7, 8, 9}, []uint64{0, 0, 0}, validSigs, false, true},
{"wrong sequences", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{3, 4, 5}, validSigs, false, true},
{"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 1, 2}, []uint64{0, 0, 0}, validSigs, false, false},
{"wrong sequences", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{3, 4, 5}, validSigs, false, true},
{"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, false},
{"no err on recheck", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 0, 0}, []uint64{0, 0, 0}, !validSigs, true, false},
}

Expand Down Expand Up @@ -265,7 +267,7 @@ func runSigDecorators(t *testing.T, params types.Params, _ bool, privs ...crypto

// Make block-height non-zero to include accNum in SignBytes
suite.ctx = suite.ctx.WithBlockHeight(1)
err := suite.accountKeeper.SetParams(suite.ctx, params)
err := suite.accountKeeper.Params.Set(suite.ctx, params)
require.NoError(t, err)

msgs := make([]sdk.Msg, len(privs))
Expand All @@ -275,10 +277,10 @@ func runSigDecorators(t *testing.T, params types.Params, _ bool, privs ...crypto
for i, priv := range privs {
addr := sdk.AccAddress(priv.PubKey().Address())
acc := suite.accountKeeper.NewAccountWithAddress(suite.ctx, addr)
require.NoError(t, acc.SetAccountNumber(uint64(i)))
require.NoError(t, acc.SetAccountNumber(uint64(i)+1000))
suite.accountKeeper.SetAccount(suite.ctx, acc)
msgs[i] = testdata.NewTestMsg(addr)
accNums[i] = uint64(i)
accNums[i] = acc.GetAccountNumber()
accSeqs[i] = uint64(0)
}
require.NoError(t, suite.txBuilder.SetMsgs(msgs...))
Expand Down
2 changes: 1 addition & 1 deletion x/auth/ante/testutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func SetupTestSuite(t *testing.T, isCheckTx bool) *AnteTestSuite {
suite.encCfg.Codec, runtime.NewKVStoreService(key), types.ProtoBaseAccount, maccPerms, sdk.Bech32MainPrefix, types.NewModuleAddress("gov").String(),
)
suite.accountKeeper.GetModuleAccount(suite.ctx, types.FeeCollectorName)
err := suite.accountKeeper.SetParams(suite.ctx, types.DefaultParams())
err := suite.accountKeeper.Params.Set(suite.ctx, types.DefaultParams())
require.NoError(t, err)

// We're using TestMsg encoding in some tests, so register it here.
Expand Down
77 changes: 11 additions & 66 deletions x/auth/keeper/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package keeper

import (
"context"
"errors"

storetypes "cosmossdk.io/store/types"
"cosmossdk.io/collections"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/auth/types"
)

// NewAccountWithAddress implements AccountKeeperI.
Expand All @@ -31,51 +31,17 @@ func (ak AccountKeeper) NewAccount(ctx context.Context, acc sdk.AccountI) sdk.Ac

// HasAccount implements AccountKeeperI.
func (ak AccountKeeper) HasAccount(ctx context.Context, addr sdk.AccAddress) bool {
store := ak.storeService.OpenKVStore(ctx)
has, err := store.Has(types.AddressStoreKey(addr))
if err != nil {
panic(err)
}
return has
}

// HasAccountAddressByID checks account address exists by id.
func (ak AccountKeeper) HasAccountAddressByID(ctx context.Context, id uint64) bool {
store := ak.storeService.OpenKVStore(ctx)
has, err := store.Has(types.AccountNumberStoreKey(id))
if err != nil {
panic(err)
}
has, _ := ak.Accounts.Has(ctx, addr)
return has
}

// GetAccount implements AccountKeeperI.
func (ak AccountKeeper) GetAccount(ctx context.Context, addr sdk.AccAddress) sdk.AccountI {
store := ak.storeService.OpenKVStore(ctx)
bz, err := store.Get(types.AddressStoreKey(addr))
if err != nil {
acc, err := ak.Accounts.Get(ctx, addr)
if err != nil && !errors.Is(err, collections.ErrNotFound) {
panic(err)
}

if bz == nil {
return nil
}

return ak.decodeAccount(bz)
}

// GetAccountAddressById returns account address by id.
func (ak AccountKeeper) GetAccountAddressByID(ctx context.Context, id uint64) string {
store := ak.storeService.OpenKVStore(ctx)
bz, err := store.Get(types.AccountNumberStoreKey(id))
if err != nil {
panic(err)
}

if bz == nil {
return ""
}
return sdk.AccAddress(bz).String()
return acc
}

// GetAllAccounts returns all accounts in the accountKeeper.
Expand All @@ -90,29 +56,16 @@ func (ak AccountKeeper) GetAllAccounts(ctx context.Context) (accounts []sdk.Acco

// SetAccount implements AccountKeeperI.
func (ak AccountKeeper) SetAccount(ctx context.Context, acc sdk.AccountI) {
addr := acc.GetAddress()
store := ak.storeService.OpenKVStore(ctx)

bz, err := ak.MarshalAccount(acc)
err := ak.Accounts.Set(ctx, acc.GetAddress(), acc)
if err != nil {
panic(err)
}

store.Set(types.AddressStoreKey(addr), bz)
store.Set(types.AccountNumberStoreKey(acc.GetAccountNumber()), addr.Bytes())
}

// RemoveAccount removes an account for the account mapper store.
// NOTE: this will cause supply invariant violation if called
func (ak AccountKeeper) RemoveAccount(ctx context.Context, acc sdk.AccountI) {
addr := acc.GetAddress()
store := ak.storeService.OpenKVStore(ctx)
err := store.Delete(types.AddressStoreKey(addr))
if err != nil {
panic(err)
}

err = store.Delete(types.AccountNumberStoreKey(acc.GetAccountNumber()))
err := ak.Accounts.Remove(ctx, acc.GetAddress())
if err != nil {
panic(err)
}
Expand All @@ -121,18 +74,10 @@ func (ak AccountKeeper) RemoveAccount(ctx context.Context, acc sdk.AccountI) {
// IterateAccounts iterates over all the stored accounts and performs a callback function.
// Stops iteration when callback returns true.
func (ak AccountKeeper) IterateAccounts(ctx context.Context, cb func(account sdk.AccountI) (stop bool)) {
store := ak.storeService.OpenKVStore(ctx)
iterator, err := store.Iterator(types.AddressStoreKeyPrefix, storetypes.PrefixEndBytes(types.AddressStoreKeyPrefix))
err := ak.Accounts.Walk(ctx, nil, func(_ sdk.AccAddress, value sdk.AccountI) (bool, error) {
return cb(value), nil
})
if err != nil {
panic(err)
}

defer iterator.Close()
for ; iterator.Valid(); iterator.Next() {
account := ak.decodeAccount(iterator.Value())

if cb(account) {
break
}
}
}
Loading

0 comments on commit 3d15f9e

Please sign in to comment.