Skip to content

Commit

Permalink
Merge pull request #345 from chirino/distributed-batcher
Browse files Browse the repository at this point in the history
  • Loading branch information
chirino authored May 27, 2024
2 parents c88628f + 3baf07b commit 5525734
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 35 deletions.
104 changes: 96 additions & 8 deletions limitador/src/storage/distributed/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::{error::Error, io::ErrorKind, pin::Pin};

use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::Sender;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::sync::mpsc::{Permit, Sender};
use tokio::sync::{broadcast, mpsc, Notify, RwLock};
use tokio::time::sleep;
use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
use tonic::{Code, Request, Response, Status, Streaming};
use tracing::debug;

use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::packet::Message;
use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient;
use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer};
Expand Down Expand Up @@ -145,12 +146,47 @@ impl Session {
TrySendError::Closed(_) => Status::unavailable("re-sync channel closed"),
})?;

let mut updates = self.broker_state.publisher.subscribe();
let mut udpates_to_send = self.broker_state.publisher.subscribe();
let mut tx_updates_by_key = HashMap::new();
let mut tx_updates_order = vec![];
let notifier = Notify::default();

loop {
tokio::select! {
update = updates.recv() => {
update = udpates_to_send.recv() => {
let update = update.map_err(|_| Status::unknown("broadcast error"))?;
self.send(Message::CounterUpdate(update)).await?;
// Multiple updates collapse into a single update for the same key
if !tx_updates_by_key.contains_key(&update.key) {
tx_updates_by_key.insert(update.key.clone(), update.value);
tx_updates_order.push(update.key);
notifier.notify_one();
}
}
_ = notifier.notified() => {
// while we have pending updates to send...
while !tx_updates_order.is_empty() {
// and we have space on the transmission channel to send the update...
match self.out_stream.clone().try_reserve() {
Err(_) => {
break
},
Ok(permit) => {

let key = tx_updates_order.remove(0);
let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone();
let (expiry, values) = (*cr_counter_value).clone().into_inner();

// only send the update if it has not expired.
if expiry > SystemTime::now() {
permit.send(Ok(Message::CounterUpdate(CounterUpdate {
key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})))?;
}
}
}
}
}
result = in_stream.next() => {
match result {
Expand Down Expand Up @@ -354,14 +390,66 @@ impl MessageSender {
},
}
}
fn try_reserve(&self) -> Result<MessagePermit<'_>, Status> {
match self {
MessageSender::Client(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Client(permit))
}
MessageSender::Server(sender) => {
let permit = sender
.try_reserve()
.map_err(|_| Status::unknown("send error"))?;
Ok(MessagePermit::Server(permit))
}
}
}
}

enum MessagePermit<'a> {
Server(Permit<'a, Result<Packet, Status>>),
Client(Permit<'a, Packet>),
}
impl<'a> MessagePermit<'a> {
fn send(self, message: Result<Message, Status>) -> Result<(), Status> {
match self {
MessagePermit::Server(sender) => {
let value = message.map(|x| Packet { message: Some(x) });
sender.send(value);
Ok(())
}
MessagePermit::Client(sender) => match message {
Ok(message) => {
sender.send(Packet {
message: Some(message),
});
Ok(())
}
Err(err) => Err(err),
},
}
}
}

type CounterUpdateFn = Pin<Box<dyn Fn(CounterUpdate) + Sync + Send>>;
#[derive(Clone, Debug)]
pub struct CounterEntry {
pub key: Vec<u8>,
pub value: Arc<CrCounterValue<String>>,
}

impl CounterEntry {
pub fn new(key: Vec<u8>, value: Arc<CrCounterValue<String>>) -> Self {
Self { key, value }
}
}

#[derive(Clone)]
struct BrokerState {
id: String,
publisher: broadcast::Sender<CounterUpdate>,
publisher: broadcast::Sender<CounterEntry>,
on_counter_update: Arc<CounterUpdateFn>,
on_re_sync: Arc<Sender<Sender<Option<CounterUpdate>>>>,
}
Expand All @@ -383,7 +471,7 @@ impl Broker {
on_re_sync: Sender<Sender<Option<CounterUpdate>>>,
) -> Broker {
let (tx, _) = broadcast::channel(16);
let publisher: broadcast::Sender<CounterUpdate> = tx;
let publisher: broadcast::Sender<CounterEntry> = tx;

Broker {
listen_address,
Expand All @@ -401,7 +489,7 @@ impl Broker {
}
}

pub fn publish(&self, counter_update: CounterUpdate) {
pub fn publish(&self, counter_update: CounterEntry) {
// ignore the send error, it just means there are no active subscribers
_ = self.broker_state.publisher.send(counter_update);
}
Expand Down
52 changes: 25 additions & 27 deletions limitador/src/storage/distributed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use crate::counter::Counter;
use crate::limit::{Limit, Namespace};
use crate::storage::distributed::cr_counter_value::CrCounterValue;
use crate::storage::distributed::grpc::v1::CounterUpdate;
use crate::storage::distributed::grpc::Broker;
use crate::storage::distributed::grpc::{Broker, CounterEntry};
use crate::storage::{Authorization, CounterStorage, StorageErr};

mod cr_counter_value;
mod grpc;

pub type LimitsMap = HashMap<Vec<u8>, CrCounterValue<String>>;
pub type LimitsMap = HashMap<Vec<u8>, Arc<CrCounterValue<String>>>;

pub struct CrInMemoryStorage {
identifier: String,
Expand All @@ -45,11 +45,11 @@ impl CounterStorage for CrInMemoryStorage {
if limit.variables().is_empty() {
let mut limits = self.limits.write().unwrap();
let key = encode_limit_to_key(limit);
limits.entry(key).or_insert(CrCounterValue::new(
limits.entry(key).or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
limit.max_value(),
Duration::from_secs(limit.seconds()),
));
)));
}
Ok(())
}
Expand All @@ -63,13 +63,16 @@ impl CounterStorage for CrInMemoryStorage {
match limits.entry(key.clone()) {
Entry::Vacant(entry) => {
let duration = counter.window();
let store_value =
CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration);
self.increment_counter(counter, key, &store_value, delta, now);
let store_value = Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
duration,
));
self.increment_counter(counter, key, store_value.clone(), delta, now);
entry.insert(store_value);
}
Entry::Occupied(entry) => {
self.increment_counter(counter, key, entry.get(), delta, now);
self.increment_counter(counter, key, entry.get().clone(), delta, now);
}
};
Ok(())
Expand Down Expand Up @@ -132,11 +135,14 @@ impl CounterStorage for CrInMemoryStorage {
if !counter_existed {
// try again with a write lock to create the counter if it's still missing.
let mut limits = self.limits.write().unwrap();
let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
));
let store_value =
limits
.entry(key.clone())
.or_insert(Arc::new(CrCounterValue::new(
self.identifier.clone(),
counter.max_value(),
counter.window(),
)));

if let Some(limited) = process_counter(counter, store_value.read(), delta) {
if !load_counters {
Expand All @@ -157,7 +163,7 @@ impl CounterStorage for CrInMemoryStorage {
.into_iter()
.for_each(|(counter, key)| {
let store_value = limits.get(&key).unwrap();
self.increment_counter(&counter, key, store_value, delta, now);
self.increment_counter(&counter, key, store_value.clone(), delta, now);
});

Ok(Authorization::Ok)
Expand All @@ -181,7 +187,7 @@ impl CounterStorage for CrInMemoryStorage {
};

if limits.contains(&limit_key) {
let counter = (&counter_key, counter_value);
let counter = (&counter_key, &*counter_value.clone());
let mut counter: Counter = counter.into();
counter.set_remaining(counter.max_value() - counter_value.read());
counter.set_expires_in(counter_value.ttl());
Expand Down Expand Up @@ -280,25 +286,17 @@ impl CrInMemoryStorage {
&self,
counter: &Counter,
store_key: Vec<u8>,
store_value: &CrCounterValue<String>,
store_value: Arc<CrCounterValue<String>>,
delta: u64,
when: SystemTime,
) {
store_value.inc_at(delta, counter.window(), when);

let (expiry, values) = store_value.clone().into_inner();
self.broker.publish(CounterUpdate {
key: store_key,
values: values.into_iter().collect(),
expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(),
})
self.broker
.publish(CounterEntry::new(store_key, store_value))
}
}

async fn process_re_sync(
limits: &Arc<RwLock<HashMap<Vec<u8>, CrCounterValue<String>>>>,
sender: Sender<Option<CounterUpdate>>,
) {
async fn process_re_sync(limits: &Arc<RwLock<LimitsMap>>, sender: Sender<Option<CounterUpdate>>) {
// sending all the counters to the peer might take a while, so we don't want to lock
// the limits map for too long, lets figure first get the list of keys that needs to be sent.
let keys: Vec<_> = {
Expand Down

0 comments on commit 5525734

Please sign in to comment.