Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dot/state): store raw authority keys and decode when verifying block signature #3627

Merged
merged 13 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions dot/digest/digest_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func TestHandler_HandleNextEpochData(t *testing.T) {

handler.handleBlockFinalisation(ctx)

stored, err := handler.epochState.(*state.EpochState).GetEpochData(targetEpoch, nil)
stored, err := handler.epochState.(*state.EpochState).GetEpochDataRaw(targetEpoch, nil)
require.NoError(t, err)

digestValue, err := digest.Value()
Expand All @@ -326,8 +326,7 @@ func TestHandler_HandleNextEpochData(t *testing.T) {
t.Fatal()
}

res, err := act.ToEpochData()
require.NoError(t, err)
res := act.ToEpochDataRaw()
require.Equal(t, res, stored)
}

Expand Down
61 changes: 23 additions & 38 deletions dot/state/epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,12 @@ func NewEpochStateFromGenesis(db database.Database, blockState *BlockState,
nextConfigData: make(nextEpochMap[types.NextConfigDataV1]),
}

auths, err := types.BABEAuthorityRawToAuthority(genesisConfig.GenesisAuthorities)
if err != nil {
return nil, err
epochDataRaw := &types.EpochDataRaw{
Authorities: genesisConfig.GenesisAuthorities,
Randomness: genesisConfig.Randomness,
}

err = s.SetEpochData(0, &types.EpochData{
Authorities: auths,
Randomness: genesisConfig.Randomness,
})
err = s.SetEpochDataRaw(0, epochDataRaw)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -235,10 +232,8 @@ func (s *EpochState) GetEpochForBlock(header *types.Header) (uint64, error) {
return 0, errNoPreRuntimeDigest
}

// SetEpochData sets the epoch data for a given epoch
func (s *EpochState) SetEpochData(epoch uint64, info *types.EpochData) error {
raw := info.ToEpochDataRaw()

// SetEpochDataRaw sets the epoch data raw for a given epoch
func (s *EpochState) SetEpochDataRaw(epoch uint64, raw *types.EpochDataRaw) error {
enc, err := scale.Marshal(*raw)
if err != nil {
return err
Expand All @@ -247,17 +242,17 @@ func (s *EpochState) SetEpochData(epoch uint64, info *types.EpochData) error {
return s.db.Put(epochDataKey(epoch), enc)
}

// GetEpochData returns the epoch data for a given epoch persisted in database
// GetEpochDataRaw returns the raw epoch data for a given epoch persisted in database
// otherwise will try to get the data from the in-memory map using the header
// if the header params is nil then it will search only in database
func (s *EpochState) GetEpochData(epoch uint64, header *types.Header) (*types.EpochData, error) {
epochData, err := s.getEpochDataFromDatabase(epoch)
func (s *EpochState) GetEpochDataRaw(epoch uint64, header *types.Header) (*types.EpochDataRaw, error) {
epochDataRaw, err := s.getEpochDataRawFromDatabase(epoch)
if err != nil && !errors.Is(err, database.ErrNotFound) {
return nil, fmt.Errorf("failed to retrieve epoch data from database: %w", err)
}

if epochData != nil {
return epochData, nil
if epochDataRaw != nil {
return epochDataRaw, nil
}

if header == nil {
Expand All @@ -272,38 +267,33 @@ func (s *EpochState) GetEpochData(epoch uint64, header *types.Header) (*types.Ep
return nil, fmt.Errorf("failed to get epoch data from memory: %w", err)
}

epochData, err = inMemoryEpochData.ToEpochData()
if err != nil {
return nil, fmt.Errorf("cannot transform into epoch data: %w", err)
}

return epochData, nil
return inMemoryEpochData.ToEpochDataRaw(), nil
}

// getEpochDataFromDatabase returns the epoch data for a given epoch persisted in database
func (s *EpochState) getEpochDataFromDatabase(epoch uint64) (*types.EpochData, error) {
// getEpochDataRawFromDatabase returns the epoch data for a given epoch persisted in database
func (s *EpochState) getEpochDataRawFromDatabase(epoch uint64) (*types.EpochDataRaw, error) {
enc, err := s.db.Get(epochDataKey(epoch))
if err != nil {
return nil, err
}

raw := &types.EpochDataRaw{}
raw := new(types.EpochDataRaw)
err = scale.Unmarshal(enc, raw)
if err != nil {
return nil, err
return nil, fmt.Errorf("unmarshaling into epoch data raw: %w", err)
}

return raw.ToEpochData()
return raw, nil
}

// GetLatestEpochData returns the EpochData for the current epoch
func (s *EpochState) GetLatestEpochData() (*types.EpochData, error) {
// GetLatestEpochDataRaw returns the EpochData for the current epoch
func (s *EpochState) GetLatestEpochDataRaw() (*types.EpochDataRaw, error) {
curr, err := s.GetCurrentEpoch()
if err != nil {
return nil, err
}

return s.GetEpochData(curr, nil)
return s.GetEpochDataRaw(curr, nil)
}

// SetConfigData sets the BABE config data for a given epoch
Expand Down Expand Up @@ -586,7 +576,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
nextEpoch = finalizedBlockEpoch + 1
}

epochInDatabase, err := s.getEpochDataFromDatabase(nextEpoch)
epochRawInDatabase, err := s.getEpochDataRawFromDatabase(nextEpoch)

// if an error occurs and the error is database.ErrNotFound we ignore
// since this error is what we will handle in the next lines
Expand All @@ -595,7 +585,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
}

// epoch data already defined we don't need to lookup in the map
if epochInDatabase != nil {
if epochRawInDatabase != nil {
return nil
}

Expand All @@ -604,12 +594,7 @@ func (s *EpochState) FinalizeBABENextEpochData(finalizedHeader *types.Header) er
return fmt.Errorf("cannot find next epoch data: %w", err)
}

ed, err := finalizedNextEpochData.ToEpochData()
if err != nil {
return fmt.Errorf("cannot transform epoch data: %w", err)
}

err = s.SetEpochData(nextEpoch, ed)
err = s.SetEpochDataRaw(nextEpoch, finalizedNextEpochData.ToEpochDataRaw())
if err != nil {
return fmt.Errorf("cannot set epoch data: %w", err)
}
Expand Down
32 changes: 13 additions & 19 deletions dot/state/epoch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,46 +58,42 @@ func TestEpochState_EpochData(t *testing.T) {
keyring, err := keystore.NewSr25519Keyring()
require.NoError(t, err)

auth := types.Authority{
Key: keyring.Alice().Public().(*sr25519.PublicKey),
auth := types.AuthorityRaw{
Key: keyring.Alice().Public().(*sr25519.PublicKey).AsBytes(),
Weight: 1,
}

info := &types.EpochData{
Authorities: []types.Authority{auth},
info := &types.EpochDataRaw{
Authorities: []types.AuthorityRaw{auth},
Randomness: [32]byte{77},
}

err = s.SetEpochData(1, info)
err = s.SetEpochDataRaw(1, info)
require.NoError(t, err)
res, err := s.GetEpochData(1, nil)
res, err := s.GetEpochDataRaw(1, nil)
require.NoError(t, err)
require.Equal(t, info.Randomness, res.Randomness)

for i, auth := range res.Authorities {
expected, err := info.Authorities[i].Encode()
require.NoError(t, err)
res, err := auth.Encode()
require.NoError(t, err)
require.Equal(t, expected, res)
require.Equal(t, info.Authorities[i], auth)
}
}

func TestEpochState_GetStartSlotForEpoch(t *testing.T) {
s := newEpochStateFromGenesis(t)

info := &types.EpochData{
info := &types.EpochDataRaw{
Randomness: [32]byte{77},
}

err := s.SetEpochData(2, info)
err := s.SetEpochDataRaw(2, info)
require.NoError(t, err)

info = &types.EpochData{
info = &types.EpochDataRaw{
Randomness: [32]byte{77},
}

err = s.SetEpochData(3, info)
err = s.SetEpochDataRaw(3, info)
require.NoError(t, err)

start, err := s.GetStartSlotForEpoch(0)
Expand Down Expand Up @@ -405,10 +401,8 @@ func TestStoreAndFinalizeBabeNextEpochData(t *testing.T) {
} else {
require.NoError(t, err)

expected, err := expectedNextEpochData.ToEpochData()
require.NoError(t, err)

gotNextEpochData, err := epochState.GetEpochData(tt.finalizeEpoch, nil)
expected := expectedNextEpochData.ToEpochDataRaw()
gotNextEpochData, err := epochState.GetEpochDataRaw(tt.finalizeEpoch, nil)
require.NoError(t, err)

require.Equal(t, expected, gotNextEpochData)
Expand Down
13 changes: 4 additions & 9 deletions dot/types/consensus_digest.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,11 @@ func (d NextEpochData) String() string { //skipcq: GO-W1029
}

// ToEpochData returns the NextEpochData as EpochData
func (d *NextEpochData) ToEpochData() (*EpochData, error) { //skipcq: GO-W1029
auths, err := BABEAuthorityRawToAuthority(d.Authorities)
if err != nil {
return nil, err
}

return &EpochData{
Authorities: auths,
func (d *NextEpochData) ToEpochDataRaw() *EpochDataRaw {
return &EpochDataRaw{
Authorities: d.Authorities,
Randomness: d.Randomness,
}, nil
}
}

// BABEOnDisabled represents a GRANDPA authority being disabled
Expand Down
12 changes: 4 additions & 8 deletions lib/babe/babe.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,28 +258,24 @@ func (b *Service) Stop() error {
}

// Authorities returns the current BABE authorities
func (b *Service) Authorities() []types.Authority {
auths := make([]types.Authority, len(b.epochHandler.epochData.authorities))
for i, auth := range b.epochHandler.epochData.authorities {
auths[i] = *auth.DeepCopy()
}
return auths
func (b *Service) AuthoritiesRaw() []types.AuthorityRaw {
return b.epochHandler.epochData.authorities
}

// IsStopped returns true if the service is stopped (ie not producing blocks)
func (b *Service) IsStopped() bool {
return b.ctx.Err() != nil
}

func (b *Service) getAuthorityIndex(Authorities []types.Authority) (uint32, error) {
func (b *Service) getAuthorityIndex(Authorities []types.AuthorityRaw) (uint32, error) {
if !b.authority {
return 0, ErrNotAuthority
}

pub := b.keypair.Public()

for i, auth := range Authorities {
if bytes.Equal(pub.Encode(), auth.Key.Encode()) {
if bytes.Equal(pub.Encode(), auth.Key[:]) {
return uint32(i), nil
}
}
Expand Down
6 changes: 3 additions & 3 deletions lib/babe/babe_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ func TestService_GetAuthorityIndex(t *testing.T) {
pubA := kpA.Public().(*sr25519.PublicKey)
pubB := kpB.Public().(*sr25519.PublicKey)

authData := []types.Authority{
{Key: pubA, Weight: 1},
{Key: pubB, Weight: 1},
authData := []types.AuthorityRaw{
{Key: pubA.AsBytes(), Weight: 1},
{Key: pubB.AsBytes(), Weight: 1},
}

bs := &Service{
Expand Down
6 changes: 3 additions & 3 deletions lib/babe/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func checkPrimaryThreshold(randomness Randomness,

func claimSecondarySlotVRF(randomness Randomness,
slot, epoch uint64,
authorities []types.Authority,
authorities []types.AuthorityRaw,
keypair *sr25519.Keypair,
authorityIndex uint32,
) (*VrfOutputAndProof, error) {
Expand Down Expand Up @@ -123,8 +123,8 @@ func claimSecondarySlotVRF(randomness Randomness,
}, nil
}

func claimSecondarySlotPlain(randomness Randomness, slot uint64, authorities []types.Authority, authorityIndex uint32,
) error {
func claimSecondarySlotPlain(randomness Randomness, slot uint64,
authorities []types.AuthorityRaw, authorityIndex uint32) error {
secondarySlotAuthor, err := getSecondarySlotAuthor(slot, len(authorities), randomness)
if err != nil {
return fmt.Errorf("cannot get secondary slot author: %w", err)
Expand Down
8 changes: 4 additions & 4 deletions lib/babe/epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *Service) getEpochData(epoch uint64, bestBlock *types.Header) (*epochDat
return epochData, nil
}

currEpochData, err := b.epochState.GetEpochData(epoch, bestBlock)
currEpochData, err := b.epochState.GetEpochDataRaw(epoch, bestBlock)
if err != nil {
return nil, fmt.Errorf("cannot get epoch data for epoch %d: %w", epoch, err)
}
Expand Down Expand Up @@ -127,13 +127,13 @@ func (b *Service) getEpochData(epoch uint64, bestBlock *types.Header) (*epochDat
func (b *Service) getLatestEpochData() (resEpochData *epochData, error error) {
resEpochData = &epochData{}

epochData, err := b.epochState.GetLatestEpochData()
epochDataRaw, err := b.epochState.GetLatestEpochDataRaw()
if err != nil {
return nil, fmt.Errorf("cannot get latest epoch data: %w", err)
}

resEpochData.randomness = epochData.Randomness
resEpochData.authorities = epochData.Authorities
resEpochData.randomness = epochDataRaw.Randomness
resEpochData.authorities = epochDataRaw.Authorities

configData, err := b.epochState.GetLatestConfigData()
if err != nil {
Expand Down
14 changes: 10 additions & 4 deletions lib/babe/epoch_handler_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ func TestEpochHandler_run_shouldReturnAfterContextCancel(t *testing.T) {
epochData := &epochData{
threshold: scale.MaxUint128,
authorityIndex: authorityIndex,
authorities: []types.Authority{
*types.NewAuthority(aliceKeyPair.Public(), 1),
authorities: []types.AuthorityRaw{
{
Key: [32]byte(aliceKeyPair.Public().Encode()),
Weight: 1,
},
},
}

Expand Down Expand Up @@ -66,8 +69,11 @@ func TestEpochHandler_run(t *testing.T) {
epochData := &epochData{
threshold: scale.MaxUint128,
authorityIndex: authorityIndex,
authorities: []types.Authority{
*types.NewAuthority(aliceKeyPair.Public(), 1),
authorities: []types.AuthorityRaw{
{
Key: [32]byte(aliceKeyPair.Public().Encode()),
Weight: 1,
},
},
}

Expand Down
Loading
Loading