Skip to content

Commit

Permalink
Add back Netty4WriteThrottlingHandler to HTTP pipeline (elastic#87407)
Browse files Browse the repository at this point in the history
Follow-up to elastic#86922 bringing back the write throttling handler (with necessary adjustments)
as removing has measurably reduced scroll performance in nightly Rally runs.
Throttling at a lower level instead of only at the 1M HTTP chunk level provides a measurable
benefit to latency as it turns out in benchmarks so lets bring it back.
This requires adjusting the write throttling handler to pass through writes that could be flushed
directly so that the upstream throttling sees the correct channel writability status.
Before this change the channel would always look writable because we wouldn't be buffering
anything in the actual outbound buffer but just in the internal queue in the write throttling
handler.
Also, this PR adds coverage for the new code paths in the write throttling handler which together
with the existing coverage should give us safe coverage of all possible throttling and message
size combinations.
  • Loading branch information
original-brownbear committed Jun 8, 2022
1 parent f5ceed1 commit d50d47c
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.netty4.NetUtils;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.elasticsearch.transport.netty4.Netty4WriteThrottlingHandler;
import org.elasticsearch.transport.netty4.NettyAllocator;
import org.elasticsearch.transport.netty4.NettyByteBufSizer;
import org.elasticsearch.transport.netty4.SharedGroupFactory;
Expand Down Expand Up @@ -299,6 +300,7 @@ protected HttpChannelHandler(final Netty4HttpServerTransport transport, final Ht
protected void initChannel(Channel ch) throws Exception {
Netty4HttpChannel nettyHttpChannel = new Netty4HttpChannel(ch);
ch.attr(HTTP_CHANNEL_KEY).set(nettyHttpChannel);
ch.pipeline().addLast("chunked_writer", new Netty4WriteThrottlingHandler(transport.getThreadPool().getThreadContext()));
ch.pipeline().addLast("byte_buf_sizer", NettyByteBufSizer.INSTANCE);
ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS));
final HttpRequestDecoder decoder = new HttpRequestDecoder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;

import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.transport.Transports;
Expand All @@ -29,6 +31,7 @@
*/
public final class Netty4WriteThrottlingHandler extends ChannelDuplexHandler {

public static final int MAX_BYTES_PER_WRITE = 1 << 18;
private final Queue<WriteOperation> queuedWrites = new ArrayDeque<>();

private final ThreadContext threadContext;
Expand All @@ -43,7 +46,45 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
assert msg instanceof ByteBuf;
assert Transports.assertDefaultThreadContext(threadContext);
assert Transports.assertTransportThread();
final boolean queued = queuedWrites.offer(new WriteOperation((ByteBuf) msg, promise));
final ByteBuf buf = (ByteBuf) msg;
if (ctx.channel().isWritable() && currentWrite == null && queuedWrites.isEmpty()) {
// nothing is queued for writing and the channel is writable, just pass the write down the pipeline directly
if (buf.readableBytes() > MAX_BYTES_PER_WRITE) {
writeInSlices(ctx, promise, buf);
} else {
ctx.write(msg, promise);
}
} else {
queueWrite(buf, promise);
}
}

/**
* Writes slices of up to the max write size until the channel stops being writable or the message has been written in full.
*/
private void writeInSlices(ChannelHandlerContext ctx, ChannelPromise promise, ByteBuf buf) {
while (true) {
final int readableBytes = buf.readableBytes();
final int bufferSize = Math.min(readableBytes, MAX_BYTES_PER_WRITE);
if (readableBytes == bufferSize) {
// last write for this chunk we're done
ctx.write(buf).addListener(forwardResultListener(ctx, promise));
return;
}
final int readerIndex = buf.readerIndex();
final ByteBuf writeBuffer = buf.retainedSlice(readerIndex, bufferSize);
buf.readerIndex(readerIndex + bufferSize);
ctx.write(writeBuffer).addListener(forwardFailureListener(ctx, promise));
if (ctx.channel().isWritable() == false) {
// channel isn't writable any longer -> move to queuing
queueWrite(buf, promise);
return;
}
}
}

private void queueWrite(ByteBuf buf, ChannelPromise promise) {
final boolean queued = queuedWrites.offer(new WriteOperation(buf, promise));
assert queued;
}

Expand Down Expand Up @@ -85,7 +126,7 @@ private boolean doFlush(ChannelHandlerContext ctx) {
}
final WriteOperation write = currentWrite;
final int readableBytes = write.buf.readableBytes();
final int bufferSize = Math.min(readableBytes, 1 << 18);
final int bufferSize = Math.min(readableBytes, MAX_BYTES_PER_WRITE);
final int readerIndex = write.buf.readerIndex();
final boolean sliced = readableBytes != bufferSize;
final ByteBuf writeBuffer;
Expand All @@ -99,21 +140,9 @@ private boolean doFlush(ChannelHandlerContext ctx) {
needsFlush = true;
if (sliced == false) {
currentWrite = null;
writeFuture.addListener(future -> {
assert ctx.executor().inEventLoop();
if (future.isSuccess()) {
write.promise.trySuccess();
} else {
write.promise.tryFailure(future.cause());
}
});
writeFuture.addListener(forwardResultListener(ctx, write.promise));
} else {
writeFuture.addListener(future -> {
assert ctx.executor().inEventLoop();
if (future.isSuccess() == false) {
write.promise.tryFailure(future.cause());
}
});
writeFuture.addListener(forwardFailureListener(ctx, write.promise));
}
if (channel.isWritable() == false) {
// try flushing to make channel writable again, loop will only continue if channel becomes writable again
Expand All @@ -130,6 +159,26 @@ private boolean doFlush(ChannelHandlerContext ctx) {
return true;
}

private static GenericFutureListener<Future<Void>> forwardFailureListener(ChannelHandlerContext ctx, ChannelPromise promise) {
return future -> {
assert ctx.executor().inEventLoop();
if (future.isSuccess() == false) {
promise.tryFailure(future.cause());
}
};
}

private static GenericFutureListener<Future<Void>> forwardResultListener(ChannelHandlerContext ctx, ChannelPromise promise) {
return future -> {
assert ctx.executor().inEventLoop();
if (future.isSuccess()) {
promise.trySuccess();
} else {
promise.tryFailure(future.cause());
}
};
}

private void failQueuedWrites() {
if (currentWrite != null) {
final WriteOperation current = currentWrite;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.transport.netty4;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import org.junit.Before;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;

import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class Netty4WriteThrottlingHandlerTests extends ESTestCase {

private SharedGroupFactory.SharedGroup transportGroup;

@Before
public void createGroup() {
final SharedGroupFactory sharedGroupFactory = new SharedGroupFactory(Settings.EMPTY);
transportGroup = sharedGroupFactory.getTransportGroup();
}

@After
public void stopGroup() {
transportGroup.shutdown();
}

public void testThrottlesLargeMessage() throws ExecutionException, InterruptedException {
final List<ByteBuf> seen = new CopyOnWriteArrayList<>();
final CapturingHandler capturingHandler = new CapturingHandler(seen);
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(
capturingHandler,
new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY))
);
// we assume that the channel outbound buffer is smaller than Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE
final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable());
assertThat(writeableBytes, lessThan(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE));
final int fullSizeChunks = randomIntBetween(2, 10);
final int extraChunkSize = randomIntBetween(0, 10);
final ByteBuf message = Unpooled.wrappedBuffer(
randomByteArrayOfLength(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE * fullSizeChunks + extraChunkSize)
);
final ChannelPromise promise = embeddedChannel.newPromise();
transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get();
assertThat(seen, hasSize(1));
assertEquals(message.slice(0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE), seen.get(0));
assertFalse(promise.isDone());
transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get();
assertTrue(promise.isDone());
assertThat(seen, hasSize(fullSizeChunks + (extraChunkSize == 0 ? 0 : 1)));
assertTrue(capturingHandler.didWriteAfterThrottled);
if (extraChunkSize != 0) {
assertEquals(
message.slice(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE * fullSizeChunks, extraChunkSize),
seen.get(seen.size() - 1)
);
}
}

public void testPassesSmallMessageDirectly() throws ExecutionException, InterruptedException {
final List<ByteBuf> seen = new CopyOnWriteArrayList<>();
final CapturingHandler capturingHandler = new CapturingHandler(seen);
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(
capturingHandler,
new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY))
);
final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable());
assertThat(writeableBytes, lessThan(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE));
final ByteBuf message = Unpooled.wrappedBuffer(
randomByteArrayOfLength(randomIntBetween(0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE))
);
final ChannelPromise promise = embeddedChannel.newPromise();
transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get();
assertThat(seen, hasSize(1)); // first message should be passed through straight away
assertSame(message, seen.get(0));
assertFalse(promise.isDone());
transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get();
assertTrue(promise.isDone());
assertThat(seen, hasSize(1));
assertFalse(capturingHandler.didWriteAfterThrottled);
}

public void testThrottlesOnUnwritable() throws ExecutionException, InterruptedException {
final List<ByteBuf> seen = new CopyOnWriteArrayList<>();
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(
new CapturingHandler(seen),
new Netty4WriteThrottlingHandler(new ThreadContext(Settings.EMPTY))
);
final int writeableBytes = Math.toIntExact(embeddedChannel.bytesBeforeUnwritable());
assertThat(writeableBytes, lessThan(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE));
final ByteBuf message = Unpooled.wrappedBuffer(randomByteArrayOfLength(writeableBytes + randomIntBetween(0, 10)));
final ChannelPromise promise = embeddedChannel.newPromise();
transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(message, promise)).get();
assertThat(seen, hasSize(1)); // first message should be passed through straight away
assertSame(message, seen.get(0));
assertFalse(promise.isDone());
final ByteBuf messageToQueue = Unpooled.wrappedBuffer(
randomByteArrayOfLength(randomIntBetween(0, Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE))
);
final ChannelPromise promiseForQueued = embeddedChannel.newPromise();
transportGroup.getLowLevelGroup().submit(() -> embeddedChannel.write(messageToQueue, promiseForQueued)).get();
assertThat(seen, hasSize(1));
assertFalse(promiseForQueued.isDone());
assertFalse(promise.isDone());
transportGroup.getLowLevelGroup().submit(embeddedChannel::flush).get();
assertTrue(promise.isDone());
assertTrue(promiseForQueued.isDone());
}

private static class CapturingHandler extends ChannelOutboundHandlerAdapter {
private final List<ByteBuf> seen;

private boolean wasThrottled = false;

private boolean didWriteAfterThrottled = false;

CapturingHandler(List<ByteBuf> seen) {
this.seen = seen;
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assertTrue("should only write to writeable channel", ctx.channel().isWritable());
assertThat(msg, instanceOf(ByteBuf.class));
final ByteBuf buf = (ByteBuf) msg;
assertThat(buf.readableBytes(), lessThanOrEqualTo(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE));
seen.add(buf);
if (wasThrottled) {
didWriteAfterThrottled = true;
}
super.write(ctx, msg, promise);
if (ctx.channel().isWritable() == false) {
wasThrottled = true;
}
}
}
}

0 comments on commit d50d47c

Please sign in to comment.