Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to specify custom ArrayPool #1190

Merged
merged 3 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

using System;
using System.Collections.Generic;
using System.Buffers;
using System.Linq;
using System.Net.Security;
using System.Security.Authentication;
Expand Down Expand Up @@ -188,6 +189,15 @@ public sealed class ConnectionFactory : ConnectionFactoryBase, IAsyncConnectionF

// just here to hold the value that was set through the setter
private Uri _uri;
private readonly ArrayPool<byte> _memoryPool;

/// <summary>
/// The memory pool used for allocating buffers. Default is <see cref="MemoryPool{T}.Shared"/>.
/// </summary>
public ArrayPool<byte> MemoryPool
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
{
get => _memoryPool;
}

/// <summary>
/// Amount of time protocol handshake operations are allowed to take before
Expand Down Expand Up @@ -258,6 +268,18 @@ public TimeSpan ContinuationTimeout
public ConnectionFactory()
{
ClientProperties = Connection.DefaultClientProperties();
_memoryPool = ArrayPool<byte>.Shared;
}

/// <summary>
/// Construct a fresh instance, with all fields set to their respective defaults,
/// using your own memory pool.
/// <param name="memoryPool">Memory pool to use with all Connections</param>
/// </summary>
public ConnectionFactory(ArrayPool<byte> memoryPool)
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
{
ClientProperties = Connection.DefaultClientProperties();
_memoryPool = memoryPool;
}

/// <summary>
Expand Down Expand Up @@ -497,7 +519,8 @@ public IConnection CreateConnection(IEndpointResolver endpointResolver, string c
else
{
var protocol = new RabbitMQ.Client.Framing.Protocol();
conn = protocol.CreateConnection(this, false, endpointResolver.SelectOne(CreateFrameHandler), clientProvidedName);
conn = protocol.CreateConnection(this, false, endpointResolver.SelectOne(CreateFrameHandler),
_memoryPool, clientProvidedName);
}
}
catch (Exception e)
Expand All @@ -510,7 +533,7 @@ public IConnection CreateConnection(IEndpointResolver endpointResolver, string c

internal IFrameHandler CreateFrameHandler(AmqpTcpEndpoint endpoint)
{
IFrameHandler fh = Protocols.DefaultProtocol.CreateFrameHandler(endpoint, SocketFactory,
IFrameHandler fh = Protocols.DefaultProtocol.CreateFrameHandler(endpoint, _memoryPool, SocketFactory,
RequestedConnectionTimeout, SocketReadTimeout, SocketWriteTimeout);
return ConfigureFrameHandler(fh);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
IBasicProperties basicProperties,
ReadOnlySpan<byte> body)
{
byte[] bodyBytes = ArrayPool<byte>.Shared.Rent(body.Length);
var pool = _model.Session.Connection.MemoryPool;
byte[] bodyBytes = pool.Rent(body.Length);
Memory<byte> bodyCopy = new Memory<byte>(bodyBytes, 0, body.Length);
body.CopyTo(bodyCopy.Span);
ScheduleUnlessShuttingDown(new BasicDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, bodyCopy));
ScheduleUnlessShuttingDown(new BasicDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, bodyCopy, pool));
}

public void HandleBasicCancelOk(IBasicConsumer consumer, string consumerTag)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,7 @@ private void Init(IFrameHandler fh)
throw new ObjectDisposedException(GetType().FullName);
}

_delegate = new Connection(_factory, false,
fh, ClientProvidedName);
_delegate = new Connection(_factory, false, fh, _factory.MemoryPool, ClientProvidedName);

_recoveryTask = Task.Run(MainRecoveryLoop);

Expand Down Expand Up @@ -1017,7 +1016,7 @@ private bool TryRecoverConnectionDelegate()
try
{
IFrameHandler fh = _endpoints.SelectOne(_factory.CreateFrameHandler);
_delegate = new Connection(_factory, false, fh, ClientProvidedName);
_delegate = new Connection(_factory, false, fh, _factory.MemoryPool, ClientProvidedName);
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
return true;
}
catch (Exception e)
Expand Down
7 changes: 5 additions & 2 deletions projects/RabbitMQ.Client/client/impl/BasicDeliver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ internal sealed class BasicDeliver : Work
private readonly string _routingKey;
private readonly IBasicProperties _basicProperties;
private readonly ReadOnlyMemory<byte> _body;
private readonly ArrayPool<byte> _bodyOwner;

public override string Context => "HandleBasicDeliver";

Expand All @@ -24,7 +25,8 @@ public BasicDeliver(IBasicConsumer consumer,
string exchange,
string routingKey,
IBasicProperties basicProperties,
ReadOnlyMemory<byte> body) : base(consumer)
ReadOnlyMemory<byte> body,
ArrayPool<byte> pool) : base(consumer)
{
_consumerTag = consumerTag;
_deliveryTag = deliveryTag;
Expand All @@ -33,6 +35,7 @@ public BasicDeliver(IBasicConsumer consumer,
_routingKey = routingKey;
_basicProperties = basicProperties;
_body = body;
_bodyOwner = pool;
}

protected override Task Execute(IAsyncBasicConsumer consumer)
Expand All @@ -50,7 +53,7 @@ public override void PostExecute()
{
if (MemoryMarshal.TryGetArray(_body, out ArraySegment<byte> segment))
{
ArrayPool<byte>.Shared.Return(segment.Array);
_bodyOwner.Return(segment.Array);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions projects/RabbitMQ.Client/client/impl/CommandAssembler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public IncomingCommand HandleFrame(in InboundFrame frame)
return IncomingCommand.Empty;
}

var result = new IncomingCommand(_method, _header, _body, _bodyBytes);
var result = new IncomingCommand(_method, _header, _body, _bodyBytes, _protocol.MemoryPool);
Reset();
return result;
}
Expand Down Expand Up @@ -123,7 +123,7 @@ private void ParseHeaderFrame(in InboundFrame frame)
_remainingBodyBytes = (int) totalBodyBytes;

// Is returned by IncomingCommand.Dispose in Session.HandleFrame
_bodyBytes = ArrayPool<byte>.Shared.Rent(_remainingBodyBytes);
_bodyBytes = _protocol.MemoryPool.Rent(_remainingBodyBytes);
_body = new Memory<byte>(_bodyBytes, 0, _remainingBodyBytes);
UpdateContentBodyState();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
IBasicProperties basicProperties,
ReadOnlySpan<byte> body)
{
byte[] memoryCopyArray = ArrayPool<byte>.Shared.Rent(body.Length);
var pool = _model.Session.Connection.MemoryPool;
byte[] memoryCopyArray = pool.Rent(body.Length);
Memory<byte> memoryCopy = new Memory<byte>(memoryCopyArray, 0, body.Length);
body.CopyTo(memoryCopy.Span);
UnlessShuttingDown(() =>
Expand All @@ -90,7 +91,7 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
}
finally
{
ArrayPool<byte>.Shared.Return(memoryCopyArray);
pool.Return(memoryCopyArray);
}
});
}
Expand Down
17 changes: 15 additions & 2 deletions projects/RabbitMQ.Client/client/impl/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ internal sealed class Connection : IConnection
private volatile bool _running = true;
private readonly MainSession _session0;
private SessionManager _sessionManager;
private readonly ArrayPool<byte> _memoryPool = ArrayPool<byte>.Shared;

//
// Heartbeats
Expand Down Expand Up @@ -127,6 +128,18 @@ public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHa
}
}

public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHandler, ArrayPool<byte> memoryPool,
string clientProvidedName = null)
: this(factory, insist, frameHandler, clientProvidedName)
{
_memoryPool = memoryPool;
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
}

internal ArrayPool<byte> MemoryPool
{
get => _memoryPool;
}

public Guid Id { get { return _id; } }

public event EventHandler<CallbackExceptionEventArgs> CallbackException;
Expand Down Expand Up @@ -908,7 +921,7 @@ public void HeartbeatWriteTimerCallback(object state)
{
if (!_closed)
{
Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame());
Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame(MemoryPool));
_heartbeatWriteTimer?.Change((int)_heartbeatTimeSpan.TotalMilliseconds, Timeout.Infinite);
}
}
Expand Down Expand Up @@ -945,7 +958,7 @@ public override string ToString()
return string.Format("Connection({0},{1})", _id, Endpoint);
}

public void Write(Memory<byte> memory)
public void Write(ReadOnlyMemory<byte> memory)
{
_frameHandler.Write(memory);
}
Expand Down
22 changes: 12 additions & 10 deletions projects/RabbitMQ.Client/client/impl/Frame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ internal static class Heartbeat
Constants.FrameEnd
};

public static Memory<byte> GetHeartbeatFrame()
public static ReadOnlyMemory<byte> GetHeartbeatFrame(ArrayPool<byte> pool)
{
// Is returned by SocketFrameHandler.WriteLoop
var buffer = ArrayPool<byte>.Shared.Rent(FrameSize);
var buffer = pool.Rent(FrameSize);
lukebakken marked this conversation as resolved.
Show resolved Hide resolved
Payload.CopyTo(buffer);
return new Memory<byte>(buffer, 0, FrameSize);
return new ReadOnlyMemory<byte>(buffer, 0, FrameSize);
}
}
}
Expand All @@ -163,13 +163,15 @@ public static Memory<byte> GetHeartbeatFrame()
public readonly int Channel;
public readonly ReadOnlyMemory<byte> Payload;
private readonly byte[] _rentedArray;
private readonly ArrayPool<byte> _rentedArrayOwner;

private InboundFrame(FrameType type, int channel, ReadOnlyMemory<byte> payload, byte[] rentedArray)
private InboundFrame(FrameType type, int channel, ReadOnlyMemory<byte> payload, byte[] rentedArray, ArrayPool<byte> rentedArrayOwner)
{
Type = type;
Channel = channel;
Payload = payload;
_rentedArray = rentedArray;
_rentedArrayOwner = rentedArrayOwner;
}

private static void ProcessProtocolHeader(Stream reader)
Expand Down Expand Up @@ -203,7 +205,7 @@ private static void ProcessProtocolHeader(Stream reader)
}
}

internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer, ArrayPool<byte> pool)
{
int type = default;
try
Expand Down Expand Up @@ -242,7 +244,7 @@ internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
const int EndMarkerLength = 1;
// Is returned by InboundFrame.Dispose in Connection.MainLoopIteration
var readSize = payloadSize + EndMarkerLength;
byte[] payloadBytes = ArrayPool<byte>.Shared.Rent(readSize);
byte[] payloadBytes = pool.Rent(readSize);
int bytesRead = 0;
try
{
Expand All @@ -254,22 +256,22 @@ internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
catch (Exception)
{
// Early EOF.
ArrayPool<byte>.Shared.Return(payloadBytes);
pool.Return(payloadBytes);
throw new MalformedFrameException($"Short frame - expected to read {readSize} bytes, only got {bytesRead} bytes");
}

if (payloadBytes[payloadSize] != Constants.FrameEnd)
{
ArrayPool<byte>.Shared.Return(payloadBytes);
pool.Return(payloadBytes);
throw new MalformedFrameException($"Bad frame end marker: {payloadBytes[payloadSize]}");
}

return new InboundFrame((FrameType)type, channel, new Memory<byte>(payloadBytes, 0, payloadSize), payloadBytes);
return new InboundFrame((FrameType)type, channel, new Memory<byte>(payloadBytes, 0, payloadSize), payloadBytes, pool);
}

public void Dispose()
{
ArrayPool<byte>.Shared.Return(_rentedArray);
_rentedArrayOwner.Return(_rentedArray);
}

public override string ToString()
Expand Down
2 changes: 1 addition & 1 deletion projects/RabbitMQ.Client/client/impl/IFrameHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ interface IFrameHandler

void SendHeader();

void Write(Memory<byte> memory);
void Write(ReadOnlyMemory<byte> memory);
}
}
9 changes: 6 additions & 3 deletions projects/RabbitMQ.Client/client/impl/IProtocolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
//---------------------------------------------------------------------------

using System;

using System.Buffers;
using System.Net.Sockets;

using RabbitMQ.Client.Impl;

namespace RabbitMQ.Client.Framing.Impl
Expand All @@ -42,12 +41,16 @@ static class IProtocolExtensions
public static IFrameHandler CreateFrameHandler(
this IProtocol protocol,
AmqpTcpEndpoint endpoint,
ArrayPool<byte> pool,
Func<AddressFamily, ITcpClient> socketFactory,
TimeSpan connectionTimeout,
TimeSpan readTimeout,
TimeSpan writeTimeout)
{
return new SocketFrameHandler(endpoint, socketFactory, connectionTimeout, readTimeout, writeTimeout);
return new SocketFrameHandler(endpoint, socketFactory, connectionTimeout, readTimeout, writeTimeout)
{
MemoryPool = pool
};
}
}
}
6 changes: 4 additions & 2 deletions projects/RabbitMQ.Client/client/impl/IncomingCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ namespace RabbitMQ.Client.Impl
public readonly ContentHeaderBase Header;
public readonly ReadOnlyMemory<byte> Body;
private readonly byte[] _rentedArray;
private readonly ArrayPool<byte> _rentedArrayOwner;

public bool IsEmpty => Method is null;

public IncomingCommand(MethodBase method, ContentHeaderBase header, ReadOnlyMemory<byte> body, byte[] rentedArray)
public IncomingCommand(MethodBase method, ContentHeaderBase header, ReadOnlyMemory<byte> body, byte[] rentedArray, ArrayPool<byte> rentedArrayOwner)
{
Method = method;
Header = header;
Body = body;
_rentedArray = rentedArray;
_rentedArrayOwner = rentedArrayOwner;
}

public void Dispose()
{
if (_rentedArray != null)
{
ArrayPool<byte>.Shared.Return(_rentedArray);
_rentedArrayOwner.Return(_rentedArray);
}
}

Expand Down
2 changes: 1 addition & 1 deletion projects/RabbitMQ.Client/client/impl/ModelBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ abstract class ModelBase : IFullModel, IRecoverable

private bool _onlyAcksReceived = true;

public IConsumerDispatcher ConsumerDispatcher { get; private set; }
public IConsumerDispatcher ConsumerDispatcher { get; }

public ModelBase(ISession session) : this(session, session.Connection.ConsumerWorkService)
{ }
Expand Down
4 changes: 2 additions & 2 deletions projects/RabbitMQ.Client/client/impl/OutgoingCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ internal void Transmit(ushort channelNumber, Connection connection)
int maxBodyPayloadBytes = (int)(connection.FrameMax == 0 ? int.MaxValue : connection.FrameMax - EmptyFrameSize);
var size = GetMaxSize(maxBodyPayloadBytes);

// Will be returned by SocketFrameWriter.WriteLoop
var memory = new Memory<byte>(ArrayPool<byte>.Shared.Rent(size), 0, size);
// Will be returned by SocketFrameHandler.WriteLoop
var memory = new Memory<byte>(connection.MemoryPool.Rent(size), 0, size);
var span = memory.Span;

var offset = Framing.Method.WriteTo(span, channelNumber, Method);
Expand Down
Loading