From 27af5684df4e8618e67073fdafe7df5b5bb3ca64 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Thu, 9 May 2024 13:10:38 -0400 Subject: [PATCH] Revert on transient redis error --- limitador/src/storage/redis/counters_cache.rs | 80 +++++++++++++---- limitador/src/storage/redis/redis_cached.rs | 90 +++++++++++++++++-- 2 files changed, 145 insertions(+), 25 deletions(-) diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3949d0c6..69477d78 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -103,6 +103,22 @@ impl CachedCounterValue { value - start == 0 } + fn revert_writes(&self, writes: i64) -> Result<(), ()> { + let newer = self.initial_value.load(Ordering::SeqCst); + if newer > writes { + return match self.initial_value.compare_exchange( + newer, + newer - writes, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Err(expiry) if expiry != 0 => Err(()), + _ => Ok(()), + }; + } + Ok(()) + } + pub fn hits(&self, _: &Counter) -> i64 { self.value.value_at(SystemTime::now()) } @@ -160,10 +176,10 @@ impl Batcher { self.notifier.notify_one(); } - pub async fn consume(&self, max: usize, consumer: F) -> O + pub async fn consume(&self, max: usize, consumer: F) -> Result where F: FnOnce(HashMap>) -> Fut, - Fut: Future, + Fut: Future>, { let mut ready = self.batch_ready(max); loop { @@ -185,10 +201,12 @@ impl Batcher { result.insert(counter.clone(), value); } let result = consumer(result).await; - batch.iter().for_each(|counter| { - self.updates - .remove_if(counter, |_, v| v.no_pending_writes()); - }); + if result.is_ok() { + batch.iter().for_each(|counter| { + self.updates + .remove_if(counter, |_, v| v.no_pending_writes()); + }); + } return result; } else { ready = select! { @@ -240,6 +258,31 @@ impl CountersCache { &self.batcher } + pub fn return_pending_writes( + &self, + counter: &Counter, + value: i64, + writes: i64, + ) -> Result<(), ()> { + if writes != 0 { + let mut miss = false; + let value = self.cache.get_with_by_ref(counter, || { + if let Some(entry) = self.batcher.updates.get(counter) { + entry.value().clone() + } else { + miss = true; + let value = Arc::new(CachedCounterValue::from_authority(counter, value)); + value.delta(counter, writes); + value + } + }); + if miss.not() { + return value.revert_writes(writes); + } + } + Ok(()) + } + pub fn apply_remote_delta( &self, counter: Counter, @@ -415,9 +458,10 @@ mod tests { .consume(2, |items| { assert!(items.is_empty()); assert!(SystemTime::now().duration_since(start).unwrap() >= duration); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -441,9 +485,10 @@ mod tests { SystemTime::now().duration_since(start).unwrap() >= Duration::from_millis(100) ); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -466,9 +511,10 @@ mod tests { let wait_period = SystemTime::now().duration_since(start).unwrap(); assert!(wait_period >= Duration::from_millis(40)); assert!(wait_period < Duration::from_millis(50)); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -487,9 +533,10 @@ mod tests { assert!( SystemTime::now().duration_since(start).unwrap() < Duration::from_millis(5) ); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } #[tokio::test] @@ -512,9 +559,10 @@ mod tests { let wait_period = SystemTime::now().duration_since(start).unwrap(); assert!(wait_period >= Duration::from_millis(40)); assert!(wait_period < Duration::from_millis(50)); - async {} + async { Ok::<(), ()>(()) } }) - .await; + .await + .expect("Always Ok!"); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index c69f2a24..0cf0e8fb 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -280,7 +280,7 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, counters_and_deltas: HashMap>, -) -> Result, StorageErr> { +) -> Result, (Vec<(Counter, i64, i64, i64)>, StorageErr)> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); @@ -299,16 +299,22 @@ async fn update_counters( script_invocation.arg(counter.seconds()); script_invocation.arg(delta); // We need to store the counter in the actual order we are sending it to the script - res.push((counter, 0, last_value_from_redis, 0)); + res.push((counter, last_value_from_redis, delta, 0)); } } let span = debug_span!("datastore"); // The redis crate is not working with tables, thus the response will be a Vec of counter values - let script_res: Vec = script_invocation + let script_res: Vec = match script_invocation .invoke_async(redis_conn) .instrument(span) - .await?; + .await + { + Ok(res) => res, + Err(err) => { + return Err((res, err.into())); + } + }; // We need to update the values and ttls returned by redis let counters_range = 0..res.len(); @@ -316,8 +322,8 @@ async fn update_counters( for (i, j) in counters_range.zip(script_res_range) { let (_, val, delta, expires_at) = &mut res[i]; - *val = script_res[j]; - *delta = script_res[j] - *delta; + *delta = script_res[j] - *val; // new value - previous one = remote writes + *val = script_res[j]; // update to value to newest *expires_at = script_res[j + 1]; } res @@ -344,9 +350,22 @@ async fn flush_batcher_and_update_counters( update_counters(&mut redis_conn, counters) }) .await - .or_else(|err| { + .or_else(|(data, err)| { if err.is_transient() { flip_partitioned(&partitioned, true); + let counters = data.len(); + let mut reverted = 0; + for (counter, old_value, pending_writes, _) in data { + if cached_counters + .return_pending_writes(&counter, old_value, pending_writes) + .is_err() + { + tracing::log::error!("Couldn't revert writes back to {:?}", &counter); + } else { + reverted += 1; + } + } + tracing::log::warn!("Reverted {} of {} counter increments", reverted, counters); Ok(Vec::new()) } else { Err(err) @@ -370,9 +389,10 @@ mod tests { }; use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; - use redis::{ErrorKind, Value}; + use redis::{Cmd, ErrorKind, RedisError, Value}; use redis_test::{MockCmd, MockRedisConnection}; use std::collections::HashMap; + use std::io; use std::ops::Add; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -510,8 +530,60 @@ mod tests { ) .await; + let c = cached_counters.get(&counter).unwrap(); + assert_eq!(c.hits(&counter), 8); + assert_eq!(c.pending_writes(), Ok(0)); + } + + #[tokio::test] + async fn flush_batcher_reverts_on_err() { + let counter = Counter::new( + Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + Default::default(), + ); + + let error: RedisError = io::Error::new(io::ErrorKind::TimedOut, "That was long!").into(); + assert!(error.is_timeout()); + let mock_client = MockRedisConnection::new(vec![MockCmd::new::<&mut Cmd, Value>( + redis::cmd("EVALSHA") + .arg("95a717e821d8fbdd667b5e4c6fede4c9cad16006") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(3), + Err(error), + )]); + + let cache = CountersCacheBuilder::new().build(Duration::from_millis(10)); + let value = Arc::new(CachedCounterValue::from_authority(&counter, 2)); + value.delta(&counter, 3); + cache.batcher().add(counter.clone(), value); + + let cached_counters: Arc = Arc::new(cache); + let partitioned = Arc::new(AtomicBool::new(false)); + if let Some(c) = cached_counters.get(&counter) { - assert_eq!(c.hits(&counter), 8); + assert_eq!(c.hits(&counter), 5); } + + flush_batcher_and_update_counters( + mock_client, + true, + cached_counters.clone(), + partitioned, + 100, + ) + .await; + + let c = cached_counters.get(&counter).unwrap(); + assert_eq!(c.hits(&counter), 5); + assert_eq!(c.pending_writes(), Ok(3)); } }