diff --git a/dot/state/block.go b/dot/state/block.go index 0ce071a6e4..5ac6233200 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -205,8 +205,6 @@ func (bs *BlockState) GetHeader(hash common.Hash) (header *types.Header, err err return header, nil } - result := types.NewEmptyHeader() - if bs.db == nil { return nil, fmt.Errorf("database is nil") } @@ -220,6 +218,7 @@ func (bs *BlockState) GetHeader(hash common.Hash) (header *types.Header, err err return nil, err } + result := types.NewEmptyHeader() err = scale.Unmarshal(data, result) if err != nil { return nil, err diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index d7677bebc4..cf867c9f45 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -118,7 +118,10 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er bs.Lock() defer bs.Unlock() - has, _ := bs.HasHeader(hash) + has, err := bs.HasHeader(hash) + if err != nil { + return fmt.Errorf("could not check header for hash %s: %w", hash, err) + } if !has { return fmt.Errorf("cannot finalise unknown block %s", hash) } diff --git a/dot/state/grandpa.go b/dot/state/grandpa.go index 7b29387e25..32850e9dec 100644 --- a/dot/state/grandpa.go +++ b/dot/state/grandpa.go @@ -138,7 +138,8 @@ func (s *GrandpaState) GetLatestRound() (uint64, error) { return round, nil } -// SetNextChange sets the next authority change +// SetNextChange sets the next authority change at the given block number. +// NOTE: This block number will be the last block in the current set and not part of the next set. func (s *GrandpaState) SetNextChange(authorities []types.GrandpaVoter, number uint) error { currSetID, err := s.GetCurrentSetID() if err != nil { @@ -180,7 +181,7 @@ func (s *GrandpaState) setSetIDChangeAtBlock(setID uint64, number uint) error { return s.db.Put(setIDChangeKey(setID), common.UintToBytes(number)) } -// GetSetIDChange returs the block number where the set ID was updated +// GetSetIDChange returns the block number where the set ID was updated func (s *GrandpaState) GetSetIDChange(setID uint64) (blockNumber uint, err error) { num, err := s.db.Get(setIDChangeKey(setID)) if err != nil { @@ -191,7 +192,7 @@ func (s *GrandpaState) GetSetIDChange(setID uint64) (blockNumber uint, err error } // GetSetIDByBlockNumber returns the set ID for a given block number -func (s *GrandpaState) GetSetIDByBlockNumber(num uint) (uint64, error) { +func (s *GrandpaState) GetSetIDByBlockNumber(blockNumber uint) (uint64, error) { curr, err := s.GetCurrentSetID() if err != nil { return 0, err @@ -215,13 +216,16 @@ func (s *GrandpaState) GetSetIDByBlockNumber(num uint) (uint64, error) { return 0, err } - // if the given block number is greater or equal to the block number of the set ID change, - // return the current set ID - if num <= changeUpper && num > changeLower { + // Set id changes at the last block in the set. So, block (changeLower) at which current + // set id was set, does not belong to current set. Thus, all block numbers in given set + // would be more than changeLower. + // Next set id change happens at the last block of current set. Thus, a block number from + // given set could be lower or equal to changeUpper. + if blockNumber <= changeUpper && blockNumber > changeLower { return curr, nil } - if num > changeUpper { + if blockNumber > changeUpper { return curr + 1, nil } diff --git a/lib/grandpa/errors.go b/lib/grandpa/errors.go index 54937cf1ef..d7f8f5555b 100644 --- a/lib/grandpa/errors.go +++ b/lib/grandpa/errors.go @@ -63,6 +63,8 @@ var ( // ErrNoJustification is returned when no justification can be found for a block, ie. it has not been finalised ErrNoJustification = errors.New("no justification found for block") + ErrBlockHashMismatch = errors.New("block hash does not correspond to given block number") + // ErrMinVotesNotMet is returned when the number of votes is less than the required minimum in a Justification ErrMinVotesNotMet = errors.New("minimum number of votes not met in a Justification") diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index 3402e8e8d3..ec4e78c8a3 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -9,6 +9,7 @@ import ( "fmt" "reflect" + "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/telemetry" "github.com/ChainSafe/gossamer/dot/types" @@ -104,6 +105,16 @@ func (h *MessageHandler) handleNeighbourMessage(msg *NeighbourMessage) error { func (h *MessageHandler) handleCommitMessage(msg *CommitMessage) error { logger.Debugf("received commit message, msg: %+v", msg) + err := verifyBlockHashAgainstBlockNumber(h.blockState, msg.Vote.Hash, uint(msg.Vote.Number)) + if err != nil { + if errors.Is(err, chaindb.ErrKeyNotFound) { + h.grandpa.tracker.addCommit(msg) + logger.Infof("we might not have synced to the given block %s yet: %s", msg.Vote.Hash, err) + return nil + } + return err + } + containsPrecommitsSignedBy := make([]string, len(msg.AuthData)) for i, authData := range msg.AuthData { containsPrecommitsSignedBy[i] = authData.AuthorityID.String() @@ -184,6 +195,16 @@ func (h *MessageHandler) handleCatchUpResponse(msg *CatchUpResponse) error { "received catch up response with hash %s for round %d and set id %d", msg.Hash, msg.Round, msg.SetID) + err := verifyBlockHashAgainstBlockNumber(h.blockState, msg.Hash, uint(msg.Number)) + if err != nil { + if errors.Is(err, chaindb.ErrKeyNotFound) { + h.grandpa.tracker.addCatchUpResponse(msg) + logger.Infof("we might not have synced to the given block %s yet: %s", msg.Hash, err) + return nil + } + return err + } + // TODO: re-add catch-up logic (#1531) if true { return nil @@ -300,12 +321,13 @@ func (h *MessageHandler) verifyCommitMessageJustification(fm *CommitMessage) err err := h.verifyJustification(just, fm.Round, h.grandpa.state.setID, precommit) if err != nil { + logger.Errorf("could not verify justification: %s", err) continue } isDescendant, err := h.blockState.IsDescendantOf(fm.Vote.Hash, just.Vote.Hash) if err != nil { - logger.Warnf("verifyCommitMessageJustification: %s", err) + logger.Warnf("could not check for descendant: %s", err) continue } @@ -330,6 +352,17 @@ func (h *MessageHandler) verifyPreVoteJustification(msg *CatchUpResponse) (commo voters := make(map[ed25519.PublicKeyBytes]map[common.Hash]int, len(msg.PreVoteJustification)) eqVotesByHash := make(map[common.Hash]map[ed25519.PublicKeyBytes]struct{}) + for _, pvj := range msg.PreVoteJustification { + err := verifyBlockHashAgainstBlockNumber(h.blockState, pvj.Vote.Hash, uint(pvj.Vote.Number)) + if err != nil { + if errors.Is(err, chaindb.ErrKeyNotFound) { + h.grandpa.tracker.addCatchUpResponse(msg) + logger.Infof("we might not have synced to the given block %s yet: %s", pvj.Vote.Hash, err) + continue + } + return common.Hash{}, err + } + } // identify equivocatory votes by hash for _, justification := range msg.PreVoteJustification { hashsToCount, ok := voters[justification.AuthorityID] @@ -386,6 +419,18 @@ func (h *MessageHandler) verifyPreVoteJustification(msg *CatchUpResponse) (commo } func (h *MessageHandler) verifyPreCommitJustification(msg *CatchUpResponse) error { + for _, pcj := range msg.PreCommitJustification { + err := verifyBlockHashAgainstBlockNumber(h.blockState, pcj.Vote.Hash, uint(pcj.Vote.Number)) + if err != nil { + if errors.Is(err, chaindb.ErrKeyNotFound) { + h.grandpa.tracker.addCatchUpResponse(msg) + logger.Infof("we might not have synced to the given block %s yet: %s", pcj.Vote.Hash, err) + continue + } + return err + } + } + auths := make([]AuthData, len(msg.PreCommitJustification)) for i, pcj := range msg.PreCommitJustification { auths[i] = AuthData{AuthorityID: pcj.AuthorityID} @@ -562,6 +607,18 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt return ErrMinVotesNotMet } + err = verifyBlockHashAgainstBlockNumber(s.blockState, fj.Commit.Hash, uint(fj.Commit.Number)) + if err != nil { + return err + } + + for _, preCommit := range fj.Commit.Precommits { + err := verifyBlockHashAgainstBlockNumber(s.blockState, preCommit.Vote.Hash, uint(preCommit.Vote.Number)) + if err != nil { + return err + } + } + err = s.blockState.SetFinalisedHash(hash, fj.Round, setID) if err != nil { return err @@ -573,6 +630,19 @@ func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byt return nil } +func verifyBlockHashAgainstBlockNumber(bs BlockState, hash common.Hash, number uint) error { + header, err := bs.GetHeader(hash) + if err != nil { + return fmt.Errorf("could not get header from block hash: %w", err) + } + + if header.Number != number { + return fmt.Errorf("%w: expected number %d from header but got number %d", + ErrBlockHashMismatch, header.Number, number) + } + return nil +} + func isInAuthSet(auth *ed25519.PublicKey, set []types.GrandpaVoter) bool { for _, a := range set { if bytes.Equal(a.Key.Encode(), auth.Encode()) { diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index 5e30d15a92..381575be60 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -250,7 +250,7 @@ func TestMessageHandler_VerifyJustification_InvalidSig(t *testing.T) { func TestMessageHandler_CommitMessage_NoCatchUpRequest_ValidSig(t *testing.T) { gs, st := newTestService(t) - round := uint64(77) + round := uint64(1) gs.state.round = round just := buildTestJustification(t, int(gs.state.threshold()), round, gs.state.setID, kr, precommit) err := st.Grandpa.SetPrecommits(round, gs.state.setID, just) @@ -505,6 +505,18 @@ func TestMessageHandler_VerifyPreVoteJustification(t *testing.T) { telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() gs, st := newTestService(t) + + body, err := types.NewBodyFromBytes([]byte{0}) + require.NoError(t, err) + + block := &types.Block{ + Header: *testHeader, + Body: *body, + } + + err = st.Block.AddBlock(block) + require.NoError(t, err) + h := NewMessageHandler(gs, st.Block, telemetryMock) just := buildTestJustification(t, int(gs.state.threshold()), 1, gs.state.setID, kr, prevote) @@ -525,6 +537,18 @@ func TestMessageHandler_VerifyPreCommitJustification(t *testing.T) { telemetryMock.EXPECT().SendMessage(gomock.Any()).AnyTimes() gs, st := newTestService(t) + + body, err := types.NewBodyFromBytes([]byte{0}) + require.NoError(t, err) + + block := &types.Block{ + Header: *testHeader, + Body: *body, + } + + err = st.Block.AddBlock(block) + require.NoError(t, err) + h := NewMessageHandler(gs, st.Block, telemetryMock) round := uint64(1) @@ -537,7 +561,7 @@ func TestMessageHandler_VerifyPreCommitJustification(t *testing.T) { Number: uint32(round), } - err := h.verifyPreCommitJustification(msg) + err = h.verifyPreCommitJustification(msg) require.NoError(t, err) } @@ -553,7 +577,7 @@ func TestMessageHandler_HandleCatchUpResponse(t *testing.T) { h := NewMessageHandler(gs, st.Block, telemetryMock) - round := uint64(77) + round := uint64(1) gs.state.round = round + 1 pvJust := buildTestJustification(t, int(gs.state.threshold()), round, gs.state.setID, kr, prevote) @@ -605,7 +629,7 @@ func TestMessageHandler_VerifyBlockJustification_WithEquivocatoryVotes(t *testin } gs, st := newTestService(t) - err := st.Grandpa.SetNextChange(auths, 1) + err := st.Grandpa.SetNextChange(auths, 0) require.NoError(t, err) body, err := types.NewBodyFromBytes([]byte{0}) @@ -623,8 +647,8 @@ func TestMessageHandler_VerifyBlockJustification_WithEquivocatoryVotes(t *testin require.NoError(t, err) require.Equal(t, uint64(1), setID) - round := uint64(2) - number := uint32(2) + round := uint64(1) + number := uint32(1) precommits := buildTestJustification(t, 20, round, setID, kr, precommit) just := newJustification(round, testHash, number, precommits) data, err := scale.Marshal(*just) @@ -647,7 +671,7 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { } gs, st := newTestService(t) - err := st.Grandpa.SetNextChange(auths, 1) + err := st.Grandpa.SetNextChange(auths, 0) require.NoError(t, err) body, err := types.NewBodyFromBytes([]byte{0}) @@ -667,8 +691,8 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { genhash := st.Block.GenesisHash() - round := uint64(2) - number := uint32(2) + round := uint64(1) + number := uint32(1) precommits := buildTestJustification(t, 2, round, setID, kr, precommit) just := newJustification(round, testHash, number, precommits) data, err := scale.Marshal(*just)