Skip to content

Commit

Permalink
loadbalancer-experimental: track the length of outstanding requests (#…
Browse files Browse the repository at this point in the history
…2833)

Motivation:

DefaultRequestTracker is intended to track how long requests are
taking but right now it can only add data once it gets a response.
As requests start to pile up we get a multiplicative effect but
we don't consider how long requests have remained outstanding.

Modifications:

- Keep track of how long a request has been pending by marking
  its start time and considering it when computing score.

Result:

We're more sensitive to sudden failures.
  • Loading branch information
bryce-anderson committed Feb 19, 2024
1 parent d2794c4 commit fb39471
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import io.servicetalk.client.api.ScoreSupplier;

import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.IntBinaryOperator;
import java.util.concurrent.locks.StampedLock;

import static io.servicetalk.utils.internal.NumberUtils.ensurePositive;
import static java.lang.Integer.MAX_VALUE;
Expand All @@ -30,7 +29,6 @@
import static java.lang.Math.min;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.concurrent.atomic.AtomicIntegerFieldUpdater.newUpdater;

/**
* Latency tracker using exponential weighted moving average based on the work by Andreas Eckner and subsequent
Expand All @@ -40,11 +38,11 @@
* (2019) Algorithms for Unevenly Spaced Time Series: Moving Averages and Other Rolling Operators (4.2 pp. 10)</a>
*/
abstract class DefaultRequestTracker implements RequestTracker, ScoreSupplier {
private static final AtomicIntegerFieldUpdater<DefaultRequestTracker> pendingUpdater =
newUpdater(DefaultRequestTracker.class, "pending");
private static final long MAX_MS_TO_NS = NANOSECONDS.convert(MAX_VALUE, MILLISECONDS);
static final long DEFAULT_CANCEL_PENALTY = 5L;
static final long DEFAULT_ERROR_PENALTY = 10L;

private final StampedLock lock = new StampedLock();
/**
* Mean lifetime, exponential decay. inverted tau
*/
Expand All @@ -60,7 +58,8 @@ abstract class DefaultRequestTracker implements RequestTracker, ScoreSupplier {
* Current weighted average.
*/
private int ewma;
private volatile int pending;
private int pendingCount;
private long pendingStamp = Long.MIN_VALUE;

DefaultRequestTracker(final long halfLifeNanos) {
this(halfLifeNanos, DEFAULT_CANCEL_PENALTY, DEFAULT_ERROR_PENALTY);
Expand All @@ -81,27 +80,70 @@ abstract class DefaultRequestTracker implements RequestTracker, ScoreSupplier {

@Override
public final long beforeRequestStart() {
pendingUpdater.incrementAndGet(this);
return currentTimeNanos();
final long stamp = lock.writeLock();
try {
long timestamp = currentTimeNanos();
pendingCount++;
if (pendingStamp == Long.MIN_VALUE) {
// only update the pending timestamp if it doesn't already have a value.
pendingStamp = timestamp;
}
return timestamp;
} finally {
lock.unlockWrite(stamp);
}
}

@Override
public void onRequestSuccess(final long startTimeNanos) {
pendingUpdater.decrementAndGet(this);
calculateAndStore((ewma, currentLatency) -> currentLatency, startTimeNanos);
onComplete(startTimeNanos, 0);
}

@Override
public void onRequestError(final long startTimeNanos, ErrorClass errorClass) {
pendingUpdater.decrementAndGet(this);
calculateAndStore(errorClass == ErrorClass.CANCELLED ? this:: cancelPenalty : this::errorPenalty,
startTimeNanos);
onComplete(startTimeNanos, errorClass == ErrorClass.CANCELLED ? cancelPenalty : errorPenalty);
}

private void onComplete(final long startTimeNanos, long penalty) {
final long stamp = lock.writeLock();
try {
pendingCount--;
// Unconditionally clear the timestamp because we don't know which request set it. This is an acceptable
// 'error' since otherwise we need to keep a collection of start timestamps.
pendingStamp = Long.MIN_VALUE;
updateEwma(penalty, startTimeNanos);
} finally {
lock.unlockWrite(stamp);
}
}

@Override
public final int score() {
final int currentEWMA = calculateAndStore((ewma, lastTimeNanos) -> 0, 0);
final int cPending = pendingUpdater.get(this);
final long lastTimeNanos;
final int cPending;
final long pendingStamp;
int currentEWMA;
// read all the relevant state using the read lock
final long stamp = lock.readLock();
try {
currentEWMA = ewma;
lastTimeNanos = this.lastTimeNanos;
cPending = pendingCount;
pendingStamp = this.pendingStamp;
} finally {
lock.unlockRead(stamp);
}
// It's fine to get this after releasing the lock since it will still happen after whatever last
// wrote the value to `lastTimeNanos`.
final long currentTimeNanos = currentTimeNanos();

if (currentEWMA != 0) {
// need to apply the exponential decay.
final double tmp = (currentTimeNanos - lastTimeNanos) * invTau;
final double w = exp(-tmp);
currentEWMA = (int) ceil(currentEWMA * w);
}

if (currentEWMA == 0) {
// If EWMA has decayed to 0 (or isn't yet initialized) and there are no pending requests we return the
// maximum score to increase the likelihood this entity is selected. If there are pending requests we
Expand All @@ -110,6 +152,12 @@ public final int score() {
return cPending == 0 ? 0 : MIN_VALUE;
}

if (cPending > 0 && pendingStamp != Long.MIN_VALUE) {
// If we have a request outstanding we should consider how long it has been outstanding so that sudden
// interruptions don't have to wait for timeouts before our scores can be adjusted.
currentEWMA = max(currentEWMA, nanoToMillis(currentTimeNanos - pendingStamp));
}

// Add penalty for pending requests to account for "unaccounted" load.
// Penalty is the observed latency if known, else an arbitrarily high value which makes entities for which
// no latency data has yet been received (eg: request sent but not received), un-selectable.
Expand All @@ -119,39 +167,31 @@ public final int score() {
return MAX_VALUE - currentEWMA <= pendingPenalty ? MIN_VALUE : -(currentEWMA + pendingPenalty);
}

private int cancelPenalty(int currentEWMA, int currentLatency) {
// There is no significance to the choice of this multiplier (other than it is half of the error penalty)
// and it is selected to gather empirical evidence as the algorithm is evaluated.
return applyPenalty(currentEWMA, currentLatency, cancelPenalty);
}

private int errorPenalty(int currentEWMA, int currentLatency) {
// There is no significance to the choice of this multiplier (other than it is double of the cancel penalty)
// and it is selected to gather empirical evidence as the algorithm is evaluated.
return applyPenalty(currentEWMA, currentLatency, errorPenalty);
}

private static int applyPenalty(int currentEWMA, int currentLatency, long penalty) {
// Relatively large latencies will have a bigger impact on the penalty, while smaller latencies (e.g. premature
// cancel/error) rely on the penalty.
return (int) min(MAX_VALUE, max(currentEWMA, currentLatency) * penalty);
}

private synchronized int calculateAndStore(final IntBinaryOperator latencyInitializer, long startTimeNanos) {
final int nextEWMA;
// We capture the current time inside the synchronized block to exploit the monotonic time source
// must be called while holding the lock in write mode.
private void updateEwma(long penalty, long startTimeNanos) {
// We capture the current time while holding the lock to exploit the monotonic time source
// properties which prevent the time duration from going negative. This will result in a latency penalty
// as concurrency increases, but is a trade-off for simplicity.
final long currentTimeNanos = currentTimeNanos();
// When modifying the EWMA and lastTime we read/write both values in a synchronized block as they are
// When modifying the EWMA and lastTime we read/write both values while holding the lock as they are
// tightly coupled in the EWMA formula below.
final int currentEWMA = ewma;
final int currentLatency = latencyInitializer.applyAsInt(ewma, nanoToMillis(currentTimeNanos - startTimeNanos));
// Note the currentLatency cannot be <0 or else the EWMA equation properties are violated
// (e.g. "degree of weighting decrease" is not in [0, 1]).
final int currentLatency;
if (penalty > 0) {
currentLatency = applyPenalty(currentEWMA, nanoToMillis(currentTimeNanos - startTimeNanos), penalty);
} else {
currentLatency = nanoToMillis(currentTimeNanos - startTimeNanos);
}
assert currentLatency >= 0;

// Peak EWMA from finagle for the score to be extremely sensitive to higher than normal latencies.
final int nextEWMA;
if (currentLatency > currentEWMA) {
nextEWMA = currentLatency;
} else {
Expand All @@ -161,7 +201,6 @@ private synchronized int calculateAndStore(final IntBinaryOperator latencyInitia
}
lastTimeNanos = currentTimeNanos;
ewma = nextEWMA;
return nextEWMA;
}

private static int nanoToMillis(long nanos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
*/
package io.servicetalk.loadbalancer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.function.LongUnaryOperator;

import static java.lang.System.nanoTime;
import static java.time.Duration.ofSeconds;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -34,28 +33,61 @@ void test() {
final LongUnaryOperator nextValueProvider = mock(LongUnaryOperator.class);
when(nextValueProvider.applyAsLong(anyLong())).thenAnswer(__ -> ofSeconds(1).toNanos());
final DefaultRequestTracker requestTracker = new TestRequestTracker(Duration.ofSeconds(1), nextValueProvider);
Assertions.assertEquals(0, requestTracker.score());
assertEquals(0, requestTracker.score());

// upon success score
requestTracker.onRequestSuccess(requestTracker.beforeRequestStart());
Assertions.assertEquals(-500, requestTracker.score());
assertEquals(-500, requestTracker.score());

// error penalty
requestTracker.onRequestError(requestTracker.beforeRequestStart(), ErrorClass.EXT_ORIGIN_REQUEST_FAILED);
Assertions.assertEquals(-5000, requestTracker.score());
assertEquals(-5_000, requestTracker.score());

// cancellation penalty
requestTracker.onRequestError(requestTracker.beforeRequestStart(), ErrorClass.CANCELLED);
Assertions.assertEquals(-12_500, requestTracker.score());
assertEquals(-25_000, requestTracker.score());

// decay
when(nextValueProvider.applyAsLong(anyLong())).thenAnswer(__ -> ofSeconds(20).toNanos());
Assertions.assertEquals(-1, requestTracker.score());
assertEquals(-1, requestTracker.score());
}

@Test
void zeroDataScoreWithPendingRequestIsIntMinValue() {
final LongUnaryOperator nextValueProvider = mock(LongUnaryOperator.class);
when(nextValueProvider.applyAsLong(anyLong())).thenAnswer(__ -> ofSeconds(0).toNanos());
final DefaultRequestTracker requestTracker = new TestRequestTracker(Duration.ofSeconds(1), nextValueProvider);
assertEquals(0, requestTracker.score());

// upon success score
requestTracker.beforeRequestStart();
assertEquals(Integer.MIN_VALUE, requestTracker.score());
}

@Test
void outstandingLatencyIsTracked() {
final LongUnaryOperator nextValueProvider = mock(LongUnaryOperator.class);
when(nextValueProvider.applyAsLong(anyLong())).thenAnswer(__ -> ofSeconds(1).toNanos());

final DefaultRequestTracker requestTracker = new TestRequestTracker(Duration.ofSeconds(1), nextValueProvider);
assertEquals(0, requestTracker.score());

// upon success score
requestTracker.onRequestSuccess(requestTracker.beforeRequestStart());
// super quick, so our score is the max it can be which is 0.
assertEquals(-500, requestTracker.score());

// start a request. Should be 5 calls to the time provider.
assertEquals(5_000_000_000L, requestTracker.beforeRequestStart());
// start to advance time
when(nextValueProvider.applyAsLong(anyLong())).thenAnswer(__ -> ofSeconds(1).toNanos());
// this is 4 because we are calling the time twice...
assertEquals(-2_000, requestTracker.score());
}

static final class TestRequestTracker extends DefaultRequestTracker {
private final LongUnaryOperator nextValueProvider;
private long lastValue = nanoTime();
private long lastValue;

TestRequestTracker(Duration measurementHalfLife, final LongUnaryOperator nextValueProvider) {
super(measurementHalfLife.toNanos(), DEFAULT_CANCEL_PENALTY, DEFAULT_ERROR_PENALTY);
Expand Down

0 comments on commit fb39471

Please sign in to comment.