diff --git a/Cargo.lock b/Cargo.lock index 194de7660432..5d259bd7d00c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7219,6 +7219,7 @@ name = "polkadot-primitives-test-helpers" version = "0.9.18" dependencies = [ "polkadot-primitives", + "rand 0.8.5", "sp-application-crypto", "sp-keyring", "sp-runtime", diff --git a/node/network/approval-distribution/src/lib.rs b/node/network/approval-distribution/src/lib.rs index ca6f5dccc5a9..a231fe0b3472 100644 --- a/node/network/approval-distribution/src/lib.rs +++ b/node/network/approval-distribution/src/lib.rs @@ -720,7 +720,7 @@ impl State { // to all peers in the BlockEntry's known_by set who know about the block, // excluding the peer in the source, if source has kind MessageSource::Peer. let maybe_peer_id = source.peer_id(); - let peers = entry + let mut peers = entry .known_by .keys() .cloned() @@ -729,8 +729,7 @@ impl State { let assignments = vec![(assignment, claimed_candidate_index)]; let gossip_peers = &self.gossip_peers; - let peers = - util::choose_random_subset(|e| gossip_peers.contains(e), peers, MIN_GOSSIP_PEERS); + util::choose_random_subset(|e| gossip_peers.contains(e), &mut peers, MIN_GOSSIP_PEERS); // Add the fingerprint of the assignment to the knowledge of each peer. for peer in peers.iter() { @@ -943,7 +942,7 @@ impl State { // to all peers in the BlockEntry's known_by set who know about the block, // excluding the peer in the source, if source has kind MessageSource::Peer. let maybe_peer_id = source.peer_id(); - let peers = entry + let mut peers = entry .known_by .keys() .cloned() @@ -951,8 +950,7 @@ impl State { .collect::>(); let gossip_peers = &self.gossip_peers; - let peers = - util::choose_random_subset(|e| gossip_peers.contains(e), peers, MIN_GOSSIP_PEERS); + util::choose_random_subset(|e| gossip_peers.contains(e), &mut peers, MIN_GOSSIP_PEERS); // Add the fingerprint of the assignment to the knowledge of each peer. for peer in peers.iter() { diff --git a/node/network/bitfield-distribution/src/lib.rs b/node/network/bitfield-distribution/src/lib.rs index befdec66b359..6bd952233111 100644 --- a/node/network/bitfield-distribution/src/lib.rs +++ b/node/network/bitfield-distribution/src/lib.rs @@ -346,7 +346,7 @@ async fn relay_message( let _span = span.child("interested-peers"); // pass on the bitfield distribution to all interested peers - let interested_peers = peer_views + let mut interested_peers = peer_views .iter() .filter_map(|(peer, view)| { // check interest in the peer in this message's relay parent @@ -363,9 +363,9 @@ async fn relay_message( } }) .collect::>(); - let interested_peers = util::choose_random_subset( + util::choose_random_subset( |e| gossip_peers.contains(e), - interested_peers, + &mut interested_peers, MIN_GOSSIP_PEERS, ); interested_peers.iter().for_each(|peer| { diff --git a/node/network/statement-distribution/src/lib.rs b/node/network/statement-distribution/src/lib.rs index 1931c545c0d5..0bf43b883cd3 100644 --- a/node/network/statement-distribution/src/lib.rs +++ b/node/network/statement-distribution/src/lib.rs @@ -32,7 +32,7 @@ use polkadot_node_network_protocol::{ IfDisconnected, PeerId, UnifiedReputationChange as Rep, View, }; use polkadot_node_primitives::{SignedFullStatement, Statement, UncheckedSignedFullStatement}; -use polkadot_node_subsystem_util::{self as util, MIN_GOSSIP_PEERS}; +use polkadot_node_subsystem_util::{self as util, rand, MIN_GOSSIP_PEERS}; use polkadot_primitives::v2::{ AuthorityDiscoveryId, CandidateHash, CommittedCandidateReceipt, CompactStatement, Hash, @@ -115,16 +115,19 @@ const LOG_TARGET: &str = "parachain::statement-distribution"; const MAX_LARGE_STATEMENTS_PER_SENDER: usize = 20; /// The statement distribution subsystem. -pub struct StatementDistributionSubsystem { +pub struct StatementDistributionSubsystem { /// Pointer to a keystore, which is required for determining this node's validator index. keystore: SyncCryptoStorePtr, /// Receiver for incoming large statement requests. req_receiver: Option>, /// Prometheus metrics metrics: Metrics, + /// Pseudo-random generator for peers selection logic + rng: R, } -impl overseer::Subsystem for StatementDistributionSubsystem +impl overseer::Subsystem + for StatementDistributionSubsystem where Context: SubsystemContext, Context: overseer::SubsystemContext, @@ -142,17 +145,6 @@ where } } -impl StatementDistributionSubsystem { - /// Create a new Statement Distribution Subsystem - pub fn new( - keystore: SyncCryptoStorePtr, - req_receiver: IncomingRequestReceiver, - metrics: Metrics, - ) -> Self { - Self { keystore, req_receiver: Some(req_receiver), metrics } - } -} - #[derive(Default)] struct RecentOutdatedHeads { buf: VecDeque, @@ -906,6 +898,7 @@ async fn circulate_statement_and_dependents( statement: SignedFullStatement, priority_peers: Vec, metrics: &Metrics, + rng: &mut impl rand::Rng, ) { let active_head = match active_heads.get_mut(&relay_parent) { Some(res) => res, @@ -932,6 +925,7 @@ async fn circulate_statement_and_dependents( stored, priority_peers, metrics, + rng, ) .await, )), @@ -1019,6 +1013,7 @@ async fn circulate_statement<'a>( stored: StoredStatement<'a>, mut priority_peers: Vec, metrics: &Metrics, + rng: &mut impl rand::Rng, ) -> Vec { let fingerprint = stored.fingerprint(); @@ -1041,8 +1036,12 @@ async fn circulate_statement<'a>( let priority_set: HashSet<&PeerId> = priority_peers.iter().collect(); peers_to_send.retain(|p| !priority_set.contains(p)); - let mut peers_to_send = - util::choose_random_subset(|e| gossip_peers.contains(e), peers_to_send, MIN_GOSSIP_PEERS); + util::choose_random_subset_with_rng( + |e| gossip_peers.contains(e), + &mut peers_to_send, + rng, + MIN_GOSSIP_PEERS, + ); // We don't want to use less peers, than we would without any priority peers: let min_size = std::cmp::max(peers_to_send.len(), MIN_GOSSIP_PEERS); // Make set full: @@ -1313,6 +1312,7 @@ async fn handle_incoming_message_and_circulate<'a>( message: protocol_v1::StatementDistributionMessage, req_sender: &mpsc::Sender, metrics: &Metrics, + rng: &mut impl rand::Rng, ) { let handled_incoming = match peers.get_mut(&peer) { Some(data) => @@ -1348,6 +1348,7 @@ async fn handle_incoming_message_and_circulate<'a>( statement, Vec::new(), metrics, + rng, ) .await; } @@ -1458,7 +1459,12 @@ async fn handle_incoming_message<'a>( Ok(()) => {}, Err(DeniedStatement::NotUseful) => return None, Err(DeniedStatement::UsefulButKnown) => { + // Note a received statement in the peer data + peer_data + .receive(&relay_parent, &fingerprint, max_message_count) + .expect("checked in `check_can_receive` above; qed"); report_peer(ctx, peer, BENEFIT_VALID_STATEMENT).await; + return None }, } @@ -1558,6 +1564,7 @@ async fn update_peer_view_and_maybe_send_unlocked( active_heads: &HashMap, new_view: View, metrics: &Metrics, + rng: &mut impl rand::Rng, ) { let old_view = std::mem::replace(&mut peer_data.view, new_view); @@ -1568,9 +1575,10 @@ async fn update_peer_view_and_maybe_send_unlocked( let is_gossip_peer = gossip_peers.contains(&peer); let lucky = is_gossip_peer || - util::gen_ratio( + util::gen_ratio_rng( util::MIN_GOSSIP_PEERS.saturating_sub(gossip_peers.len()), util::MIN_GOSSIP_PEERS, + rng, ); // Add entries for all relay-parents in the new view but not the old. @@ -1597,6 +1605,7 @@ async fn handle_network_update( req_sender: &mpsc::Sender, update: NetworkBridgeEvent, metrics: &Metrics, + rng: &mut impl rand::Rng, ) { match update { NetworkBridgeEvent::PeerConnected(peer, role, maybe_authority) => { @@ -1638,6 +1647,7 @@ async fn handle_network_update( &*active_heads, view, metrics, + rng, ) .await } @@ -1654,6 +1664,7 @@ async fn handle_network_update( message, req_sender, metrics, + rng, ) .await; }, @@ -1670,6 +1681,7 @@ async fn handle_network_update( &*active_heads, view, metrics, + rng, ) .await, None => (), @@ -1681,7 +1693,17 @@ async fn handle_network_update( } } -impl StatementDistributionSubsystem { +impl StatementDistributionSubsystem { + /// Create a new Statement Distribution Subsystem + pub fn new( + keystore: SyncCryptoStorePtr, + req_receiver: IncomingRequestReceiver, + metrics: Metrics, + rng: R, + ) -> Self { + Self { keystore, req_receiver: Some(req_receiver), metrics, rng } + } + async fn run( mut self, mut ctx: (impl SubsystemContext @@ -1803,7 +1825,7 @@ impl StatementDistributionSubsystem { } async fn handle_requester_message( - &self, + &mut self, ctx: &mut impl SubsystemContext, gossip_peers: &HashSet, peers: &mut HashMap, @@ -1861,6 +1883,7 @@ impl StatementDistributionSubsystem { message, req_sender, &self.metrics, + &mut self.rng, ) .await; } @@ -1910,7 +1933,7 @@ impl StatementDistributionSubsystem { } async fn handle_subsystem_message( - &self, + &mut self, ctx: &mut (impl SubsystemContext + overseer::SubsystemContext), runtime: &mut RuntimeInfo, peers: &mut HashMap, @@ -2022,6 +2045,7 @@ impl StatementDistributionSubsystem { statement, group_peers, metrics, + &mut self.rng, ) .await; }, @@ -2036,6 +2060,7 @@ impl StatementDistributionSubsystem { req_sender, event, metrics, + &mut self.rng, ) .await; }, diff --git a/node/network/statement-distribution/src/tests.rs b/node/network/statement-distribution/src/tests.rs index 10462fc1a580..9e91ac5ba650 100644 --- a/node/network/statement-distribution/src/tests.rs +++ b/node/network/statement-distribution/src/tests.rs @@ -29,7 +29,9 @@ use polkadot_node_network_protocol::{ use polkadot_node_primitives::{Statement, UncheckedSignedFullStatement}; use polkadot_node_subsystem_test_helpers::mock::make_ferdie_keystore; use polkadot_primitives::v2::{Hash, SessionInfo, ValidationCode}; -use polkadot_primitives_test_helpers::{dummy_committed_candidate_receipt, dummy_hash}; +use polkadot_primitives_test_helpers::{ + dummy_committed_candidate_receipt, dummy_hash, AlwaysZeroRng, +}; use polkadot_subsystem::{ jaeger, messages::{RuntimeApiMessage, RuntimeApiRequest}, @@ -511,6 +513,7 @@ fn peer_view_update_sends_messages() { &active_heads, new_view.clone(), &Default::default(), + &mut AlwaysZeroRng, ) .await; @@ -640,6 +643,7 @@ fn circulated_statement_goes_to_all_peers_with_view() { statement, Vec::new(), &Metrics::default(), + &mut AlwaysZeroRng, ) .await; @@ -723,6 +727,7 @@ fn receiving_from_one_sends_to_another_and_to_candidate_backing() { Arc::new(LocalKeystore::in_memory()), statement_req_receiver, Default::default(), + AlwaysZeroRng, ); s.run(ctx).await.unwrap(); }; @@ -915,6 +920,7 @@ fn receiving_large_statement_from_one_sends_to_another_and_to_candidate_backing( make_ferdie_keystore(), statement_req_receiver, Default::default(), + AlwaysZeroRng, ); s.run(ctx).await.unwrap(); }; @@ -1412,6 +1418,7 @@ fn share_prioritizes_backing_group() { make_ferdie_keystore(), statement_req_receiver, Default::default(), + AlwaysZeroRng, ); s.run(ctx).await.unwrap(); }; @@ -1695,6 +1702,7 @@ fn peer_cant_flood_with_large_statements() { make_ferdie_keystore(), statement_req_receiver, Default::default(), + AlwaysZeroRng, ); s.run(ctx).await.unwrap(); }; @@ -1842,6 +1850,347 @@ fn peer_cant_flood_with_large_statements() { executor::block_on(future::join(test_fut, bg)); } +// This test addresses an issue when received knowledge is not updated on a +// subsequent `Seconded` statements +// See https://github.com/paritytech/polkadot/pull/5177 +#[test] +fn handle_multiple_seconded_statements() { + let relay_parent_hash = Hash::repeat_byte(1); + + let candidate = dummy_committed_candidate_receipt(relay_parent_hash); + let candidate_hash = candidate.hash(); + + // We want to ensure that our peers are not lucky + let mut all_peers: Vec = Vec::with_capacity(MIN_GOSSIP_PEERS + 4); + let peer_a = PeerId::random(); + let peer_b = PeerId::random(); + assert_ne!(peer_a, peer_b); + + for _ in 0..MIN_GOSSIP_PEERS + 2 { + all_peers.push(PeerId::random()); + } + all_peers.push(peer_a.clone()); + all_peers.push(peer_b.clone()); + + let mut lucky_peers = all_peers.clone(); + util::choose_random_subset_with_rng( + |_| false, + &mut lucky_peers, + &mut AlwaysZeroRng, + MIN_GOSSIP_PEERS, + ); + lucky_peers.sort(); + assert_eq!(lucky_peers.len(), MIN_GOSSIP_PEERS); + assert!(!lucky_peers.contains(&peer_a)); + assert!(!lucky_peers.contains(&peer_b)); + + let validators = vec![ + Sr25519Keyring::Alice.pair(), + Sr25519Keyring::Bob.pair(), + Sr25519Keyring::Charlie.pair(), + ]; + + let session_info = make_session_info(validators, vec![]); + + let session_index = 1; + + let pool = sp_core::testing::TaskExecutor::new(); + let (ctx, mut handle) = polkadot_node_subsystem_test_helpers::make_subsystem_context(pool); + + let (statement_req_receiver, _) = IncomingRequest::get_config_receiver(); + + let virtual_overseer_fut = async move { + let s = StatementDistributionSubsystem::new( + Arc::new(LocalKeystore::in_memory()), + statement_req_receiver, + Default::default(), + AlwaysZeroRng, + ); + s.run(ctx).await.unwrap(); + }; + + let test_fut = async move { + // register our active heads. + handle + .send(FromOverseer::Signal(OverseerSignal::ActiveLeaves( + ActiveLeavesUpdate::start_work(ActivatedLeaf { + hash: relay_parent_hash, + number: 1, + status: LeafStatus::Fresh, + span: Arc::new(jaeger::Span::Disabled), + }), + ))) + .await; + + assert_matches!( + handle.recv().await, + AllMessages::RuntimeApi( + RuntimeApiMessage::Request(r, RuntimeApiRequest::SessionIndexForChild(tx)) + ) + if r == relay_parent_hash + => { + let _ = tx.send(Ok(session_index)); + } + ); + + assert_matches!( + handle.recv().await, + AllMessages::RuntimeApi( + RuntimeApiMessage::Request(r, RuntimeApiRequest::SessionInfo(sess_index, tx)) + ) + if r == relay_parent_hash && sess_index == session_index + => { + let _ = tx.send(Ok(Some(session_info))); + } + ); + + // notify of peers and view + for peer in all_peers.iter() { + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerConnected(peer.clone(), ObservedRole::Full, None), + ), + }) + .await; + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerViewChange(peer.clone(), view![relay_parent_hash]), + ), + }) + .await; + } + + // Explicitly add all `lucky` peers to the gossip peers to ensure that neither `peerA` not `peerB` + // receive statements + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::NewGossipTopology( + lucky_peers.iter().cloned().collect::>(), + ), + ), + }) + .await; + + // receive a seconded statement from peer A. it should be propagated onwards to peer B and to + // candidate backing. + let statement = { + let signing_context = SigningContext { parent_hash: relay_parent_hash, session_index }; + + let keystore: SyncCryptoStorePtr = Arc::new(LocalKeystore::in_memory()); + let alice_public = CryptoStore::sr25519_generate_new( + &*keystore, + ValidatorId::ID, + Some(&Sr25519Keyring::Alice.to_seed()), + ) + .await + .unwrap(); + + SignedFullStatement::sign( + &keystore, + Statement::Seconded(candidate.clone()), + &signing_context, + ValidatorIndex(0), + &alice_public.into(), + ) + .await + .ok() + .flatten() + .expect("should be signed") + }; + + // `PeerA` sends a `Seconded` message + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerMessage( + peer_a.clone(), + protocol_v1::StatementDistributionMessage::Statement( + relay_parent_hash, + statement.clone().into(), + ), + ), + ), + }) + .await; + + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::ReportPeer(p, r) + ) => { + assert_eq!(p, peer_a); + assert_eq!(r, BENEFIT_VALID_STATEMENT_FIRST); + } + ); + + // After the first valid statement, we expect messages to be circulated + assert_matches!( + handle.recv().await, + AllMessages::CandidateBacking( + CandidateBackingMessage::Statement(r, s) + ) => { + assert_eq!(r, relay_parent_hash); + assert_eq!(s, statement); + } + ); + + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::SendValidationMessage( + recipients, + protocol_v1::ValidationProtocol::StatementDistribution( + protocol_v1::StatementDistributionMessage::Statement(r, s) + ), + ) + ) => { + assert!(!recipients.contains(&peer_b)); + assert_eq!(r, relay_parent_hash); + assert_eq!(s, statement.clone().into()); + } + ); + + // `PeerB` sends a `Seconded` message: valid but known + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerMessage( + peer_b.clone(), + protocol_v1::StatementDistributionMessage::Statement( + relay_parent_hash, + statement.clone().into(), + ), + ), + ), + }) + .await; + + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::ReportPeer(p, r) + ) => { + assert_eq!(p, peer_b); + assert_eq!(r, BENEFIT_VALID_STATEMENT); + } + ); + + // Create a `Valid` statement + let statement = { + let signing_context = SigningContext { parent_hash: relay_parent_hash, session_index }; + + let keystore: SyncCryptoStorePtr = Arc::new(LocalKeystore::in_memory()); + let alice_public = CryptoStore::sr25519_generate_new( + &*keystore, + ValidatorId::ID, + Some(&Sr25519Keyring::Alice.to_seed()), + ) + .await + .unwrap(); + + SignedFullStatement::sign( + &keystore, + Statement::Valid(candidate_hash), + &signing_context, + ValidatorIndex(0), + &alice_public.into(), + ) + .await + .ok() + .flatten() + .expect("should be signed") + }; + + // `PeerA` sends a `Valid` message + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerMessage( + peer_a.clone(), + protocol_v1::StatementDistributionMessage::Statement( + relay_parent_hash, + statement.clone().into(), + ), + ), + ), + }) + .await; + + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::ReportPeer(p, r) + ) => { + assert_eq!(p, peer_a); + assert_eq!(r, BENEFIT_VALID_STATEMENT_FIRST); + } + ); + + assert_matches!( + handle.recv().await, + AllMessages::CandidateBacking( + CandidateBackingMessage::Statement(r, s) + ) => { + assert_eq!(r, relay_parent_hash); + assert_eq!(s, statement); + } + ); + + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::SendValidationMessage( + recipients, + protocol_v1::ValidationProtocol::StatementDistribution( + protocol_v1::StatementDistributionMessage::Statement(r, s) + ), + ) + ) => { + assert!(!recipients.contains(&peer_b)); + assert_eq!(r, relay_parent_hash); + assert_eq!(s, statement.clone().into()); + } + ); + + // `PeerB` sends a `Valid` message + handle + .send(FromOverseer::Communication { + msg: StatementDistributionMessage::NetworkBridgeUpdateV1( + NetworkBridgeEvent::PeerMessage( + peer_b.clone(), + protocol_v1::StatementDistributionMessage::Statement( + relay_parent_hash, + statement.clone().into(), + ), + ), + ), + }) + .await; + + // We expect that this is still valid despite the fact that `PeerB` was not + // the first when sending `Seconded` + assert_matches!( + handle.recv().await, + AllMessages::NetworkBridge( + NetworkBridgeMessage::ReportPeer(p, r) + ) => { + assert_eq!(p, peer_b); + assert_eq!(r, BENEFIT_VALID_STATEMENT); + } + ); + + handle.send(FromOverseer::Signal(OverseerSignal::Conclude)).await; + }; + + futures::pin_mut!(test_fut); + futures::pin_mut!(virtual_overseer_fut); + + executor::block_on(future::join(test_fut, virtual_overseer_fut)); +} + fn make_session_info(validators: Vec, groups: Vec>) -> SessionInfo { let validator_groups: Vec> = groups .iter() diff --git a/node/service/src/overseer.rs b/node/service/src/overseer.rs index a9a757163381..a8d14536fd7a 100644 --- a/node/service/src/overseer.rs +++ b/node/service/src/overseer.rs @@ -63,6 +63,7 @@ pub use polkadot_node_core_dispute_coordinator::DisputeCoordinatorSubsystem; pub use polkadot_node_core_provisioner::ProvisionerSubsystem; pub use polkadot_node_core_pvf_checker::PvfCheckerSubsystem; pub use polkadot_node_core_runtime_api::RuntimeApiSubsystem; +use polkadot_node_subsystem_util::rand::{self, SeedableRng}; pub use polkadot_statement_distribution::StatementDistributionSubsystem; /// Arguments passed for overseer construction. @@ -148,7 +149,7 @@ pub fn prepared_overseer_builder<'a, Spawner, RuntimeClient>( CandidateValidationSubsystem, PvfCheckerSubsystem, CandidateBackingSubsystem, - StatementDistributionSubsystem, + StatementDistributionSubsystem, AvailabilityDistributionSubsystem, AvailabilityRecoverySubsystem, BitfieldSigningSubsystem, @@ -255,6 +256,7 @@ where keystore.clone(), statement_req_receiver, Metrics::register(registry)?, + rand::rngs::StdRng::from_entropy(), )) .approval_distribution(ApprovalDistributionSubsystem::new(Metrics::register(registry)?)) .approval_voting(ApprovalVotingSubsystem::with_config( diff --git a/node/subsystem-util/src/lib.rs b/node/subsystem-util/src/lib.rs index ccdfe7982b59..6886d298eb9d 100644 --- a/node/subsystem-util/src/lib.rs +++ b/node/subsystem-util/src/lib.rs @@ -55,6 +55,7 @@ use polkadot_primitives::v2::{ PersistedValidationData, SessionIndex, SessionInfo, Signed, SigningContext, ValidationCode, ValidationCodeHash, ValidatorId, ValidatorIndex, ValidatorSignature, }; +pub use rand; use sp_application_crypto::AppKey; use sp_core::{traits::SpawnNamed, ByteArray}; use sp_keystore::{CryptoStore, Error as KeystoreError, SyncCryptoStorePtr}; @@ -276,33 +277,41 @@ pub fn find_validator_group( /// Choose a random subset of `min` elements. /// But always include `is_priority` elements. -pub fn choose_random_subset bool>( +pub fn choose_random_subset bool>(is_priority: F, v: &mut Vec, min: usize) { + choose_random_subset_with_rng(is_priority, v, &mut rand::thread_rng(), min) +} + +/// Choose a random subset of `min` elements using a specific Random Generator `Rng` +/// But always include `is_priority` elements. +pub fn choose_random_subset_with_rng bool, R: rand::Rng>( is_priority: F, - mut v: Vec, + v: &mut Vec, + rng: &mut R, min: usize, -) -> Vec { +) { use rand::seq::SliceRandom as _; // partition the elements into priority first // the returned index is when non_priority elements start - let i = itertools::partition(&mut v, is_priority); + let i = itertools::partition(v.iter_mut(), is_priority); if i >= min || v.len() <= i { v.truncate(i); - return v + return } - let mut rng = rand::thread_rng(); - v[i..].shuffle(&mut rng); + v[i..].shuffle(rng); v.truncate(min); - v } /// Returns a `bool` with a probability of `a / b` of being true. pub fn gen_ratio(a: usize, b: usize) -> bool { - use rand::Rng as _; - let mut rng = rand::thread_rng(); + gen_ratio_rng(a, b, &mut rand::thread_rng()) +} + +/// Returns a `bool` with a probability of `a / b` of being true. +pub fn gen_ratio_rng(a: usize, b: usize, rng: &mut R) -> bool { rng.gen_ratio(a as u32, b as u32) } diff --git a/node/subsystem-util/src/tests.rs b/node/subsystem-util/src/tests.rs index c7c6cbf6d80c..166b4d557508 100644 --- a/node/subsystem-util/src/tests.rs +++ b/node/subsystem-util/src/tests.rs @@ -25,7 +25,7 @@ use polkadot_node_subsystem::{ }; use polkadot_node_subsystem_test_helpers::{self as test_helpers, make_subsystem_context}; use polkadot_primitives::v2::Hash; -use polkadot_primitives_test_helpers::{dummy_candidate_receipt, dummy_hash}; +use polkadot_primitives_test_helpers::{dummy_candidate_receipt, dummy_hash, AlwaysZeroRng}; use std::{ pin::Pin, sync::{ @@ -248,11 +248,23 @@ fn tick_tack_metronome() { #[test] fn subset_generation_check() { - let values = (0_u8..=25).collect::>(); + let mut values = (0_u8..=25).collect::>(); // 12 even numbers exist - let mut chosen = choose_random_subset::(|v| v & 0x01 == 0, values, 12); - chosen.sort(); - for (idx, v) in dbg!(chosen).into_iter().enumerate() { + choose_random_subset::(|v| v & 0x01 == 0, &mut values, 12); + values.sort(); + for (idx, v) in dbg!(values).into_iter().enumerate() { assert_eq!(v as usize, idx * 2); } } + +#[test] +fn subset_predefined_generation_check() { + let mut values = (0_u8..=25).collect::>(); + choose_random_subset_with_rng::(|_| false, &mut values, &mut AlwaysZeroRng, 12); + assert_eq!(values.len(), 12); + for (idx, v) in dbg!(values).into_iter().enumerate() { + // Since shuffle actually shuffles the indexes from 1..len, then + // our PRG that returns zeroes will shuffle 0 and 1, 1 and 2, ... len-2 and len-1 + assert_eq!(v as usize, idx + 1); + } +} diff --git a/primitives/test-helpers/Cargo.toml b/primitives/test-helpers/Cargo.toml index e8223c99cc5a..59fdf4e1a706 100644 --- a/primitives/test-helpers/Cargo.toml +++ b/primitives/test-helpers/Cargo.toml @@ -9,3 +9,4 @@ sp-keyring = { git = "https://github.com/paritytech/substrate", branch = "master sp-application-crypto = { package = "sp-application-crypto", git = "https://github.com/paritytech/substrate", branch = "master", default-features = false } sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "master" } polkadot-primitives = { path = "../" } +rand = "0.8.5" diff --git a/primitives/test-helpers/src/lib.rs b/primitives/test-helpers/src/lib.rs index 9d98cd0b55f5..02ba009b13cc 100644 --- a/primitives/test-helpers/src/lib.rs +++ b/primitives/test-helpers/src/lib.rs @@ -26,6 +26,7 @@ use polkadot_primitives::v2::{ CommittedCandidateReceipt, Hash, HeadData, Id as ParaId, ValidationCode, ValidationCodeHash, ValidatorId, }; +pub use rand; use sp_application_crypto::sr25519; use sp_keyring::Sr25519Keyring; use sp_runtime::generic::Digest; @@ -224,3 +225,33 @@ impl TestCandidateBuilder { CandidateReceipt { descriptor, commitments_hash: self.commitments_hash } } } + +/// A special `Rng` that always returns zero for testing something that implied +/// to be random but should not be random in the tests +pub struct AlwaysZeroRng; + +impl Default for AlwaysZeroRng { + fn default() -> Self { + Self {} + } +} +impl rand::RngCore for AlwaysZeroRng { + fn next_u32(&mut self) -> u32 { + 0_u32 + } + + fn next_u64(&mut self) -> u64 { + 0_u64 + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + for element in dest.iter_mut() { + *element = 0_u8; + } + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + self.fill_bytes(dest); + Ok(()) + } +}