Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(consensus): the Context is passed as a param instead of being held as a field by SHC #2238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/papyrus_node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn run_consensus(
let start_height = config.start_height;

Ok(tokio::spawn(papyrus_consensus::run_consensus(
Arc::new(context),
context,
start_height,
validator_id,
consensus_channels.broadcasted_messages_receiver,
Expand Down
21 changes: 9 additions & 12 deletions crates/sequencing/papyrus_consensus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// TODO(Matan): fix #[allow(missing_docs)].
//! A consensus implementation for a [`Starknet`](https://www.starknet.io/) node.

use std::sync::Arc;

use futures::channel::{mpsc, oneshot};
use papyrus_common::metrics as papyrus_metrics;
use papyrus_network::network_manager::BroadcastSubscriberReceiver;
Expand Down Expand Up @@ -37,8 +35,8 @@ use futures::StreamExt;

#[instrument(skip(context, validator_id, network_receiver, cached_messages), level = "info")]
#[allow(missing_docs)]
async fn run_height<BlockT: ConsensusBlock>(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
async fn run_height<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: &ContextT,
height: BlockNumber,
validator_id: ValidatorId,
network_receiver: &mut BroadcastSubscriberReceiver<ConsensusMessage>,
Expand All @@ -49,10 +47,9 @@ where
Into<(ProposalInit, mpsc::Receiver<BlockT::ProposalChunk>, oneshot::Receiver<BlockHash>)>,
{
let validators = context.validators(height).await;
let mut shc =
SingleHeightConsensus::new(Arc::clone(&context), height, validator_id, validators);
let mut shc = SingleHeightConsensus::new(height, validator_id, validators);

if let Some(decision) = shc.start().await? {
if let Some(decision) = shc.start(context).await? {
return Ok(decision);
}

Expand Down Expand Up @@ -91,9 +88,9 @@ where
// Special case due to fake streaming.
let (proposal_init, content_receiver, fin_receiver) =
ProposalWrapper(proposal).into();
shc.handle_proposal(proposal_init, content_receiver, fin_receiver).await?
shc.handle_proposal(context, proposal_init, content_receiver, fin_receiver).await?
}
_ => shc.handle_message(message).await?,
_ => shc.handle_message(context, message).await?,
};

if let Some(decision) = maybe_decision {
Expand All @@ -105,8 +102,8 @@ where
// TODO(dvir): add test for this.
#[instrument(skip(context, start_height, network_receiver), level = "info")]
#[allow(missing_docs)]
pub async fn run_consensus<BlockT: ConsensusBlock>(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
pub async fn run_consensus<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: ContextT,
start_height: BlockNumber,
validator_id: ValidatorId,
mut network_receiver: BroadcastSubscriberReceiver<ConsensusMessage>,
Expand All @@ -119,7 +116,7 @@ where
let mut future_messages = Vec::new();
loop {
let decision = run_height(
Arc::clone(&context),
&context,
current_height,
validator_id,
&mut network_receiver,
Expand Down
71 changes: 37 additions & 34 deletions crates/sequencing/papyrus_consensus/src/single_height_consensus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod single_height_consensus_test;

use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

use futures::channel::{mpsc, oneshot};
use papyrus_protobuf::consensus::{ConsensusMessage, Vote, VoteType};
Expand All @@ -29,7 +28,6 @@ const ROUND_ZERO: Round = 0;
/// out messages "directly" to the network, and returning a decision to the caller.
pub(crate) struct SingleHeightConsensus<BlockT: ConsensusBlock> {
height: BlockNumber,
context: Arc<dyn ConsensusContext<Block = BlockT>>,
validators: Vec<ValidatorId>,
id: ValidatorId,
state_machine: StateMachine,
Expand All @@ -39,17 +37,11 @@ pub(crate) struct SingleHeightConsensus<BlockT: ConsensusBlock> {
}

impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
pub(crate) fn new(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
height: BlockNumber,
id: ValidatorId,
validators: Vec<ValidatorId>,
) -> Self {
pub(crate) fn new(height: BlockNumber, id: ValidatorId, validators: Vec<ValidatorId>) -> Self {
// TODO(matan): Use actual weights, not just `len`.
let state_machine = StateMachine::new(validators.len() as u32);
Self {
height,
context,
validators,
id,
state_machine,
Expand All @@ -59,22 +51,26 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
}
}

#[instrument(skip(self), fields(height=self.height.0), level = "debug")]
pub(crate) async fn start(&mut self) -> Result<Option<Decision<BlockT>>, ConsensusError> {
#[instrument(skip_all, fields(height=self.height.0), level = "debug")]
pub(crate) async fn start<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
info!("Starting consensus with validators {:?}", self.validators);
let events = self.state_machine.start();
self.handle_state_machine_events(events).await
self.handle_state_machine_events(context, events).await
}

/// Receive a proposal from a peer node. Returns only once the proposal has been fully received
/// and processed.
#[instrument(
skip(self, init, p2p_messages_receiver, fin_receiver),
skip_all,
fields(height = %self.height),
level = "debug",
)]
pub(crate) async fn handle_proposal(
pub(crate) async fn handle_proposal<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
init: ProposalInit,
p2p_messages_receiver: mpsc::Receiver<<BlockT as ConsensusBlock>::ProposalChunk>,
fin_receiver: oneshot::Receiver<BlockHash>,
Expand All @@ -83,7 +79,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
"Received proposal: proposal_height={}, proposer={:?}",
init.height.0, init.proposer
);
let proposer_id = self.context.proposer(&self.validators, self.height);
let proposer_id = context.proposer(&self.validators, self.height);
if init.height != self.height {
let msg = format!("invalid height: expected {:?}, got {:?}", self.height, init.height);
return Err(ConsensusError::InvalidProposal(proposer_id, self.height, msg));
Expand All @@ -94,8 +90,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
return Err(ConsensusError::InvalidProposal(proposer_id, self.height, msg));
}

let block_receiver =
self.context.validate_proposal(self.height, p2p_messages_receiver).await;
let block_receiver = context.validate_proposal(self.height, p2p_messages_receiver).await;
// TODO(matan): Actual Tendermint should handle invalid proposals.
let block = block_receiver.await.map_err(|_| {
ConsensusError::InvalidProposal(
Expand Down Expand Up @@ -124,27 +119,29 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
// TODO(matan): Handle multiple rounds.
self.proposals.insert(ROUND_ZERO, block);
let sm_events = self.state_machine.handle_event(sm_proposal);
self.handle_state_machine_events(sm_events).await
self.handle_state_machine_events(context, sm_events).await
}

/// Handle messages from peer nodes.
#[instrument(skip_all)]
pub(crate) async fn handle_message(
pub(crate) async fn handle_message<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
message: ConsensusMessage,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
debug!("Received message: {:?}", message);
match message {
ConsensusMessage::Proposal(_) => {
unimplemented!("Proposals should use `handle_proposal` due to fake streaming")
}
ConsensusMessage::Vote(vote) => self.handle_vote(vote).await,
ConsensusMessage::Vote(vote) => self.handle_vote(context, vote).await,
}
}

#[instrument(skip_all)]
async fn handle_vote(
async fn handle_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
vote: Vote,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
let (votes, sm_vote) = match vote.vote_type {
Expand All @@ -170,21 +167,24 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {

votes.insert((ROUND_ZERO, vote.voter), vote);
let sm_events = self.state_machine.handle_event(sm_vote);
self.handle_state_machine_events(sm_events).await
self.handle_state_machine_events(context, sm_events).await
}

// Handle events output by the state machine.
#[instrument(skip_all)]
async fn handle_state_machine_events(
async fn handle_state_machine_events<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
mut events: VecDeque<StateMachineEvent>,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
while let Some(event) = events.pop_front() {
trace!("Handling event: {:?}", event);
match event {
StateMachineEvent::StartRound(block_hash, round) => {
events.append(
&mut self.handle_state_machine_start_round(block_hash, round).await,
&mut self
.handle_state_machine_start_round(context, block_hash, round)
.await,
);
}
StateMachineEvent::Proposal(_, _) => {
Expand All @@ -195,37 +195,39 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
return self.handle_state_machine_decision(block_hash, round).await;
}
StateMachineEvent::Prevote(block_hash, round) => {
self.handle_state_machine_vote(block_hash, round, VoteType::Prevote).await?;
self.handle_state_machine_vote(context, block_hash, round, VoteType::Prevote)
.await?;
}
StateMachineEvent::Precommit(block_hash, round) => {
self.handle_state_machine_vote(block_hash, round, VoteType::Precommit).await?;
self.handle_state_machine_vote(context, block_hash, round, VoteType::Precommit)
.await?;
}
}
}
Ok(None)
}

#[instrument(skip(self), level = "debug")]
async fn handle_state_machine_start_round(
#[instrument(skip(self, context), level = "debug")]
async fn handle_state_machine_start_round<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
block_hash: Option<BlockHash>,
round: Round,
) -> VecDeque<StateMachineEvent> {
// TODO(matan): Support re-proposing validValue.
assert!(block_hash.is_none(), "Reproposing is not yet supported");
let proposer_id = self.context.proposer(&self.validators, self.height);
let proposer_id = context.proposer(&self.validators, self.height);
if proposer_id != self.id {
debug!("Validator");
return self.state_machine.handle_event(StateMachineEvent::StartRound(None, round));
}
debug!("Proposer");

let (p2p_messages_receiver, block_receiver) =
self.context.build_proposal(self.height).await;
let (p2p_messages_receiver, block_receiver) = context.build_proposal(self.height).await;
let (fin_sender, fin_receiver) = oneshot::channel();
let init = ProposalInit { height: self.height, proposer: self.id };
// Peering is a permanent component, so if sending to it fails we cannot continue.
self.context
context
.propose(init, p2p_messages_receiver, fin_receiver)
.await
.expect("Failed sending Proposal to Peering");
Expand All @@ -245,8 +247,9 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
}

#[instrument(skip_all)]
async fn handle_state_machine_vote(
async fn handle_state_machine_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
block_hash: BlockHash,
round: Round,
vote_type: VoteType,
Expand All @@ -260,7 +263,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
// TODO(matan): Consider refactoring not to panic, rather log and return the error.
panic!("State machine should not send repeat votes: old={:?}, new={:?}", old, vote);
}
self.context.broadcast(ConsensusMessage::Vote(vote)).await?;
context.broadcast(ConsensusMessage::Vote(vote)).await?;
Ok(None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,26 @@ async fn proposer() {
.returning(move |_| Ok(()));

let mut shc = SingleHeightConsensus::new(
Arc::new(context),
BlockNumber(0),
node_id,
vec![node_id, 2_u32.into(), 3_u32.into(), 4_u32.into()],
);

// Sends proposal and prevote.
assert!(matches!(shc.start().await, Ok(None)));
assert!(matches!(shc.start(&context).await, Ok(None)));

assert_eq!(shc.handle_message(prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));

let precommits = vec![
precommit(block.id(), 0, 1_u32.into()),
precommit(BlockHash(Felt::TWO), 0, 4_u32.into()), // Ignores since disagrees.
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
];
assert_eq!(shc.handle_message(precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&context, precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(&context, precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down Expand Up @@ -116,7 +115,6 @@ async fn validator() {

// Creation calls to `context.validators`.
let mut shc = SingleHeightConsensus::new(
Arc::new(context),
BlockNumber(0),
node_id,
vec![node_id, proposer, 3_u32.into(), 4_u32.into()],
Expand All @@ -128,23 +126,24 @@ async fn validator() {

let res = shc
.handle_proposal(
&context,
ProposalInit { height: BlockNumber(0), proposer },
mpsc::channel(1).1, // content - ignored by SHC.
fin_receiver,
)
.await;
assert_eq!(res, Ok(None));

assert_eq!(shc.handle_message(prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));

let precommits = vec![
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
precommit(block.id(), 0, node_id),
];
assert_eq!(shc.handle_message(precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&context, precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down
12 changes: 1 addition & 11 deletions crates/sequencing/papyrus_consensus/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#[cfg(test)]
#[path = "types_test.rs"]
mod types_test;

use std::fmt::Debug;

use async_trait::async_trait;
Expand Down Expand Up @@ -68,14 +64,8 @@ pub trait ConsensusBlock: Send {
}

/// Interface for consensus to call out to the node.
// Why `Send + Sync`?
// 1. We expect multiple components within consensus to concurrently access the context.
// 2. The other option is for each component to have its own copy (i.e. clone) of the context, but
// this is object unsafe (Clone requires Sized).
// 3. Given that we see the context as basically a connector to other components in the node, the
// limitation of Sync to keep functions `&self` shouldn't be a problem.
#[async_trait]
pub trait ConsensusContext: Send + Sync {
pub trait ConsensusContext {
/// The [block](`ConsensusBlock`) type built by `ConsensusContext` from a proposal.
// We use an associated type since consensus is indifferent to the actual content of a proposal,
// but we cannot use generics due to object safety.
Expand Down
10 changes: 0 additions & 10 deletions crates/sequencing/papyrus_consensus/src/types_test.rs

This file was deleted.

Loading