Skip to content

Commit

Permalink
Merge pull request #310 from Kuadrant/moar_tests
Browse files Browse the repository at this point in the history
Added tests... and fixes... to CachedCounterValue
  • Loading branch information
alexsnaps authored May 1, 2024
2 parents 754a319 + 7441e4d commit 43ed429
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 81 deletions.
189 changes: 125 additions & 64 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ use dashmap::DashMap;
use moka::sync::Cache;
use std::collections::HashMap;
use std::future::Future;
use std::ops::Not;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::select;
use tokio::sync::Notify;

#[derive(Debug)]
pub struct CachedCounterValue {
value: AtomicExpiringValue,
initial_value: AtomicI64,
Expand All @@ -40,7 +42,7 @@ impl CachedCounterValue {
temp_value,
now + Duration::from_secs(counter.seconds()),
),
initial_value: AtomicI64::new(temp_value),
initial_value: AtomicI64::new(0),
expiry: AtomicExpiryTime::from_now(Duration::from_secs(counter.seconds())),
from_authority: AtomicBool::new(false),
}
Expand All @@ -64,6 +66,7 @@ impl CachedCounterValue {
.update(delta, counter.seconds(), SystemTime::now());
if value == delta {
// new window, invalidate initial value
// which happens _after_ the self.value was reset, see `pending_writes`
self.initial_value.store(0, Ordering::SeqCst);
}
value
Expand All @@ -76,9 +79,11 @@ impl CachedCounterValue {
value
} else {
let writes = value - start;
if writes > 0 {
if writes >= 0 {
writes
} else {
// self.value expired, is now less than the writes of the previous window
// which have not yet been reset... it'll be 0, so treat it as such.
value
}
};
Expand All @@ -89,7 +94,7 @@ impl CachedCounterValue {
Ok(_) => Ok(offset),
Err(newer) => {
if newer == 0 {
// We got expired in the meantime, this fresh value can wait the next iteration
// We got reset because of expiry, this fresh value can wait the next iteration
Ok(0)
} else {
// Concurrent call to this method?
Expand Down Expand Up @@ -123,7 +128,7 @@ impl CachedCounterValue {
}

pub fn requires_fast_flush(&self, within: &Duration) -> bool {
self.from_authority.load(Ordering::Acquire) || &self.value.ttl() <= within
self.from_authority.load(Ordering::Acquire).not() || &self.value.ttl() <= within
}
}

Expand Down Expand Up @@ -388,20 +393,102 @@ mod tests {
use crate::limit::Limit;
use std::collections::HashMap;

mod cached_counter_value {
use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::CachedCounterValue;
use std::ops::Not;
use std::time::{Duration, SystemTime};

#[test]
fn records_pending_writes() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
assert_eq!(value.pending_writes(), Ok(0));
value.delta(&counter, 5);
assert_eq!(value.pending_writes(), Ok(5));
}

#[test]
fn consumes_pending_writes() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
value.delta(&counter, 5);
assert_eq!(value.pending_writes(), Ok(5));
assert_eq!(value.pending_writes(), Ok(0));
}

#[test]
fn no_pending_writes() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
value.delta(&counter, 5);
assert!(value.no_pending_writes().not());
assert!(value.pending_writes().is_ok());
assert!(value.no_pending_writes());
}

#[test]
fn setting_from_auth_resets_pending_writes() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
value.delta(&counter, 5);
assert!(value.no_pending_writes().not());
value.set_from_authority(&counter, 6, Duration::from_secs(1));
assert!(value.no_pending_writes());
assert_eq!(value.pending_writes(), Ok(0));
}

#[test]
fn from_authority_no_need_to_flush() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(10));
assert!(value.requires_fast_flush(&Duration::from_secs(30)).not());
}

#[test]
fn from_authority_needs_to_flush_within_ttl() {
let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
assert!(value.requires_fast_flush(&Duration::from_secs(90)));
}

#[test]
fn fake_needs_to_flush_within_ttl() {
let counter = test_counter(10, None);
let value = CachedCounterValue::load_from_authority_asap(&counter, 0);
assert!(value.requires_fast_flush(&Duration::from_secs(30)));
}

#[test]
fn expiry_of_cached_entry() {
let counter = test_counter(10, None);
let cache_entry_ttl = Duration::from_secs(1);
let value = CachedCounterValue::from_authority(&counter, 0, cache_entry_ttl);
let now = SystemTime::now();
assert!(value.expired_at(now).not());
assert!(value.expired_at(now + cache_entry_ttl));
}

#[test]
fn delegates_to_underlying_value() {
let hits = 4;

let counter = test_counter(10, None);
let value = CachedCounterValue::from_authority(&counter, 0, Duration::from_secs(1));
value.delta(&counter, hits);
assert!(value.to_next_window() > Duration::from_millis(59999));
assert_eq!(value.hits(&counter), hits);
let remaining = counter.max_value() - hits;
assert_eq!(value.remaining(&counter), remaining);
assert!(value.is_limited(&counter, 1).not());
assert!(value.is_limited(&counter, remaining).not());
assert!(value.is_limited(&counter, remaining + 1));
}
}

#[test]
fn get_existing_counter() {
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
let counter = Counter::new(
Limit::new(
"test_namespace",
10,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
);
let counter = test_counter(10, None);

let cache = CountersCacheBuilder::new().build(Duration::default());
cache.insert(
Expand All @@ -417,18 +504,7 @@ mod tests {

#[test]
fn get_non_existing_counter() {
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
let counter = Counter::new(
Limit::new(
"test_namespace",
10,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
);
let counter = test_counter(10, None);

let cache = CountersCacheBuilder::new().build(Duration::default());

Expand All @@ -439,18 +515,7 @@ mod tests {
fn insert_saves_the_given_value_when_is_some() {
let max_val = 10;
let current_value = max_val / 2;
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
let counter = Counter::new(
Limit::new(
"test_namespace",
max_val,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
);
let counter = test_counter(max_val, None);

let cache = CountersCacheBuilder::new().build(Duration::default());
cache.insert(
Expand All @@ -470,18 +535,7 @@ mod tests {
#[test]
fn insert_saves_zero_when_redis_val_is_none() {
let max_val = 10;
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
let counter = Counter::new(
Limit::new(
"test_namespace",
max_val,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
);
let counter = test_counter(max_val, None);

let cache = CountersCacheBuilder::new().build(Duration::default());
cache.insert(
Expand All @@ -499,18 +553,7 @@ mod tests {
fn increase_by() {
let current_val = 10;
let increase_by = 8;
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
let counter = Counter::new(
Limit::new(
"test_namespace",
current_val,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
);
let counter = test_counter(current_val, None);

let cache = CountersCacheBuilder::new().build(Duration::default());
cache.insert(
Expand All @@ -527,4 +570,22 @@ mod tests {
(current_val + increase_by)
);
}

fn test_counter(max_val: i64, other_values: Option<HashMap<String, String>>) -> Counter {
let mut values = HashMap::new();
values.insert("app_id".to_string(), "1".to_string());
if let Some(overrides) = other_values {
values.extend(overrides);
}
Counter::new(
Limit::new(
"test_namespace",
max_val,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
values,
)
}
}
29 changes: 12 additions & 17 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ async fn update_counters<C: ConnectionLike>(
return Ok(res);
}

for (counter, delta) in counters_and_deltas {
let delta = delta.pending_writes().expect("State machine is wrong!");
for (counter, value) in counters_and_deltas {
let delta = value.pending_writes().expect("State machine is wrong!");
if delta > 0 {
script_invocation.key(key_for_counter(&counter));
script_invocation.key(key_for_counters_of_limit(counter.limit()));
Expand Down Expand Up @@ -421,14 +421,13 @@ mod tests {
Default::default(),
);

counters_and_deltas.insert(
counter.clone(),
Arc::new(CachedCounterValue::from_authority(
&counter,
1,
Duration::from_secs(60),
)),
);
let arc = Arc::new(CachedCounterValue::from_authority(
&counter,
1,
Duration::from_secs(60),
));
arc.delta(&counter, 1);
counters_and_deltas.insert(counter.clone(), arc);

let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]);

Expand All @@ -440,7 +439,7 @@ mod tests {
.arg(key_for_counters_of_limit(counter.limit()))
.arg(60)
.arg(1),
Ok(mock_response.clone()),
Ok(mock_response),
)]);

let result = update_counters(&mut mock_client, counters_and_deltas).await;
Expand Down Expand Up @@ -479,17 +478,13 @@ mod tests {
.arg(key_for_counters_of_limit(counter.limit()))
.arg(60)
.arg(2),
Ok(mock_response.clone()),
Ok(mock_response),
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(1));
cache.batcher().add(
counter.clone(),
Arc::new(CachedCounterValue::from_authority(
&counter,
2,
Duration::from_secs(60),
)),
Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)),
);
cache.insert(
counter.clone(),
Expand Down

0 comments on commit 43ed429

Please sign in to comment.