diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 3754a46e..caa9a06d 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -9,12 +9,14 @@ use dashmap::DashMap; use moka::sync::Cache; use std::collections::HashMap; use std::future::Future; +use std::ops::Not; use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::select; use tokio::sync::Notify; +#[derive(Debug)] pub struct CachedCounterValue { value: AtomicExpiringValue, initial_value: AtomicI64, @@ -40,7 +42,7 @@ impl CachedCounterValue { temp_value, now + Duration::from_secs(counter.seconds()), ), - initial_value: AtomicI64::new(temp_value), + initial_value: AtomicI64::new(0), expiry: AtomicExpiryTime::from_now(Duration::from_secs(counter.seconds())), from_authority: AtomicBool::new(false), } @@ -64,6 +66,7 @@ impl CachedCounterValue { .update(delta, counter.seconds(), SystemTime::now()); if value == delta { // new window, invalidate initial value + // which happens _after_ the self.value was reset, see `pending_writes` self.initial_value.store(0, Ordering::SeqCst); } value @@ -76,9 +79,11 @@ impl CachedCounterValue { value } else { let writes = value - start; - if writes > 0 { + if writes >= 0 { writes } else { + // self.value expired, is now less than the writes of the previous window + // which have not yet been reset... it'll be 0, so treat it as such. value } }; @@ -89,7 +94,7 @@ impl CachedCounterValue { Ok(_) => Ok(offset), Err(newer) => { if newer == 0 { - // We got expired in the meantime, this fresh value can wait the next iteration + // We got reset because of expiry, this fresh value can wait the next iteration Ok(0) } else { // Concurrent call to this method? @@ -123,7 +128,7 @@ impl CachedCounterValue { } pub fn requires_fast_flush(&self, within: &Duration) -> bool { - self.from_authority.load(Ordering::Acquire) || &self.value.ttl() <= within + self.from_authority.load(Ordering::Acquire).not() || &self.value.ttl() <= within } } @@ -388,20 +393,102 @@ mod tests { use crate::limit::Limit; use std::collections::HashMap; + mod cached_counter_value { + use crate::storage::redis::counters_cache::tests::test_counter; + use crate::storage::redis::counters_cache::CachedCounterValue; + use std::ops::Not; + use std::time::{Duration, SystemTime}; + + #[test] + fn records_pending_writes() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + assert_eq!(value.pending_writes(), Ok(0)); + value.delta(&counter, 5); + assert_eq!(value.pending_writes(), Ok(5)); + } + + #[test] + fn consumes_pending_writes() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + value.delta(&counter, 5); + assert_eq!(value.pending_writes(), Ok(5)); + assert_eq!(value.pending_writes(), Ok(0)); + } + + #[test] + fn no_pending_writes() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + value.delta(&counter, 5); + assert!(value.no_pending_writes().not()); + assert!(value.pending_writes().is_ok()); + assert!(value.no_pending_writes()); + } + + #[test] + fn setting_from_auth_resets_pending_writes() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + value.delta(&counter, 5); + assert!(value.no_pending_writes().not()); + value.set_from_authority(&counter, 6, Duration::from_secs(1)); + assert!(value.no_pending_writes()); + assert_eq!(value.pending_writes(), Ok(0)); + } + + #[test] + fn from_authority_no_need_to_flush() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(10)); + assert!(value.requires_fast_flush(&Duration::from_secs(30)).not()); + } + + #[test] + fn from_authority_needs_to_flush_within_ttl() { + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + assert!(value.requires_fast_flush(&Duration::from_secs(90))); + } + + #[test] + fn fake_needs_to_flush_within_ttl() { + let counter = test_counter(10, None); + let value = CachedCounterValue::load_from_authority_asap(&counter, 0); + assert!(value.requires_fast_flush(&Duration::from_secs(30))); + } + + #[test] + fn expiry_of_cached_entry() { + let counter = test_counter(10, None); + let cache_entry_ttl = Duration::from_secs(1); + let value = CachedCounterValue::from_authority(&counter, 0, cache_entry_ttl); + let now = SystemTime::now(); + assert!(value.expired_at(now).not()); + assert!(value.expired_at(now + cache_entry_ttl)); + } + + #[test] + fn delegates_to_underlying_value() { + let hits = 4; + + let counter = test_counter(10, None); + let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1)); + value.delta(&counter, hits); + assert!(value.to_next_window() > Duration::from_millis(59999)); + assert_eq!(value.hits(&counter), hits); + let remaining = counter.max_value() - hits; + assert_eq!(value.remaining(&counter), remaining); + assert!(value.is_limited(&counter, 1).not()); + assert!(value.is_limited(&counter, remaining).not()); + assert!(value.is_limited(&counter, remaining + 1)); + } + } + #[test] fn get_existing_counter() { - let mut values = HashMap::new(); - values.insert("app_id".to_string(), "1".to_string()); - let counter = Counter::new( - Limit::new( - "test_namespace", - 10, - 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ), - values, - ); + let counter = test_counter(10, None); let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( @@ -417,18 +504,7 @@ mod tests { #[test] fn get_non_existing_counter() { - let mut values = HashMap::new(); - values.insert("app_id".to_string(), "1".to_string()); - let counter = Counter::new( - Limit::new( - "test_namespace", - 10, - 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ), - values, - ); + let counter = test_counter(10, None); let cache = CountersCacheBuilder::new().build(Duration::default()); @@ -439,18 +515,7 @@ mod tests { fn insert_saves_the_given_value_when_is_some() { let max_val = 10; let current_value = max_val / 2; - let mut values = HashMap::new(); - values.insert("app_id".to_string(), "1".to_string()); - let counter = Counter::new( - Limit::new( - "test_namespace", - max_val, - 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ), - values, - ); + let counter = test_counter(max_val, None); let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( @@ -470,18 +535,7 @@ mod tests { #[test] fn insert_saves_zero_when_redis_val_is_none() { let max_val = 10; - let mut values = HashMap::new(); - values.insert("app_id".to_string(), "1".to_string()); - let counter = Counter::new( - Limit::new( - "test_namespace", - max_val, - 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ), - values, - ); + let counter = test_counter(max_val, None); let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( @@ -499,18 +553,7 @@ mod tests { fn increase_by() { let current_val = 10; let increase_by = 8; - let mut values = HashMap::new(); - values.insert("app_id".to_string(), "1".to_string()); - let counter = Counter::new( - Limit::new( - "test_namespace", - current_val, - 60, - vec!["req.method == 'POST'"], - vec!["app_id"], - ), - values, - ); + let counter = test_counter(current_val, None); let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( @@ -527,4 +570,22 @@ mod tests { (current_val + increase_by) ); } + + fn test_counter(max_val: i64, other_values: Option>) -> Counter { + let mut values = HashMap::new(); + values.insert("app_id".to_string(), "1".to_string()); + if let Some(overrides) = other_values { + values.extend(overrides); + } + Counter::new( + Limit::new( + "test_namespace", + max_val, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + values, + ) + } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index fbea6ac3..5a092f61 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -301,8 +301,8 @@ async fn update_counters( return Ok(res); } - for (counter, delta) in counters_and_deltas { - let delta = delta.pending_writes().expect("State machine is wrong!"); + for (counter, value) in counters_and_deltas { + let delta = value.pending_writes().expect("State machine is wrong!"); if delta > 0 { script_invocation.key(key_for_counter(&counter)); script_invocation.key(key_for_counters_of_limit(counter.limit())); @@ -421,14 +421,13 @@ mod tests { Default::default(), ); - counters_and_deltas.insert( - counter.clone(), - Arc::new(CachedCounterValue::from_authority( - &counter, - 1, - Duration::from_secs(60), - )), - ); + let arc = Arc::new(CachedCounterValue::from_authority( + &counter, + 1, + Duration::from_secs(60), + )); + arc.delta(&counter, 1); + counters_and_deltas.insert(counter.clone(), arc); let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]); @@ -440,7 +439,7 @@ mod tests { .arg(key_for_counters_of_limit(counter.limit())) .arg(60) .arg(1), - Ok(mock_response.clone()), + Ok(mock_response), )]); let result = update_counters(&mut mock_client, counters_and_deltas).await; @@ -479,17 +478,13 @@ mod tests { .arg(key_for_counters_of_limit(counter.limit())) .arg(60) .arg(2), - Ok(mock_response.clone()), + Ok(mock_response), )]); let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); cache.batcher().add( counter.clone(), - Arc::new(CachedCounterValue::from_authority( - &counter, - 2, - Duration::from_secs(60), - )), + Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)), ); cache.insert( counter.clone(),