diff --git a/projects/RabbitMQ.Client/FrameworkExtension/Interlocked.cs b/projects/RabbitMQ.Client/FrameworkExtension/Interlocked.cs deleted file mode 100644 index 7106be7059..0000000000 --- a/projects/RabbitMQ.Client/FrameworkExtension/Interlocked.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Note: -// The code in this file is inspired by the code in `dotnet/runtime`, in this file: -// src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Interlocked.cs -// -// The license from that file is as follows: -// -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Runtime.CompilerServices; - -namespace RabbitMQ.Client -{ -#if NETSTANDARD - // TODO GH-1308 class should be called "Interlocked" to be used by other code - internal static class InterlockedExtensions - { - public static ulong CompareExchange(ref ulong location1, ulong value, ulong comparand) - { - return (ulong)System.Threading.Interlocked.CompareExchange(ref Unsafe.As(ref location1), (long)value, (long)comparand); - } - - public static ulong Increment(ref ulong location1) - { - return (ulong)System.Threading.Interlocked.Add(ref Unsafe.As(ref location1), 1L); - } - } -#endif -} diff --git a/projects/RabbitMQ.Client/PublicAPI.Unshipped.txt b/projects/RabbitMQ.Client/PublicAPI.Unshipped.txt index d470b54eb7..e8e5583aa6 100644 --- a/projects/RabbitMQ.Client/PublicAPI.Unshipped.txt +++ b/projects/RabbitMQ.Client/PublicAPI.Unshipped.txt @@ -918,7 +918,7 @@ virtual RabbitMQ.Client.TcpClientAdapter.ReceiveTimeout.set -> void ~RabbitMQ.Client.IChannel.TxCommitAsync() -> System.Threading.Tasks.Task ~RabbitMQ.Client.IChannel.TxRollbackAsync() -> System.Threading.Tasks.Task ~RabbitMQ.Client.IChannel.TxSelectAsync() -> System.Threading.Tasks.Task -~RabbitMQ.Client.IConnection.CloseAsync(ushort reasonCode, string reasonText, System.TimeSpan timeout, bool abort) -> System.Threading.Tasks.Task +~RabbitMQ.Client.IConnection.CloseAsync(ushort reasonCode, string reasonText, System.TimeSpan timeout, bool abort, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task ~RabbitMQ.Client.IConnection.CreateChannelAsync() -> System.Threading.Tasks.Task ~RabbitMQ.Client.IConnection.UpdateSecretAsync(string newSecret, string reason) -> System.Threading.Tasks.Task ~RabbitMQ.Client.IConnectionFactory.CreateConnectionAsync(string clientProvidedName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/projects/RabbitMQ.Client/client/TaskExtensions.cs b/projects/RabbitMQ.Client/client/TaskExtensions.cs index b0328875af..d8458d2dc4 100644 --- a/projects/RabbitMQ.Client/client/TaskExtensions.cs +++ b/projects/RabbitMQ.Client/client/TaskExtensions.cs @@ -53,14 +53,67 @@ public static bool IsCompletedSuccessfully(this Task task) private static readonly TaskContinuationOptions s_tco = TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously; private static void IgnoreTaskContinuation(Task t, object s) => t.Exception.Handle(e => true); - public static async Task WithCancellation(this Task task, CancellationToken cancellationToken) + // https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/ + public static Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken) { - var tcs = new TaskCompletionSource(); + if (task.IsCompletedSuccessfully()) + { + return task; + } + else + { + return DoWaitWithTimeoutAsync(task, timeout, cancellationToken); + } + } + + private static async Task DoWaitWithTimeoutAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken) + { + using var timeoutTokenCts = new CancellationTokenSource(timeout); + CancellationToken timeoutToken = timeoutTokenCts.Token; + + var linkedTokenTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, cancellationToken); + using CancellationTokenRegistration cancellationTokenRegistration = + linkedCts.Token.Register(s => ((TaskCompletionSource)s).TrySetResult(true), + state: linkedTokenTcs, useSynchronizationContext: false); + + if (task != await Task.WhenAny(task, linkedTokenTcs.Task).ConfigureAwait(false)) + { + task.Ignore(); + if (timeoutToken.IsCancellationRequested) + { + throw new OperationCanceledException($"Operation timed out after {timeout}"); + } + else + { + throw new OperationCanceledException(cancellationToken); + } + } - // https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/ - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetResult(true), tcs)) + await task.ConfigureAwait(false); + } + + // https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/ + public static Task WaitAsync(this Task task, CancellationToken cancellationToken) + { + if (task.IsCompletedSuccessfully()) { - if (task != await Task.WhenAny(task, tcs.Task)) + return task; + } + else + { + return DoWaitAsync(task, cancellationToken); + } + } + + private static async Task DoWaitAsync(this Task task, CancellationToken cancellationToken) + { + var cancellationTokenTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetResult(true), + state: cancellationTokenTcs, useSynchronizationContext: false)) + { + if (task != await Task.WhenAny(task, cancellationTokenTcs.Task).ConfigureAwait(false)) { task.Ignore(); throw new OperationCanceledException(cancellationToken); @@ -172,10 +225,13 @@ public static T EnsureCompleted(this ValueTask task) public static void EnsureCompleted(this ValueTask task) { - task.GetAwaiter().GetResult(); + if (false == task.IsCompletedSuccessfully) + { + task.GetAwaiter().GetResult(); + } } -#if !NET6_0_OR_GREATER +#if NETSTANDARD // https://github.com/dotnet/runtime/issues/23878 // https://github.com/dotnet/runtime/issues/23878#issuecomment-1398958645 public static void Ignore(this Task task) diff --git a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs index 3cc437b4e4..036493976f 100644 --- a/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs +++ b/projects/RabbitMQ.Client/client/api/ConnectionFactory.cs @@ -618,6 +618,7 @@ private ConnectionConfig CreateConfig(string clientProvidedName) internal async Task CreateFrameHandlerAsync( AmqpTcpEndpoint endpoint, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); IFrameHandler fh = new SocketFrameHandler(endpoint, SocketFactory, RequestedConnectionTimeout, SocketReadTimeout, SocketWriteTimeout); await fh.ConnectAsync(cancellationToken) .ConfigureAwait(false); diff --git a/projects/RabbitMQ.Client/client/api/IConnection.cs b/projects/RabbitMQ.Client/client/api/IConnection.cs index d7fec6358c..2e3c44bc67 100644 --- a/projects/RabbitMQ.Client/client/api/IConnection.cs +++ b/projects/RabbitMQ.Client/client/api/IConnection.cs @@ -31,6 +31,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Events; using RabbitMQ.Client.Exceptions; @@ -222,9 +223,10 @@ public interface IConnection : INetworkConnection, IDisposable /// /// The close code (See under "Reply Codes" in the AMQP 0-9-1 specification). /// A message indicating the reason for closing the connection. - /// Operation timeout. + /// /// Whether or not this close is an abort (ignores certain exceptions). - Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort); + /// + Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort, CancellationToken cancellationToken = default); /// /// Asynchronously create and return a fresh channel, session, and channel. diff --git a/projects/RabbitMQ.Client/client/api/IConnectionExtensions.cs b/projects/RabbitMQ.Client/client/api/IConnectionExtensions.cs index 6f879fae0e..58c97c17ea 100644 --- a/projects/RabbitMQ.Client/client/api/IConnectionExtensions.cs +++ b/projects/RabbitMQ.Client/client/api/IConnectionExtensions.cs @@ -20,7 +20,8 @@ public static class IConnectionExtensions /// public static Task CloseAsync(this IConnection connection) { - return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", InternalConstants.DefaultConnectionCloseTimeout, false); + return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", InternalConstants.DefaultConnectionCloseTimeout, false, + CancellationToken.None); } /// @@ -38,7 +39,8 @@ public static Task CloseAsync(this IConnection connection) /// public static Task CloseAsync(this IConnection connection, ushort reasonCode, string reasonText) { - return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionCloseTimeout, false); + return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionCloseTimeout, false, + CancellationToken.None); } /// @@ -58,7 +60,8 @@ public static Task CloseAsync(this IConnection connection, ushort reasonCode, st /// public static Task CloseAsync(this IConnection connection, TimeSpan timeout) { - return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", timeout, false); + return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", timeout, false, + CancellationToken.None); } /// @@ -80,7 +83,8 @@ public static Task CloseAsync(this IConnection connection, TimeSpan timeout) /// public static Task CloseAsync(this IConnection connection, ushort reasonCode, string reasonText, TimeSpan timeout) { - return connection.CloseAsync(reasonCode, reasonText, timeout, false); + return connection.CloseAsync(reasonCode, reasonText, timeout, false, + CancellationToken.None); } /// @@ -94,7 +98,8 @@ public static Task CloseAsync(this IConnection connection, ushort reasonCode, st /// public static Task AbortAsync(this IConnection connection) { - return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", InternalConstants.DefaultConnectionAbortTimeout, true); + return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", InternalConstants.DefaultConnectionAbortTimeout, true, + CancellationToken.None); } /// @@ -112,7 +117,8 @@ public static Task AbortAsync(this IConnection connection) /// public static Task AbortAsync(this IConnection connection, ushort reasonCode, string reasonText) { - return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionAbortTimeout, true); + return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionAbortTimeout, true, + CancellationToken.None); } /// @@ -130,7 +136,8 @@ public static Task AbortAsync(this IConnection connection, ushort reasonCode, st /// public static Task AbortAsync(this IConnection connection, TimeSpan timeout) { - return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", timeout, true); + return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", timeout, true, + CancellationToken.None); } /// @@ -149,7 +156,8 @@ public static Task AbortAsync(this IConnection connection, TimeSpan timeout) /// public static Task AbortAsync(this IConnection connection, ushort reasonCode, string reasonText, TimeSpan timeout) { - return connection.CloseAsync(reasonCode, reasonText, timeout, true); + return connection.CloseAsync(reasonCode, reasonText, timeout, true, + CancellationToken.None); } } } diff --git a/projects/RabbitMQ.Client/client/api/IEndpointResolverExtensions.cs b/projects/RabbitMQ.Client/client/api/IEndpointResolverExtensions.cs index 6f0ca8e7d5..75c4da2d6f 100644 --- a/projects/RabbitMQ.Client/client/api/IEndpointResolverExtensions.cs +++ b/projects/RabbitMQ.Client/client/api/IEndpointResolverExtensions.cs @@ -45,6 +45,7 @@ public static async Task SelectOneAsync(this IEndpointResolver resolver, var exceptions = new List(); foreach (AmqpTcpEndpoint ep in resolver.All()) { + cancellationToken.ThrowIfCancellationRequested(); try { t = await selector(ep, cancellationToken).ConfigureAwait(false); diff --git a/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs b/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs index e364fb1ed4..8db8fa639b 100644 --- a/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs +++ b/projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs @@ -29,7 +29,7 @@ public virtual Task ConnectAsync(IPAddress ep, int port, CancellationToken cance #else public virtual Task ConnectAsync(IPAddress ep, int port, CancellationToken cancellationToken = default) { - return _sock.ConnectAsync(ep, port).WithCancellation(cancellationToken); + return _sock.ConnectAsync(ep, port).WaitAsync(cancellationToken); } #endif diff --git a/projects/RabbitMQ.Client/client/events/EventingBasicConsumer.cs b/projects/RabbitMQ.Client/client/events/EventingBasicConsumer.cs index 153ff7bb8b..ca57918fce 100644 --- a/projects/RabbitMQ.Client/client/events/EventingBasicConsumer.cs +++ b/projects/RabbitMQ.Client/client/events/EventingBasicConsumer.cs @@ -92,7 +92,8 @@ public override async Task HandleBasicDeliverAsync(string consumerTag, ulong del BasicDeliverEventArgs eventArgs = new BasicDeliverEventArgs(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body); using (Activity activity = RabbitMQActivitySource.SubscriberHasListeners ? RabbitMQActivitySource.Deliver(eventArgs) : default) { - await base.HandleBasicDeliverAsync(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body); + await base.HandleBasicDeliverAsync(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body) + .ConfigureAwait(false); Received?.Invoke(this, eventArgs); } } diff --git a/projects/RabbitMQ.Client/client/framing/Channel.cs b/projects/RabbitMQ.Client/client/framing/Channel.cs index 129c4ee809..75cb3989ca 100644 --- a/projects/RabbitMQ.Client/client/framing/Channel.cs +++ b/projects/RabbitMQ.Client/client/framing/Channel.cs @@ -29,6 +29,8 @@ // Copyright (c) 2007-2020 VMware, Inc. All rights reserved. //--------------------------------------------------------------------------- +using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.client.framing; using RabbitMQ.Client.Impl; @@ -69,19 +71,22 @@ public override void _Private_ConnectionCloseOk() public override ValueTask BasicAckAsync(ulong deliveryTag, bool multiple) { var method = new BasicAck(deliveryTag, multiple); - return ModelSendAsync(method); + // TODO cancellation token? + return ModelSendAsync(method, CancellationToken.None); } public override ValueTask BasicNackAsync(ulong deliveryTag, bool multiple, bool requeue) { var method = new BasicNack(deliveryTag, multiple, requeue); - return ModelSendAsync(method); + // TODO use cancellation token + return ModelSendAsync(method, CancellationToken.None); } public override Task BasicRejectAsync(ulong deliveryTag, bool requeue) { var method = new BasicReject(deliveryTag, requeue); - return ModelSendAsync(method).AsTask(); + // TODO cancellation token? + return ModelSendAsync(method, CancellationToken.None).AsTask(); } protected override bool DispatchAsynchronous(in IncomingCommand cmd) diff --git a/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs b/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs index b9e2d0274d..00f2a272c5 100644 --- a/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs +++ b/projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs @@ -44,8 +44,8 @@ namespace RabbitMQ.Client.Impl internal abstract class AsyncRpcContinuation : IRpcContinuation, IDisposable { private readonly CancellationTokenSource _cancellationTokenSource; + private readonly CancellationTokenRegistration _cancellationTokenRegistration; private readonly ConfiguredTaskAwaitable _tcsConfiguredTaskAwaitable; - protected readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private bool _disposedValue; @@ -59,21 +59,43 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout) */ _cancellationTokenSource = new CancellationTokenSource(continuationTimeout); - _cancellationTokenSource.Token.Register(() => +#if NET6_0_OR_GREATER + _cancellationTokenRegistration = _cancellationTokenSource.Token.UnsafeRegister((object state) => + { + var tcs = (TaskCompletionSource)state; + if (tcs.TrySetCanceled()) + { + // TODO LRB rabbitmq/rabbitmq-dotnet-client#1347 + // Cancellation was successful, does this mean we should set a TimeoutException + // in the same manner as BlockingCell? + } + }, _tcs); +#else + _cancellationTokenRegistration = _cancellationTokenSource.Token.Register((object state) => { - if (_tcs.TrySetCanceled()) + var tcs = (TaskCompletionSource)state; + if (tcs.TrySetCanceled()) { // TODO LRB rabbitmq/rabbitmq-dotnet-client#1347 // Cancellation was successful, does this mean we should set a TimeoutException // in the same manner as BlockingCell? } - }, useSynchronizationContext: false); + }, state: _tcs, useSynchronizationContext: false); +#endif _tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false); } internal DateTime StartTime { get; } = DateTime.UtcNow; + public CancellationToken CancellationToken + { + get + { + return _cancellationTokenSource.Token; + } + } + public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter GetAwaiter() { return _tcsConfiguredTaskAwaitable.GetAwaiter(); @@ -92,6 +114,7 @@ protected virtual void Dispose(bool disposing) { if (disposing) { + _cancellationTokenRegistration.Dispose(); _cancellationTokenSource.Dispose(); } diff --git a/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.Recovery.cs b/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.Recovery.cs index d927592472..9ad44677e1 100644 --- a/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.Recovery.cs +++ b/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.Recovery.cs @@ -44,10 +44,7 @@ namespace RabbitMQ.Client.Framing.Impl internal sealed partial class AutorecoveringConnection { private Task? _recoveryTask; - private CancellationTokenSource? _recoveryCancellationTokenSource; - - // TODO dispose the CTS - private CancellationTokenSource RecoveryCancellationTokenSource => _recoveryCancellationTokenSource ??= new CancellationTokenSource(); + private readonly CancellationTokenSource _recoveryCancellationTokenSource = new CancellationTokenSource(); private void HandleConnectionShutdown(object _, ShutdownEventArgs args) { @@ -60,18 +57,36 @@ private void HandleConnectionShutdown(object _, ShutdownEventArgs args) } } - static bool ShouldTriggerConnectionRecovery(ShutdownEventArgs args) => - args.Initiator == ShutdownInitiator.Peer || + static bool ShouldTriggerConnectionRecovery(ShutdownEventArgs args) + { + if (args.Initiator == ShutdownInitiator.Peer) + { + if (args.ReplyCode == Constants.AccessRefused) + { + return false; + } + else + { + return true; + } + } + // happens when EOF is reached, e.g. due to RabbitMQ node // connectivity loss or abrupt shutdown - args.Initiator == ShutdownInitiator.Library; + if (args.Initiator == ShutdownInitiator.Library) + { + return true; + } + + return false; + } } private async Task RecoverConnectionAsync() { try { - CancellationToken token = RecoveryCancellationTokenSource.Token; + CancellationToken token = _recoveryCancellationTokenSource.Token; bool success; do { @@ -79,7 +94,7 @@ await Task.Delay(_config.NetworkRecoveryInterval, token) .ConfigureAwait(false); success = await TryPerformAutomaticRecoveryAsync(token) .ConfigureAwait(false); - } while (!success && !token.IsCancellationRequested); + } while (false == success && false == token.IsCancellationRequested); } catch (OperationCanceledException) { @@ -96,44 +111,32 @@ await Task.Delay(_config.NetworkRecoveryInterval, token) } } - /// - /// Cancels the main recovery loop and will block until the loop finishes, or the timeout - /// expires, to prevent Close operations overlapping with recovery operations. - /// - private void StopRecoveryLoop() - { - Task? task = _recoveryTask; - if (task is null) - { - return; - } - RecoveryCancellationTokenSource.Cancel(); - - Task timeout = Task.Delay(_config.RequestedConnectionTimeout); - if (Task.WhenAny(task, timeout).Result == timeout) - { - ESLog.Warn("Timeout while trying to stop background AutorecoveringConnection recovery loop."); - } - } - /// /// Async cancels the main recovery loop and will block until the loop finishes, or the timeout /// expires, to prevent Close operations overlapping with recovery operations. /// - private async ValueTask StopRecoveryLoopAsync() + private async ValueTask StopRecoveryLoopAsync(CancellationToken cancellationToken) { Task? task = _recoveryTask; if (task != null) { - RecoveryCancellationTokenSource.Cancel(); + _recoveryCancellationTokenSource.Cancel(); + using var timeoutTokenSource = new CancellationTokenSource(_config.RequestedConnectionTimeout); + using var lts = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken); try { - await TaskExtensions.WaitAsync(task, _config.RequestedConnectionTimeout) - .ConfigureAwait(false); + await task.WaitAsync(lts.Token).ConfigureAwait(false); } - catch (TimeoutException) + catch (OperationCanceledException) { - ESLog.Warn("Timeout while trying to stop background AutorecoveringConnection recovery loop."); + if (timeoutTokenSource.Token.IsCancellationRequested) + { + ESLog.Warn("Timeout while trying to stop background AutorecoveringConnection recovery loop."); + } + else + { + throw; + } } } } diff --git a/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.cs b/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.cs index fa6a01e0e2..02afa9b107 100644 --- a/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.cs +++ b/projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.cs @@ -196,7 +196,7 @@ public override string ToString() internal Task CloseFrameHandlerAsync() { - return InnerConnection.FrameHandler.CloseAsync(); + return InnerConnection.FrameHandler.CloseAsync(CancellationToken.None); } ///API-side invocation of updating the secret. @@ -208,15 +208,44 @@ public Task UpdateSecretAsync(string newSecret, string reason) } ///Asynchronous API-side invocation of connection.close with timeout. - public async Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort) + public async Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort, + CancellationToken cancellationToken = default) { ThrowIfDisposed(); - await StopRecoveryLoopAsync() - .ConfigureAwait(false); - if (_innerConnection.IsOpen) + + Task CloseInnerConnectionAsync() { - await _innerConnection.CloseAsync(reasonCode, reasonText, timeout, abort) + if (_innerConnection.IsOpen) + { + return _innerConnection.CloseAsync(reasonCode, reasonText, timeout, abort, cancellationToken); + } + else + { + return Task.CompletedTask; + } + } + + try + { + await StopRecoveryLoopAsync(cancellationToken) .ConfigureAwait(false); + + await CloseInnerConnectionAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + try + { + await CloseInnerConnectionAsync() + .ConfigureAwait(false); + } + catch (Exception innerConnectionException) + { + throw new AggregateException(ex, innerConnectionException); + } + + throw; } } @@ -255,9 +284,10 @@ public void Dispose() { _channels.Clear(); _innerConnection = null; - _disposed = true; _recordedEntitiesSemaphore.Dispose(); _channelsSemaphore.Dispose(); + _recoveryCancellationTokenSource.Dispose(); + _disposed = true; } } diff --git a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs index 1e893cf2b2..9b1f09438f 100644 --- a/projects/RabbitMQ.Client/client/impl/ChannelBase.cs +++ b/projects/RabbitMQ.Client/client/impl/ChannelBase.cs @@ -266,7 +266,8 @@ public Task CloseAsync(ushort replyCode, string replyText, bool abort) public async Task CloseAsync(ShutdownEventArgs args, bool abort) { using var k = new ChannelCloseAsyncRpcContinuation(ContinuationTimeout); - await _rpcSemaphore.WaitAsync() + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -278,13 +279,14 @@ await _rpcSemaphore.WaitAsync() { var method = new ChannelClose( args.ReplyCode, args.ReplyText, args.ClassId, args.MethodId); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } bool result = await k; Debug.Assert(result); + // TODO cancellation token? await ConsumerDispatcher.WaitForShutdownAsync() .ConfigureAwait(false); } @@ -316,28 +318,28 @@ await ConsumerDispatcher.WaitForShutdownAsync() } } - internal async ValueTask ConnectionOpenAsync(string virtualHost, CancellationToken _) + internal async ValueTask ConnectionOpenAsync(string virtualHost, CancellationToken cancellationToken) { + using var timeoutTokenSource = new CancellationTokenSource(HandshakeContinuationTimeout); + using var lts = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken); var m = new ConnectionOpen(virtualHost); - // TODO linked cancellation token - await ModelSendAsync(m) - .TimeoutAfter(HandshakeContinuationTimeout) - .ConfigureAwait(false); + await ModelSendAsync(m, lts.Token).ConfigureAwait(false); } internal async ValueTask ConnectionSecureOkAsync(byte[] response) { - await _rpcSemaphore.WaitAsync() + using var k = new ConnectionSecureOrTuneAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new ConnectionSecureOrTuneAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); try { var method = new ConnectionSecureOk(response); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } catch (AlreadyClosedException) @@ -359,17 +361,18 @@ internal async ValueTask ConnectionStartOkAsync( IDictionary clientProperties, string mechanism, byte[] response, string locale) { - await _rpcSemaphore.WaitAsync() + using var k = new ConnectionSecureOrTuneAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new ConnectionSecureOrTuneAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); try { var method = new ConnectionStartOk(clientProperties, mechanism, response, locale); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } catch (AlreadyClosedException) @@ -403,15 +406,16 @@ protected void Enqueue(IRpcContinuation k) internal async Task OpenAsync() { - await _rpcSemaphore.WaitAsync() - .ConfigureAwait(false); using var k = new ChannelOpenAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) + .ConfigureAwait(false); try { Enqueue(k); var method = new ChannelOpen(); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -426,7 +430,7 @@ await ModelSendAsync(method) internal void FinishClose() { - var reason = CloseReason; + ShutdownEventArgs reason = CloseReason; if (reason != null) { Session.Close(reason); @@ -508,9 +512,9 @@ protected void ChannelSend(in T method) where T : struct, IOutgoingAmqpMethod } [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected ValueTask ModelSendAsync(in T method) where T : struct, IOutgoingAmqpMethod + protected ValueTask ModelSendAsync(in T method, CancellationToken cancellationToken) where T : struct, IOutgoingAmqpMethod { - return Session.TransmitAsync(in method); + return Session.TransmitAsync(in method, cancellationToken); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -527,8 +531,7 @@ protected void ChannelSend(in TMethod method, in THeader heade } [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected ValueTask ModelSendAsync(in TMethod method, in THeader header, - ReadOnlyMemory body) + protected ValueTask ModelSendAsync(in TMethod method, in THeader header, ReadOnlyMemory body, CancellationToken cancellationToken) where TMethod : struct, IOutgoingAmqpMethod where THeader : IAmqpHeader { @@ -537,7 +540,7 @@ protected ValueTask ModelSendAsync(in TMethod method, in THead _flowControlBlock.Wait(); } - return Session.TransmitAsync(in method, in header, body); + return Session.TransmitAsync(in method, in header, body, cancellationToken); } internal void OnCallbackException(CallbackExceptionEventArgs args) @@ -960,7 +963,7 @@ protected void HandleConnectionClose(in IncomingCommand cmd) var reason = new ShutdownEventArgs(ShutdownInitiator.Peer, method._replyCode, method._replyText, method._classId, method._methodId); try { - Session.Connection.InternalClose(reason); + Session.Connection.ClosedViaPeer(reason); _Private_ConnectionCloseOk(); SetCloseReason(Session.Connection.CloseReason); } @@ -994,8 +997,10 @@ protected void HandleConnectionStart(in IncomingCommand cmd) if (m_connectionStartCell is null) { var reason = new ShutdownEventArgs(ShutdownInitiator.Library, Constants.CommandInvalid, "Unexpected Connection.Start"); - // TODO async! - Session.Connection.CloseAsync(reason, false, InternalConstants.DefaultConnectionCloseTimeout).EnsureCompleted(); + // // TODO async / cancellation token + Session.Connection.CloseAsync(reason, false, + InternalConstants.DefaultConnectionCloseTimeout, + CancellationToken.None).EnsureCompleted(); } else { @@ -1074,24 +1079,25 @@ protected bool HandleQueueDeclareOk(in IncomingCommand cmd) public async Task BasicCancelAsync(string consumerTag, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new BasicCancelAsyncRpcContinuation(consumerTag, ConsumerDispatcher, ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - var method = new Client.Framing.Impl.BasicCancel(consumerTag, noWait); + var method = new BasicCancel(consumerTag, noWait); if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); ConsumerDispatcher.GetAndRemoveConsumer(consumerTag); } else { - using var k = new BasicCancelAsyncRpcContinuation(consumerTag, ConsumerDispatcher, ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1119,15 +1125,16 @@ public async Task BasicConsumeAsync(string queue, bool autoAck, string c } } - await _rpcSemaphore.WaitAsync() + using var k = new BasicConsumeAsyncRpcContinuation(consumer, ConsumerDispatcher, ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new BasicConsumeAsyncRpcContinuation(consumer, ConsumerDispatcher, ContinuationTimeout); Enqueue(k); var method = new Client.Framing.Impl.BasicConsume(queue, consumerTag, noLocal, autoAck, exclusive, false, arguments); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); return await k; @@ -1140,15 +1147,16 @@ await ModelSendAsync(method) public async ValueTask BasicGetAsync(string queue, bool autoAck) { - await _rpcSemaphore.WaitAsync() + using var k = new BasicGetAsyncRpcContinuation(AdjustDeliveryTag, ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new BasicGetAsyncRpcContinuation(AdjustDeliveryTag, ContinuationTimeout); Enqueue(k); var method = new BasicGet(queue, autoAck); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); BasicGetResult result = await k; @@ -1171,7 +1179,7 @@ await ModelSendAsync(method) public abstract ValueTask BasicNackAsync(ulong deliveryTag, bool multiple, bool requeue); - public ValueTask BasicPublishAsync(string exchange, string routingKey, in TProperties basicProperties, ReadOnlyMemory body, bool mandatory) + public async ValueTask BasicPublishAsync(string exchange, string routingKey, TProperties basicProperties, ReadOnlyMemory body, bool mandatory) where TProperties : IReadOnlyBasicProperties, IAmqpHeader { if (NextPublishSeqNo > 0) @@ -1193,10 +1201,14 @@ public ValueTask BasicPublishAsync(string exchange, string routingK if (sendActivity != null) { BasicProperties props = PopulateActivityAndPropagateTraceId(basicProperties, sendActivity); - return ModelSendAsync(in cmd, in props, body); + // TODO cancellation token + await ModelSendAsync(in cmd, in props, body, CancellationToken.None); + } + else + { + // TODO cancellation token + await ModelSendAsync(in cmd, in basicProperties, body, CancellationToken.None); } - - return ModelSendAsync(in cmd, in basicProperties, body); } catch { @@ -1227,8 +1239,8 @@ private static void InjectTraceContextIntoBasicProperties(object propsObj, strin } } - public void BasicPublish(CachedString exchange, CachedString routingKey, - in TProperties basicProperties, ReadOnlyMemory body, bool mandatory) + public async void BasicPublish(CachedString exchange, CachedString routingKey, + TProperties basicProperties, ReadOnlyMemory body, bool mandatory) where TProperties : IReadOnlyBasicProperties, IAmqpHeader { if (NextPublishSeqNo > 0) @@ -1242,6 +1254,7 @@ public void BasicPublish(CachedString exchange, CachedString routin try { var cmd = new BasicPublishMemory(exchange.Bytes, routingKey.Bytes, mandatory, default); + RabbitMQActivitySource.TryGetExistingContext(basicProperties, out ActivityContext existingContext); using Activity sendActivity = RabbitMQActivitySource.PublisherHasListeners ? RabbitMQActivitySource.Send(routingKey.Value, exchange.Value, body.Length, existingContext) @@ -1250,55 +1263,14 @@ public void BasicPublish(CachedString exchange, CachedString routin if (sendActivity != null) { BasicProperties props = PopulateActivityAndPropagateTraceId(basicProperties, sendActivity); - ChannelSend(in cmd, in props, body); - return; + // TODO cancellation token + await ModelSendAsync(in cmd, in basicProperties, body, CancellationToken.None); } - - ChannelSend(in cmd, in basicProperties, body); - } - catch - { - if (NextPublishSeqNo > 0) - { - lock (_confirmLock) - { - NextPublishSeqNo--; - _pendingDeliveryTags.RemoveLast(); - } - } - - throw; - } - } - - public async ValueTask BasicPublishAsync(string exchange, string routingKey, - TProperties basicProperties, ReadOnlyMemory body, bool mandatory) - where TProperties : IReadOnlyBasicProperties, IAmqpHeader - { - if (NextPublishSeqNo > 0) - { - lock (_confirmLock) - { - _pendingDeliveryTags.AddLast(NextPublishSeqNo++); - } - } - - try - { - var cmd = new BasicPublish(exchange, routingKey, mandatory, default); - RabbitMQActivitySource.TryGetExistingContext(basicProperties, out ActivityContext existingContext); - using Activity sendActivity = RabbitMQActivitySource.PublisherHasListeners - ? RabbitMQActivitySource.Send(routingKey, exchange, body.Length, existingContext) - : default; - - if (sendActivity != null) + else { - BasicProperties props = PopulateActivityAndPropagateTraceId(basicProperties, sendActivity); - await ModelSendAsync(in cmd, in props, body); - return; + // TODO cancellation token + await ModelSendAsync(in cmd, in basicProperties, body, CancellationToken.None); } - - await ModelSendAsync(in cmd, in basicProperties, body); } catch { @@ -1330,6 +1302,7 @@ public async ValueTask BasicPublishAsync(CachedString exchange, Cac try { var cmd = new BasicPublishMemory(exchange.Bytes, routingKey.Bytes, mandatory, default); + RabbitMQActivitySource.TryGetExistingContext(basicProperties, out ActivityContext existingContext); using Activity sendActivity = RabbitMQActivitySource.PublisherHasListeners ? RabbitMQActivitySource.Send(routingKey.Value, exchange.Value, body.Length, existingContext) @@ -1338,11 +1311,14 @@ public async ValueTask BasicPublishAsync(CachedString exchange, Cac if (sendActivity != null) { BasicProperties props = PopulateActivityAndPropagateTraceId(basicProperties, sendActivity); - await ModelSendAsync(in cmd, in props, body); - return; + // TODO cancellation token + await ModelSendAsync(in cmd, in props, body, CancellationToken.None); + } + else + { + // TODO cancellation token + await ModelSendAsync(in cmd, in basicProperties, body, CancellationToken.None); } - - await ModelSendAsync(in cmd, in basicProperties, body); } catch { @@ -1398,9 +1374,9 @@ await _rpcSemaphore.WaitAsync() using var k = new SimpleAsyncRpcContinuation(ProtocolCommandId.ConnectionUpdateSecretOk, ContinuationTimeout); Enqueue(k); - var newSecretBytes = Encoding.UTF8.GetBytes(newSecret); + byte[] newSecretBytes = Encoding.UTF8.GetBytes(newSecret); var method = new ConnectionUpdateSecret(newSecretBytes, reason); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1415,15 +1391,16 @@ await ModelSendAsync(method) public async Task BasicQosAsync(uint prefetchSize, ushort prefetchCount, bool global) { - await _rpcSemaphore.WaitAsync() + using var k = new BasicQosAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new BasicQosAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new BasicQos(prefetchSize, prefetchCount, global); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1440,7 +1417,9 @@ await ModelSendAsync(method) public async Task ConfirmSelectAsync() { - await _rpcSemaphore.WaitAsync() + using var k = new ConfirmSelectAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1450,11 +1429,10 @@ await _rpcSemaphore.WaitAsync() NextPublishSeqNo = 1; } - using var k = new ConfirmSelectAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new ConfirmSelect(false); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1471,7 +1449,9 @@ await ModelSendAsync(method) public async Task ExchangeBindAsync(string destination, string source, string routingKey, IDictionary arguments, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new ExchangeBindAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1479,15 +1459,14 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } else { - using var k = new ExchangeBindAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1511,7 +1490,9 @@ public Task ExchangeDeclarePassiveAsync(string exchange) public async Task ExchangeDeclareAsync(string exchange, string type, bool durable, bool autoDelete, IDictionary arguments, bool passive, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new ExchangeDeclareAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1519,15 +1500,14 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } else { - using var k = new ExchangeDeclareAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1544,7 +1524,9 @@ await ModelSendAsync(method) public async Task ExchangeDeleteAsync(string exchange, bool ifUnused, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new ExchangeDeleteAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1552,15 +1534,14 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } else { - using var k = new ExchangeDeleteAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1578,7 +1559,9 @@ await ModelSendAsync(method) public async Task ExchangeUnbindAsync(string destination, string source, string routingKey, IDictionary arguments, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new ExchangeUnbindAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1586,15 +1569,14 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } else { - using var k = new ExchangeUnbindAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1632,7 +1614,9 @@ public async Task QueueDeclareAsync(string queue, bool durable, } } - await _rpcSemaphore.WaitAsync() + using var k = new QueueDeclareAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1640,7 +1624,7 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); if (false == passive) @@ -1652,10 +1636,9 @@ await ModelSendAsync(method) } else { - using var k = new QueueDeclareAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); QueueDeclareOk result = await k; @@ -1676,7 +1659,9 @@ await ModelSendAsync(method) public async Task QueueBindAsync(string queue, string exchange, string routingKey, IDictionary arguments, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new QueueBindAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1684,15 +1669,14 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); } else { - using var k = new QueueBindAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1721,7 +1705,9 @@ public async Task ConsumerCountAsync(string queue) public async Task QueueDeleteAsync(string queue, bool ifUnused, bool ifEmpty, bool noWait) { - await _rpcSemaphore.WaitAsync() + using var k = new QueueDeleteAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { @@ -1729,17 +1715,16 @@ await _rpcSemaphore.WaitAsync() if (noWait) { - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); return 0; } else { - var k = new QueueDeleteAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); return await k; @@ -1753,15 +1738,16 @@ await ModelSendAsync(method) public async Task QueuePurgeAsync(string queue) { - await _rpcSemaphore.WaitAsync() + using var k = new QueuePurgeAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - var k = new QueuePurgeAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new QueuePurge(queue, false); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); return await k; @@ -1774,15 +1760,16 @@ await ModelSendAsync(method) public async Task QueueUnbindAsync(string queue, string exchange, string routingKey, IDictionary arguments) { - await _rpcSemaphore.WaitAsync() + using var k = new QueueUnbindAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new QueueUnbindAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new QueueUnbind(queue, exchange, routingKey, arguments); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1797,15 +1784,16 @@ await ModelSendAsync(method) public async Task TxCommitAsync() { - await _rpcSemaphore.WaitAsync() + using var k = new TxCommitAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new TxCommitAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new TxCommit(); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1820,15 +1808,16 @@ await ModelSendAsync(method) public async Task TxRollbackAsync() { - await _rpcSemaphore.WaitAsync() + using var k = new TxRollbackAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new TxRollbackAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new TxRollback(); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1843,15 +1832,16 @@ await ModelSendAsync(method) public async Task TxSelectAsync() { - await _rpcSemaphore.WaitAsync() + using var k = new TxSelectAsyncRpcContinuation(ContinuationTimeout); + + await _rpcSemaphore.WaitAsync(k.CancellationToken) .ConfigureAwait(false); try { - using var k = new TxSelectAsyncRpcContinuation(ContinuationTimeout); Enqueue(k); var method = new TxSelect(); - await ModelSendAsync(method) + await ModelSendAsync(method, k.CancellationToken) .ConfigureAwait(false); bool result = await k; @@ -1902,24 +1892,25 @@ public Task WaitForConfirmsAsync(CancellationToken token = default) private async Task WaitForConfirmsWithTokenAsync(TaskCompletionSource tcs, CancellationToken token) { CancellationTokenRegistration tokenRegistration = -#if NETSTANDARD - token.Register( -#else +#if NET6_0_OR_GREATER token.UnsafeRegister( -#endif state => ((TaskCompletionSource)state).TrySetCanceled(), tcs); - +#else + token.Register( + state => ((TaskCompletionSource)state).TrySetCanceled(), + state: tcs, useSynchronizationContext: false); +#endif try { return await tcs.Task.ConfigureAwait(false); } finally { -#if NETSTANDARD - tokenRegistration.Dispose(); -#else +#if NET6_0_OR_GREATER await tokenRegistration.DisposeAsync() .ConfigureAwait(false); +#else + tokenRegistration.Dispose(); #endif } } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs index ae2991945e..f0e236d57a 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Commands.cs @@ -74,10 +74,22 @@ private async ValueTask StartAndTuneAsync(CancellationToken cancellationToken) { var connectionStartCell = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - using CancellationTokenRegistration ctr = cancellationToken.Register(() => +#if NET6_0_OR_GREATER + using CancellationTokenRegistration ctr = cancellationToken.UnsafeRegister((object? state) => { - connectionStartCell.TrySetCanceled(cancellationToken); - }, useSynchronizationContext: false); + if (state != null) + { + var csc = (TaskCompletionSource)state; + csc.TrySetCanceled(cancellationToken); + } + }, connectionStartCell); +#else + using CancellationTokenRegistration ctr = cancellationToken.Register((object state) => + { + var csc = (TaskCompletionSource)state; + csc.TrySetCanceled(cancellationToken); + }, state: connectionStartCell, useSynchronizationContext: false); +#endif _channel0.m_connectionStartCell = connectionStartCell; _channel0.HandshakeContinuationTimeout = _config.HandshakeContinuationTimeout; @@ -101,7 +113,8 @@ await _frameHandler.SendProtocolHeaderAsync(cancellationToken) if (!serverVersion.Equals(Protocol.Version)) { TerminateMainloop(); - FinishClose(); + // TODO hmmm + FinishCloseAsync(CancellationToken.None).EnsureCompleted(); throw new ProtocolVersionMismatchException(Protocol.MajorVersion, Protocol.MinorVersion, serverVersion.Major, serverVersion.Minor); } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs b/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs index bd3ae7d7dc..4e1cad73ab 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Heartbeat.cs @@ -91,7 +91,7 @@ private void HeartbeatReadTimerCallback(object? state) try { - if (!_closed) + if (false == _closed) { if (_heartbeatDetected) { @@ -119,7 +119,8 @@ private void HeartbeatReadTimerCallback(object? state) if (shouldTerminate) { TerminateMainloop(); - FinishClose(); + // TODO hmmm + FinishCloseAsync(CancellationToken.None).EnsureCompleted(); } else { @@ -147,7 +148,7 @@ private void HeartbeatWriteTimerCallback(object? state) try { - if (!_closed) + if (false == _closed) { Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame()); _heartbeatWriteTimer?.Change((int)_heartbeatTimeSpan.TotalMilliseconds, Timeout.Infinite); diff --git a/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs b/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs index 8ddc388fae..d9b5a02453 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.Receive.cs @@ -31,6 +31,7 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Exceptions; using RabbitMQ.Client.Impl; @@ -40,16 +41,23 @@ namespace RabbitMQ.Client.Framing.Impl #nullable enable internal sealed partial class Connection { + private readonly CancellationTokenSource _mainLoopCts = new CancellationTokenSource(); private readonly IFrameHandler _frameHandler; private Task _mainLoopTask; private async Task MainLoop() { + CancellationToken mainLoopToken = _mainLoopCts.Token; try { - await ReceiveLoopAsync() + await ReceiveLoopAsync(mainLoopToken) .ConfigureAwait(false); } + catch (OperationCanceledException) + { + // TODO what to do here? + // Debug log? + } catch (EndOfStreamException eose) { // Possible heartbeat exception @@ -60,7 +68,7 @@ await ReceiveLoopAsync() } catch (HardProtocolException hpe) { - await HardProtocolExceptionHandlerAsync(hpe) + await HardProtocolExceptionHandlerAsync(hpe, mainLoopToken) .ConfigureAwait(false); } catch (FileLoadException fileLoadException) @@ -83,13 +91,17 @@ await HardProtocolExceptionHandlerAsync(hpe) HandleMainLoopException(ea); } - FinishClose(); + // TODO is this the best way? + using var cts = new CancellationTokenSource(InternalConstants.DefaultConnectionCloseTimeout); + await FinishCloseAsync(cts.Token); } - private async Task ReceiveLoopAsync() + private async Task ReceiveLoopAsync(CancellationToken mainLoopCancelllationToken) { - while (!_closed) + while (false == _closed) { + mainLoopCancelllationToken.ThrowIfCancellationRequested(); + while (_frameHandler.TryReadFrame(out InboundFrame frame)) { NotifyHeartbeatListener(); @@ -97,8 +109,8 @@ private async Task ReceiveLoopAsync() } // Done reading frames synchronously, go async - InboundFrame asyncFrame = await _frameHandler.ReadFrameAsync() - .ConfigureAwait(false); + InboundFrame asyncFrame = await _frameHandler.ReadFrameAsync(mainLoopCancelllationToken) + .ConfigureAwait(false); NotifyHeartbeatListener(); ProcessFrame(asyncFrame); } @@ -157,6 +169,7 @@ private void ProcessFrame(InboundFrame frame) /// private void TerminateMainloop() { + _mainLoopCts.Cancel(); MaybeStopHeartbeatTimers(); } @@ -174,19 +187,20 @@ private void HandleMainLoopException(ShutdownEventArgs reason) LogCloseError($"Unexpected connection closure: {reason}", reason.Exception); } - private async Task HardProtocolExceptionHandlerAsync(HardProtocolException hpe) + private async Task HardProtocolExceptionHandlerAsync(HardProtocolException hpe, CancellationToken cancellationToken) { if (SetCloseReason(hpe.ShutdownReason)) { OnShutdown(hpe.ShutdownReason); - _session0.SetSessionClosing(false); + await _session0.SetSessionClosingAsync(false); try { var cmd = new ConnectionClose(hpe.ShutdownReason.ReplyCode, hpe.ShutdownReason.ReplyText, 0, 0); - _session0.Transmit(in cmd); + await _session0.TransmitAsync(in cmd, cancellationToken) + .ConfigureAwait(false); if (hpe.CanShutdownCleanly) { - await ClosingLoopAsync() + await ClosingLoopAsync(cancellationToken) .ConfigureAwait(false); } } @@ -204,18 +218,18 @@ await ClosingLoopAsync() /// /// Loop only used while quiescing. Use only to cleanly close connection /// - private async Task ClosingLoopAsync() + private async Task ClosingLoopAsync(CancellationToken cancellationToken) { try { _frameHandler.ReadTimeout = TimeSpan.Zero; // Wait for response/socket closure or timeout - await ReceiveLoopAsync() + await ReceiveLoopAsync(cancellationToken) .ConfigureAwait(false); } catch (ObjectDisposedException ode) { - if (!_closed) + if (false == _closed) { LogCloseError("Connection didn't close cleanly", ode); } diff --git a/projects/RabbitMQ.Client/client/impl/Connection.cs b/projects/RabbitMQ.Client/client/impl/Connection.cs index 056be859ce..b00225be38 100644 --- a/projects/RabbitMQ.Client/client/impl/Connection.cs +++ b/projects/RabbitMQ.Client/client/impl/Connection.cs @@ -33,8 +33,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Net; -using System.Net.Sockets; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -52,7 +50,7 @@ internal sealed partial class Connection : IConnection private volatile bool _closed; private readonly ConnectionConfig _config; - private readonly ChannelBase _channel0; + private readonly ChannelBase _channel0; // TODO this isn't disposed, hmm private readonly MainSession _session0; private Guid _id = Guid.NewGuid(); @@ -153,7 +151,7 @@ public event EventHandler ConnectionShutdown add { ThrowIfDisposed(); - var reason = CloseReason; + ShutdownEventArgs? reason = CloseReason; if (reason is null) { _connectionShutdownWrapper.AddHandler(value); @@ -220,35 +218,42 @@ internal IConnection Open() return OpenAsync(CancellationToken.None).EnsureCompleted(); } - // TODO cancellationToken internal async ValueTask OpenAsync(CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + try { RabbitMqClientEventSource.Log.ConnectionOpened(); + cancellationToken.ThrowIfCancellationRequested(); + await _frameHandler.ConnectAsync(cancellationToken) .ConfigureAwait(false); // Note: this must happen *after* the frame handler is started _mainLoopTask = Task.Run(MainLoop, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + await StartAndTuneAsync(cancellationToken) .ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + await _channel0.ConnectionOpenAsync(_config.VirtualHost, cancellationToken) .ConfigureAwait(false); return this; } - catch // TODO - evaluate all "catch all" clauses to ensure correct exception is eventually thrown + catch { try { var ea = new ShutdownEventArgs(ShutdownInitiator.Library, Constants.InternalError, "FailedOpen"); - // TODO linked cancellation token? - await CloseAsync(ea, true, TimeSpan.FromSeconds(5)) - .ConfigureAwait(false); + await CloseAsync(ea, true, + InternalConstants.DefaultConnectionAbortTimeout, + cancellationToken).ConfigureAwait(false); } catch { } @@ -285,9 +290,10 @@ internal void EnsureIsOpen() } ///Asynchronous API-side invocation of connection.close with timeout. - public Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort) + public Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort, CancellationToken cancellationToken) { - return CloseAsync(new ShutdownEventArgs(ShutdownInitiator.Application, reasonCode, reasonText), abort, timeout); + var reason = new ShutdownEventArgs(ShutdownInitiator.Application, reasonCode, reasonText); + return CloseAsync(reason, abort, timeout, cancellationToken); } ///Asychronously try to close connection in a graceful way @@ -305,8 +311,7 @@ public Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, b ///to complete. /// /// - // TODO cancellation token - internal async Task CloseAsync(ShutdownEventArgs reason, bool abort, TimeSpan timeout) + internal async Task CloseAsync(ShutdownEventArgs reason, bool abort, TimeSpan timeout, CancellationToken cancellationToken) { if (false == SetCloseReason(reason)) { @@ -318,22 +323,24 @@ internal async Task CloseAsync(ShutdownEventArgs reason, bool abort, TimeSpan ti } else { + cancellationToken.ThrowIfCancellationRequested(); + OnShutdown(reason); - _session0.SetSessionClosing(false); + await _session0.SetSessionClosingAsync(false); try { // Try to send connection.close wait for CloseOk in the MainLoop - if (!_closed) + if (false == _closed) { var method = new ConnectionClose(reason.ReplyCode, reason.ReplyText, 0, 0); - await _session0.TransmitAsync(method) + await _session0.TransmitAsync(method, cancellationToken) .ConfigureAwait(false); } } catch (AlreadyClosedException) { - if (!abort) + if (false == abort) { throw; } @@ -365,29 +372,28 @@ await _session0.TransmitAsync(method) try { - await _mainLoopTask.WaitAsync(timeout) + await _mainLoopTask.WaitAsync(timeout, cancellationToken) .ConfigureAwait(false); } - catch (TimeoutException) - { - } - catch (AggregateException) - { - } - finally + catch { try { - await _frameHandler.CloseAsync() + await _frameHandler.CloseAsync(cancellationToken) .ConfigureAwait(false); } catch { } + + if (false == abort) + { + throw; + } } } - internal void InternalClose(ShutdownEventArgs reason) + internal void ClosedViaPeer(ShutdownEventArgs reason) { if (!SetCloseReason(reason)) { @@ -404,13 +410,13 @@ internal void InternalClose(ShutdownEventArgs reason) } // Only call at the end of the Mainloop or HeartbeatLoop - // TODO async - private void FinishClose() + private async Task FinishCloseAsync(CancellationToken cancellationToken) { + _mainLoopCts.Cancel(); _closed = true; MaybeStopHeartbeatTimers(); - _frameHandler.Close(); + await _frameHandler.CloseAsync(cancellationToken); _channel0.SetCloseReason(CloseReason); _channel0.FinishClose(); RabbitMqClientEventSource.Log.ConnectionClosed(); @@ -449,17 +455,13 @@ internal void OnCallbackException(CallbackExceptionEventArgs args) internal void Write(RentedMemory frames) { Activity.Current.SetNetworkTags(_frameHandler); - ValueTask task = _frameHandler.WriteAsync(frames); - if (!task.IsCompletedSuccessfully) - { - task.EnsureCompleted(); - } + _frameHandler.Write(frames); } - internal ValueTask WriteAsync(RentedMemory frames) + internal ValueTask WriteAsync(RentedMemory frames, CancellationToken cancellationToken) { Activity.Current.SetNetworkTags(_frameHandler); - return _frameHandler.WriteAsync(frames); + return _frameHandler.WriteAsync(frames, cancellationToken); } public void Dispose() @@ -480,6 +482,9 @@ public void Dispose() { throw new InvalidOperationException("Connection must be closed before calling Dispose!"); } + + _session0.Dispose(); + _mainLoopCts.Dispose(); } catch (OperationInterruptedException) { diff --git a/projects/RabbitMQ.Client/client/impl/Frame.cs b/projects/RabbitMQ.Client/client/impl/Frame.cs index ad5eab4c19..1c6bcb4ad5 100644 --- a/projects/RabbitMQ.Client/client/impl/Frame.cs +++ b/projects/RabbitMQ.Client/client/impl/Frame.cs @@ -34,6 +34,7 @@ using System.IO; using System.IO.Pipelines; using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Exceptions; @@ -253,9 +254,9 @@ private static void ProcessProtocolHeader(ReadOnlySequence buffer) } } - internal static async ValueTask ReadFromPipeAsync(PipeReader reader, uint maxMessageSize) + internal static async ValueTask ReadFromPipeAsync(PipeReader reader, uint maxMessageSize, CancellationToken cancellationToken) { - ReadResult result = await reader.ReadAsync() + ReadResult result = await reader.ReadAsync(cancellationToken) .ConfigureAwait(false); ReadOnlySequence buffer = result.Buffer; @@ -269,7 +270,7 @@ internal static async ValueTask ReadFromPipeAsync(PipeReader reade reader.AdvanceTo(buffer.Start, buffer.End); // Not enough data, read a bit more - result = await reader.ReadAsync() + result = await reader.ReadAsync(cancellationToken) .ConfigureAwait(false); MaybeThrowEndOfStream(result, buffer); @@ -381,8 +382,7 @@ private static void MaybeThrowEndOfStream(ReadResult result, ReadOnlySequenceSocket write timeout. System.Threading.Timeout.InfiniteTimeSpan signals "infinity". TimeSpan WriteTimeout { set; } - void Close(); - Task CloseAsync(); + Task CloseAsync(CancellationToken cancellationToken); ///Read a frame from the underlying ///transport. Returns null if the read operation timed out ///(see Timeout property). - // TODO cancellation token for read timeout / cancellation? - ValueTask ReadFrameAsync(); + ValueTask ReadFrameAsync(CancellationToken cancellationToken); ///Try to synchronously read a frame from the underlying transport. ///Returns false if connection buffer contains insufficient data. @@ -72,7 +70,7 @@ internal interface IFrameHandler Task SendProtocolHeaderAsync(CancellationToken cancellationToken); - // TODO cancellation token for write timeout / cancellation? - ValueTask WriteAsync(RentedMemory frames); + void Write(RentedMemory frames); // TODO remove, should be async only + ValueTask WriteAsync(RentedMemory frames, CancellationToken cancellationToken); } } diff --git a/projects/RabbitMQ.Client/client/impl/ISession.cs b/projects/RabbitMQ.Client/client/impl/ISession.cs index 1d5a34f76b..69f9a8076e 100644 --- a/projects/RabbitMQ.Client/client/impl/ISession.cs +++ b/projects/RabbitMQ.Client/client/impl/ISession.cs @@ -30,6 +30,7 @@ //--------------------------------------------------------------------------- using System; +using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Framing.Impl; @@ -83,9 +84,9 @@ void Transmit(in TMethod cmd, in THeader header, ReadOnlyMemor where TMethod : struct, IOutgoingAmqpMethod where THeader : IAmqpHeader; - ValueTask TransmitAsync(in T cmd) where T : struct, IOutgoingAmqpMethod; + ValueTask TransmitAsync(in T cmd, CancellationToken cancellationToken) where T : struct, IOutgoingAmqpMethod; - ValueTask TransmitAsync(in TMethod cmd, in THeader header, ReadOnlyMemory body) + ValueTask TransmitAsync(in TMethod cmd, in THeader header, ReadOnlyMemory body, CancellationToken cancellationToken) where TMethod : struct, IOutgoingAmqpMethod where THeader : IAmqpHeader; } diff --git a/projects/RabbitMQ.Client/client/impl/MainSession.cs b/projects/RabbitMQ.Client/client/impl/MainSession.cs index 8b7e19f597..79423ca354 100644 --- a/projects/RabbitMQ.Client/client/impl/MainSession.cs +++ b/projects/RabbitMQ.Client/client/impl/MainSession.cs @@ -34,17 +34,20 @@ // the versions we support*. Obviously we may need to revisit this if // that ever changes. +using System; +using System.Threading; +using System.Threading.Tasks; using RabbitMQ.Client.client.framing; using RabbitMQ.Client.Framing.Impl; namespace RabbitMQ.Client.Impl { ///Small ISession implementation used only for channel 0. - internal sealed class MainSession : Session + internal sealed class MainSession : Session, IDisposable { - private volatile bool _closeServerInitiated; + private volatile bool _closeIsServerInitiated; private volatile bool _closing; - private readonly object _lock = new object(); + private readonly SemaphoreSlim _closingSemaphore = new SemaphoreSlim(1, 1); public MainSession(Connection connection) : base(connection, 0) { @@ -55,7 +58,7 @@ public override bool HandleFrame(in InboundFrame frame) if (_closing) { // We are closing - if (!_closeServerInitiated && frame.Type == FrameType.FrameMethod) + if ((false == _closeIsServerInitiated) && (frame.Type == FrameType.FrameMethod)) { // This isn't a server initiated close and we have a method frame switch (Connection.Protocol.DecodeCommandIdFrom(frame.Payload.Span)) @@ -84,33 +87,86 @@ public override bool HandleFrame(in InboundFrame frame) /// method call because that would prevent us from /// sending/receiving Close/CloseOk commands /// - public void SetSessionClosing(bool closeServerInitiated) + public void SetSessionClosing(bool closeIsServerInitiated) { - if (!_closing) + if (_closingSemaphore.Wait(InternalConstants.DefaultConnectionAbortTimeout)) { - lock (_lock) + try { - if (!_closing) + if (false == _closing) { _closing = true; - _closeServerInitiated = closeServerInitiated; + _closeIsServerInitiated = closeIsServerInitiated; } } + finally + { + _closingSemaphore.Release(); + } + } + else + { + throw new InvalidOperationException("[DEBUG] couldn't enter semaphore"); + } + } + + public async Task SetSessionClosingAsync(bool closeIsServerInitiated) + { + if (await _closingSemaphore.WaitAsync(InternalConstants.DefaultConnectionAbortTimeout).ConfigureAwait(false)) + { + try + { + if (false == _closing) + { + _closing = true; + _closeIsServerInitiated = closeIsServerInitiated; + } + } + finally + { + _closingSemaphore.Release(); + } + } + else + { + throw new InvalidOperationException("[DEBUG] couldn't async enter semaphore"); } } public override void Transmit(in T cmd) { - if (_closing && // Are we closing? - cmd.ProtocolCommandId != ProtocolCommandId.ConnectionCloseOk && // is this not a close-ok? - (_closeServerInitiated || cmd.ProtocolCommandId != ProtocolCommandId.ConnectionClose)) // is this either server initiated or not a close? + // Are we closing? + if (_closing) { - // We shouldn't do anything since we are closing, not sending a connection-close-ok command - // and this is either a server-initiated close or not a connection-close command. - return; + if ((cmd.ProtocolCommandId != ProtocolCommandId.ConnectionCloseOk) && // is this not a close-ok? + (_closeIsServerInitiated || cmd.ProtocolCommandId != ProtocolCommandId.ConnectionClose)) // is this either server initiated or not a close? + { + // We shouldn't do anything since we are closing, not sending a connection-close-ok command + // and this is either a server-initiated close or not a connection-close command. + return; + } } base.Transmit(in cmd); } + + public override ValueTask TransmitAsync(in T cmd, CancellationToken cancellationToken) + { + // Are we closing? + if (_closing) + { + if ((cmd.ProtocolCommandId != ProtocolCommandId.ConnectionCloseOk) && // is this not a close-ok? + (_closeIsServerInitiated || cmd.ProtocolCommandId != ProtocolCommandId.ConnectionClose)) // is this either server initiated or not a close? + { + // We shouldn't do anything since we are closing, not sending a connection-close-ok command + // and this is either a server-initiated close or not a connection-close command. + return default; + } + } + + return base.TransmitAsync(in cmd, cancellationToken); + } + + public void Dispose() => ((IDisposable)_closingSemaphore).Dispose(); } } diff --git a/projects/RabbitMQ.Client/client/impl/RecordedBinding.cs b/projects/RabbitMQ.Client/client/impl/RecordedBinding.cs index d0a861d8fd..51fd498708 100644 --- a/projects/RabbitMQ.Client/client/impl/RecordedBinding.cs +++ b/projects/RabbitMQ.Client/client/impl/RecordedBinding.cs @@ -92,7 +92,9 @@ public override bool Equals(object? obj) public override int GetHashCode() { -#if NETSTANDARD +#if NET6_0_OR_GREATER + return HashCode.Combine(_isQueueBinding, _destination, _source, _routingKey, _arguments); +#else unchecked { int hashCode = _isQueueBinding.GetHashCode(); @@ -102,8 +104,6 @@ public override int GetHashCode() hashCode = (hashCode * 397) ^ (_arguments != null ? _arguments.GetHashCode() : 0); return hashCode; } -#else - return HashCode.Combine(_isQueueBinding, _destination, _source, _routingKey, _arguments); #endif } diff --git a/projects/RabbitMQ.Client/client/impl/SessionBase.cs b/projects/RabbitMQ.Client/client/impl/SessionBase.cs index d2c8d3d545..961d43cf6e 100644 --- a/projects/RabbitMQ.Client/client/impl/SessionBase.cs +++ b/projects/RabbitMQ.Client/client/impl/SessionBase.cs @@ -144,7 +144,7 @@ public virtual void Transmit(in T cmd) where T : struct, IOutgoingAmqpMethod Connection.Write(bytes); } - public virtual ValueTask TransmitAsync(in T cmd) where T : struct, IOutgoingAmqpMethod + public virtual ValueTask TransmitAsync(in T cmd, CancellationToken cancellationToken) where T : struct, IOutgoingAmqpMethod { if (!IsOpen && cmd.ProtocolCommandId != client.framing.ProtocolCommandId.ChannelCloseOk) { @@ -153,7 +153,7 @@ public virtual ValueTask TransmitAsync(in T cmd) where T : struct, IOutgoingA RentedMemory bytes = Framing.SerializeToFrames(ref Unsafe.AsRef(cmd), ChannelNumber); RabbitMQActivitySource.PopulateMessageEnvelopeSize(Activity.Current, bytes.Size); - return Connection.WriteAsync(bytes); + return Connection.WriteAsync(bytes, cancellationToken); } public void Transmit(in TMethod cmd, in THeader header, ReadOnlyMemory body) @@ -171,7 +171,7 @@ public void Transmit(in TMethod cmd, in THeader header, ReadOn Connection.Write(bytes); } - public ValueTask TransmitAsync(in TMethod cmd, in THeader header, ReadOnlyMemory body) + public ValueTask TransmitAsync(in TMethod cmd, in THeader header, ReadOnlyMemory body, CancellationToken cancellationToken = default) where TMethod : struct, IOutgoingAmqpMethod where THeader : IAmqpHeader { @@ -180,10 +180,9 @@ public ValueTask TransmitAsync(in TMethod cmd, in THeader head ThrowAlreadyClosedException(); } - RentedMemory bytes = Framing.SerializeToFrames(ref Unsafe.AsRef(cmd), ref Unsafe.AsRef(header), body, ChannelNumber, - Connection.MaxPayloadSize); + RentedMemory bytes = Framing.SerializeToFrames(ref Unsafe.AsRef(cmd), ref Unsafe.AsRef(header), body, ChannelNumber, Connection.MaxPayloadSize); RabbitMQActivitySource.PopulateMessageEnvelopeSize(Activity.Current, bytes.Size); - return Connection.WriteAsync(bytes); + return Connection.WriteAsync(bytes, cancellationToken); } private void ThrowAlreadyClosedException() diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index b0f9e766a9..d2f3927f6d 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -155,6 +155,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken) return; } + cancellationToken.ThrowIfCancellationRequested(); + #if NET6_0_OR_GREATER _amqpTcpEndpointAddresses = await Dns.GetHostAddressesAsync(_amqpTcpEndpoint.HostName, cancellationToken) .ConfigureAwait(false); @@ -216,7 +218,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) } catch { - await CloseAsync() + await CloseAsync(cancellationToken) .ConfigureAwait(false); throw; } @@ -229,22 +231,17 @@ await CloseAsync() _connected = true; } - public void Close() - { - CloseAsync().EnsureCompleted(); - } - - public async Task CloseAsync() + public async Task CloseAsync(CancellationToken cancellationToken) { if (_closed || _socket == null) { return; } - await _closingSemaphore.WaitAsync() - .ConfigureAwait(false); try { + await _closingSemaphore.WaitAsync(cancellationToken) + .ConfigureAwait(false); try { _channelWriter.Complete(); @@ -271,16 +268,20 @@ await _pipeReader.CompleteAsync() // ignore, we are closing anyway } } + catch + { + } finally { _closingSemaphore.Release(); + _closingSemaphore.Dispose(); _closed = true; } } - public ValueTask ReadFrameAsync() + public ValueTask ReadFrameAsync(CancellationToken cancellationToken) { - return InboundFrame.ReadFromPipeAsync(_pipeReader, _amqpTcpEndpoint.MaxMessageSize); + return InboundFrame.ReadFromPipeAsync(_pipeReader, _amqpTcpEndpoint.MaxMessageSize, cancellationToken); } public bool TryReadFrame(out InboundFrame frame) @@ -296,17 +297,31 @@ await _pipeWriter.FlushAsync(cancellationToken) .ConfigureAwait(false); } - public async ValueTask WriteAsync(RentedMemory frames) + public void Write(RentedMemory frames) { if (_closed) { frames.Dispose(); - await Task.Yield(); } else { - await _channelWriter.WriteAsync(frames) - .ConfigureAwait(false); + if (false == _channelWriter.TryWrite(frames)) + { + // TODO what to do here? + } + } + } + + public ValueTask WriteAsync(RentedMemory frames, CancellationToken cancellationToken) + { + if (_closed) + { + frames.Dispose(); + return default; + } + else + { + return _channelWriter.WriteAsync(frames, cancellationToken); } } @@ -406,8 +421,7 @@ private static async Task ConnectOrFailAsync(ITcpClient tcpClient, IPEndPoint en * https://learn.microsoft.com/en-us/dotnet/standard/threading/how-to-listen-for-multiple-cancellation-requests */ using var timeoutTokenSource = new CancellationTokenSource(connectionTimeout); - CancellationToken timeoutToken = timeoutTokenSource.Token; - using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, externalCancellationToken); + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, externalCancellationToken); try { @@ -428,8 +442,9 @@ await tcpClient.ConnectAsync(endpoint.Address, endpoint.Port, linkedTokenSource. } catch (OperationCanceledException e) { - if (timeoutToken.IsCancellationRequested) + if (timeoutTokenSource.Token.IsCancellationRequested) { + // TODO maybe do not use System.TimeoutException here var timeoutException = new TimeoutException(msg, e); throw new ConnectFailureException(msg, timeoutException); } diff --git a/projects/RabbitMQ.Client/client/logging/RabbitMqClientEventSource.Counters.cs b/projects/RabbitMQ.Client/client/logging/RabbitMqClientEventSource.Counters.cs index 4cbee694d9..6cc3995d18 100644 --- a/projects/RabbitMQ.Client/client/logging/RabbitMqClientEventSource.Counters.cs +++ b/projects/RabbitMQ.Client/client/logging/RabbitMqClientEventSource.Counters.cs @@ -47,7 +47,7 @@ internal sealed partial class RabbitMqClientEventSource private static long CommandsSent; private static long CommandsReceived; -#if !NETSTANDARD +#if NET6_0_OR_GREATER private PollingCounter? _connectionOpenedCounter; private PollingCounter? _openConnectionCounter; private PollingCounter? _channelOpenedCounter; diff --git a/projects/RabbitMQ.Client/util/BlockingCell.cs b/projects/RabbitMQ.Client/util/BlockingCell.cs index 8f7c0d2736..356fc91429 100644 --- a/projects/RabbitMQ.Client/util/BlockingCell.cs +++ b/projects/RabbitMQ.Client/util/BlockingCell.cs @@ -76,6 +76,8 @@ public T WaitForValue(TimeSpan timeout) { return _value; } + + // TODO do not use System.TimeoutException here throw new TimeoutException(); } diff --git a/projects/Test/Applications/CreateChannel/Program.cs b/projects/Test/Applications/CreateChannel/Program.cs index a490ae0032..d3f5d66ce6 100644 --- a/projects/Test/Applications/CreateChannel/Program.cs +++ b/projects/Test/Applications/CreateChannel/Program.cs @@ -17,8 +17,6 @@ public static class Program public static async Task Main() { - ThreadPool.SetMinThreads(16 * Environment.ProcessorCount, 16 * Environment.ProcessorCount); - doneEvent = new AutoResetEvent(false); var connectionFactory = new ConnectionFactory { DispatchConsumersAsync = true }; diff --git a/projects/Test/AsyncIntegration/TestAsyncConsumer.cs b/projects/Test/AsyncIntegration/TestAsyncConsumer.cs index 49c77aedb5..fbfc24eaa7 100644 --- a/projects/Test/AsyncIntegration/TestAsyncConsumer.cs +++ b/projects/Test/AsyncIntegration/TestAsyncConsumer.cs @@ -65,62 +65,71 @@ public async Task TestBasicRoundtripConcurrent() var publish1SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var maximumWaitTime = TimeSpan.FromSeconds(10); + var tokenSource = new CancellationTokenSource(maximumWaitTime); - tokenSource.Token.Register(() => + CancellationTokenRegistration ctsr = tokenSource.Token.Register(() => { publish1SyncSource.TrySetResult(false); publish2SyncSource.TrySetResult(false); }); - _conn.ConnectionShutdown += (o, ea) => + try { - HandleConnectionShutdown(_conn, ea, (args) => + _conn.ConnectionShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleConnectionShutdown(_conn, ea, (args) => { - publish1SyncSource.TrySetResult(false); - publish2SyncSource.TrySetResult(false); - } - }); - }; + if (args.Initiator == ShutdownInitiator.Peer) + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + } + }); + }; - _channel.ChannelShutdown += (o, ea) => - { - HandleChannelShutdown(_channel, ea, (args) => + _channel.ChannelShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleChannelShutdown(_channel, ea, (args) => { - publish1SyncSource.TrySetResult(false); - publish2SyncSource.TrySetResult(false); - } - }); - }; + if (args.Initiator == ShutdownInitiator.Peer) + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + } + }); + }; - consumer.Received += async (o, a) => - { - string decoded = _encoding.GetString(a.Body.ToArray()); - if (decoded == publish1) - { - publish1SyncSource.TrySetResult(true); - await publish2SyncSource.Task; - } - else if (decoded == publish2) + consumer.Received += async (o, a) => { - publish2SyncSource.TrySetResult(true); - await publish1SyncSource.Task; - } - }; + string decoded = _encoding.GetString(a.Body.ToArray()); + if (decoded == publish1) + { + publish1SyncSource.TrySetResult(true); + await publish2SyncSource.Task; + } + else if (decoded == publish2) + { + publish2SyncSource.TrySetResult(true); + await publish1SyncSource.Task; + } + }; - await _channel.BasicConsumeAsync(q.QueueName, true, string.Empty, false, false, null, consumer); + await _channel.BasicConsumeAsync(q.QueueName, true, string.Empty, false, false, null, consumer); - // ensure we get a delivery - await AssertRanToCompletion(publish1SyncSource.Task, publish2SyncSource.Task); + // ensure we get a delivery + await AssertRanToCompletion(publish1SyncSource.Task, publish2SyncSource.Task); - bool result1 = await publish1SyncSource.Task; - Assert.True(result1, $"1 - Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + bool result1 = await publish1SyncSource.Task; + Assert.True(result1, $"1 - Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); - bool result2 = await publish2SyncSource.Task; - Assert.True(result2, $"2 - Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + bool result2 = await publish2SyncSource.Task; + Assert.True(result2, $"2 - Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + } + finally + { + tokenSource.Dispose(); + ctsr.Dispose(); + } } [Fact] @@ -138,115 +147,123 @@ public async Task TestBasicRoundtripConcurrentManyMessages() var publish2SyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var maximumWaitTime = TimeSpan.FromSeconds(30); var tokenSource = new CancellationTokenSource(maximumWaitTime); - tokenSource.Token.Register(() => + CancellationTokenRegistration ctsr = tokenSource.Token.Register(() => { publish1SyncSource.TrySetResult(false); publish2SyncSource.TrySetResult(false); }); - _conn.ConnectionShutdown += (o, ea) => + try { - HandleConnectionShutdown(_conn, ea, (args) => + _conn.ConnectionShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleConnectionShutdown(_conn, ea, (args) => { - publish1SyncSource.TrySetResult(false); - publish2SyncSource.TrySetResult(false); - } - }); - }; + if (args.Initiator == ShutdownInitiator.Peer) + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + } + }); + }; - _channel.ChannelShutdown += (o, ea) => - { - HandleChannelShutdown(_channel, ea, (args) => + _channel.ChannelShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleChannelShutdown(_channel, ea, (args) => { - publish1SyncSource.TrySetResult(false); - publish2SyncSource.TrySetResult(false); - } - }); - }; + if (args.Initiator == ShutdownInitiator.Peer) + { + publish1SyncSource.TrySetResult(false); + publish2SyncSource.TrySetResult(false); + } + }); + }; - QueueDeclareOk q = await _channel.QueueDeclareAsync(queue: queueName, exclusive: false, durable: true); - Assert.Equal(q, queueName); + QueueDeclareOk q = await _channel.QueueDeclareAsync(queue: queueName, exclusive: false, durable: true); + Assert.Equal(q, queueName); - Task publishTask = Task.Run(async () => - { - using (IChannel publishChannel = await _conn.CreateChannelAsync()) + Task publishTask = Task.Run(async () => { - QueueDeclareOk pubQ = await publishChannel.QueueDeclareAsync(queue: queueName, exclusive: false, durable: true); - Assert.Equal(queueName, pubQ.QueueName); - for (int i = 0; i < publish_total; i++) + using (IChannel publishChannel = await _conn.CreateChannelAsync()) { - await publishChannel.BasicPublishAsync(string.Empty, queueName, body1); - await publishChannel.BasicPublishAsync(string.Empty, queueName, body2); - } + QueueDeclareOk pubQ = await publishChannel.QueueDeclareAsync(queue: queueName, exclusive: false, durable: true); + Assert.Equal(queueName, pubQ.QueueName); + for (int i = 0; i < publish_total; i++) + { + await publishChannel.BasicPublishAsync(string.Empty, queueName, body1); + await publishChannel.BasicPublishAsync(string.Empty, queueName, body2); + } - await publishChannel.CloseAsync(); - } - }); + await publishChannel.CloseAsync(); + } + }); - Task consumeTask = Task.Run(async () => - { - using (IChannel consumeChannel = await _conn.CreateChannelAsync()) + Task consumeTask = Task.Run(async () => { - var consumer = new AsyncEventingBasicConsumer(consumeChannel); + using (IChannel consumeChannel = await _conn.CreateChannelAsync()) + { + var consumer = new AsyncEventingBasicConsumer(consumeChannel); - int publish1_count = 0; - int publish2_count = 0; + int publish1_count = 0; + int publish2_count = 0; - consumer.Received += async (o, a) => - { - string decoded = _encoding.GetString(a.Body.ToArray()); - if (decoded == publish1) + consumer.Received += async (o, a) => { - if (Interlocked.Increment(ref publish1_count) >= publish_total) + string decoded = _encoding.GetString(a.Body.ToArray()); + if (decoded == publish1) { - publish1SyncSource.TrySetResult(true); - await publish2SyncSource.Task; + if (Interlocked.Increment(ref publish1_count) >= publish_total) + { + publish1SyncSource.TrySetResult(true); + await publish2SyncSource.Task; + } } - } - else if (decoded == publish2) - { - if (Interlocked.Increment(ref publish2_count) >= publish_total) + else if (decoded == publish2) { - publish2SyncSource.TrySetResult(true); - await publish1SyncSource.Task; + if (Interlocked.Increment(ref publish2_count) >= publish_total) + { + publish2SyncSource.TrySetResult(true); + await publish1SyncSource.Task; + } } - } - }; + }; - await consumeChannel.BasicConsumeAsync(queueName, true, string.Empty, false, false, null, consumer); + await consumeChannel.BasicConsumeAsync(queueName, true, string.Empty, false, false, null, consumer); - // ensure we get a delivery - await AssertRanToCompletion(publish1SyncSource.Task, publish2SyncSource.Task); + // ensure we get a delivery + await AssertRanToCompletion(publish1SyncSource.Task, publish2SyncSource.Task); - bool result1 = await publish1SyncSource.Task; - Assert.True(result1, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + bool result1 = await publish1SyncSource.Task; + Assert.True(result1, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); - bool result2 = await publish2SyncSource.Task; - Assert.True(result2, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); + bool result2 = await publish2SyncSource.Task; + Assert.True(result2, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}"); - await consumeChannel.CloseAsync(); - } - }); + await consumeChannel.CloseAsync(); + } + }); - await AssertRanToCompletion(publishTask, consumeTask); + await AssertRanToCompletion(publishTask, consumeTask); + } + finally + { + tokenSource.Dispose(); + ctsr.Dispose(); + } } [Fact] public async Task TestBasicRejectAsync() { var publishSyncSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - using (var cancellationTokenSource = new CancellationTokenSource(TestTimeout)) + var cancellationTokenSource = new CancellationTokenSource(TestTimeout); + CancellationTokenRegistration ctsr = cancellationTokenSource.Token.Register(() => { - cancellationTokenSource.Token.Register(() => - { - publishSyncSource.SetCanceled(); - }); + publishSyncSource.SetCanceled(); + }); + try + { _conn.ConnectionShutdown += (o, ea) => { HandleConnectionShutdown(_conn, ea, (args) => @@ -330,6 +347,11 @@ await _channel.BasicConsumeAsync(queue: queueName, autoAck: false, Assert.Equal((uint)0, consumerCount); } } + finally + { + cancellationTokenSource.Dispose(); + ctsr.Dispose(); + } } [Fact] diff --git a/projects/Test/AsyncIntegration/TestAsyncConsumerExceptions.cs b/projects/Test/AsyncIntegration/TestAsyncConsumerExceptions.cs index dc1051d002..63f35a5fd5 100644 --- a/projects/Test/AsyncIntegration/TestAsyncConsumerExceptions.cs +++ b/projects/Test/AsyncIntegration/TestAsyncConsumerExceptions.cs @@ -93,20 +93,27 @@ protected async Task TestExceptionHandlingWith(IBasicConsumer consumer, var waitSpan = TimeSpan.FromSeconds(5); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var cts = new CancellationTokenSource(waitSpan); - cts.Token.Register(() => tcs.TrySetResult(false)); - - string q = await _channel.QueueDeclareAsync(string.Empty, false, true, false); - _channel.CallbackException += (ch, evt) => + CancellationTokenRegistration ctsr = cts.Token.Register(() => tcs.TrySetResult(false)); + try { - if (evt.Exception == TestException) + string q = await _channel.QueueDeclareAsync(string.Empty, false, true, false); + _channel.CallbackException += (ch, evt) => { - tcs.SetResult(true); - } - }; - - string tag = await _channel.BasicConsumeAsync(q, true, string.Empty, false, false, null, consumer); - await action(_channel, q, consumer, tag); - Assert.True(await tcs.Task); + if (evt.Exception == TestException) + { + tcs.SetResult(true); + } + }; + + string tag = await _channel.BasicConsumeAsync(q, true, string.Empty, false, false, null, consumer); + await action(_channel, q, consumer, tag); + Assert.True(await tcs.Task); + } + finally + { + cts.Dispose(); + ctsr.Dispose(); + } } private class ConsumerFailingOnDelivery : AsyncEventingBasicConsumer diff --git a/projects/Test/AsyncIntegration/TestConcurrentAccessWithSharedConnectionAsync.cs b/projects/Test/AsyncIntegration/TestConcurrentAccessWithSharedConnectionAsync.cs index 32a16dc546..e91088b47c 100644 --- a/projects/Test/AsyncIntegration/TestConcurrentAccessWithSharedConnectionAsync.cs +++ b/projects/Test/AsyncIntegration/TestConcurrentAccessWithSharedConnectionAsync.cs @@ -91,48 +91,56 @@ private Task TestConcurrentChannelOpenAndPublishingWithBodyAsync(byte[] body, in { var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var tokenSource = new CancellationTokenSource(LongWaitSpan); - tokenSource.Token.Register(() => + CancellationTokenRegistration ctsr = tokenSource.Token.Register(() => { tcs.TrySetResult(false); }); - using (IChannel ch = await _conn.CreateChannelAsync()) + try { - ch.ChannelShutdown += (o, ea) => + using (IChannel ch = await _conn.CreateChannelAsync()) { - HandleChannelShutdown(ch, ea, (args) => + ch.ChannelShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleChannelShutdown(ch, ea, (args) => { - tcs.TrySetResult(false); + if (args.Initiator == ShutdownInitiator.Peer) + { + tcs.TrySetResult(false); + } + }); + }; + + await ch.ConfirmSelectAsync(); + + ch.BasicAcks += (object sender, BasicAckEventArgs e) => + { + if (e.DeliveryTag >= _messageCount) + { + tcs.SetResult(true); } - }); - }; + }; - await ch.ConfirmSelectAsync(); + ch.BasicNacks += (object sender, BasicNackEventArgs e) => + { + tcs.SetResult(false); + _output.WriteLine($"channel #{ch.ChannelNumber} saw a nack, deliveryTag: {e.DeliveryTag}, multiple: {e.Multiple}"); + }; - ch.BasicAcks += (object sender, BasicAckEventArgs e) => - { - if (e.DeliveryTag >= _messageCount) + QueueDeclareOk q = await ch.QueueDeclareAsync(queue: string.Empty, passive: false, durable: false, exclusive: true, autoDelete: true, arguments: null); + for (ushort j = 0; j < _messageCount; j++) { - tcs.SetResult(true); + await ch.BasicPublishAsync("", q.QueueName, body, mandatory: true); } - }; - ch.BasicNacks += (object sender, BasicNackEventArgs e) => - { - tcs.SetResult(false); - _output.WriteLine($"channel #{ch.ChannelNumber} saw a nack, deliveryTag: {e.DeliveryTag}, multiple: {e.Multiple}"); - }; - - QueueDeclareOk q = await ch.QueueDeclareAsync(queue: string.Empty, passive: false, durable: false, exclusive: true, autoDelete: true, arguments: null); - for (ushort j = 0; j < _messageCount; j++) - { - await ch.BasicPublishAsync("", q.QueueName, body, mandatory: true); + Assert.True(await tcs.Task); + await ch.CloseAsync(); } - - Assert.True(await tcs.Task); - await ch.CloseAsync(); + } + finally + { + tokenSource.Dispose(); + ctsr.Dispose(); } }, iterations); } diff --git a/projects/Test/AsyncIntegration/TestFloodPublishingAsync.cs b/projects/Test/AsyncIntegration/TestFloodPublishingAsync.cs index e196ea063a..98a626b459 100644 --- a/projects/Test/AsyncIntegration/TestFloodPublishingAsync.cs +++ b/projects/Test/AsyncIntegration/TestFloodPublishingAsync.cs @@ -190,46 +190,54 @@ public async Task TestMultithreadFloodPublishingAsync() }); var cts = new CancellationTokenSource(WaitSpan); - cts.Token.Register(() => + CancellationTokenRegistration ctsr = cts.Token.Register(() => { tcs.TrySetResult(false); }); - using (IChannel consumeCh = await _conn.CreateChannelAsync()) + try { - consumeCh.ChannelShutdown += (o, ea) => + using (IChannel consumeCh = await _conn.CreateChannelAsync()) { - HandleChannelShutdown(consumeCh, ea, (args) => + consumeCh.ChannelShutdown += (o, ea) => { - if (args.Initiator == ShutdownInitiator.Peer) + HandleChannelShutdown(consumeCh, ea, (args) => { - tcs.TrySetResult(false); - } - }); - }; + if (args.Initiator == ShutdownInitiator.Peer) + { + tcs.TrySetResult(false); + } + }); + }; - var consumer = new AsyncEventingBasicConsumer(consumeCh); - consumer.Received += async (o, a) => - { - string receivedMessage = _encoding.GetString(a.Body.ToArray()); - Assert.Equal(message, receivedMessage); - if (Interlocked.Increment(ref receivedCount) == publishCount) + var consumer = new AsyncEventingBasicConsumer(consumeCh); + consumer.Received += async (o, a) => { - tcs.SetResult(true); - } - await Task.Yield(); - }; + string receivedMessage = _encoding.GetString(a.Body.ToArray()); + Assert.Equal(message, receivedMessage); + if (Interlocked.Increment(ref receivedCount) == publishCount) + { + tcs.SetResult(true); + } + await Task.Yield(); + }; - await consumeCh.BasicConsumeAsync(queue: queueName, autoAck: true, - consumerTag: string.Empty, noLocal: false, exclusive: false, - arguments: null, consumer: consumer); + await consumeCh.BasicConsumeAsync(queue: queueName, autoAck: true, + consumerTag: string.Empty, noLocal: false, exclusive: false, + arguments: null, consumer: consumer); - Assert.True(await tcs.Task); - await consumeCh.CloseAsync(); - } + Assert.True(await tcs.Task); + await consumeCh.CloseAsync(); + } - await pub; - Assert.Equal(publishCount, receivedCount); + await pub; + Assert.Equal(publishCount, receivedCount); + } + finally + { + cts.Dispose(); + ctsr.Dispose(); + } } } } diff --git a/projects/Test/Common/IntegrationFixtureBase.cs b/projects/Test/Common/IntegrationFixtureBase.cs index 53f25680e9..68e5098887 100644 --- a/projects/Test/Common/IntegrationFixtureBase.cs +++ b/projects/Test/Common/IntegrationFixtureBase.cs @@ -106,6 +106,8 @@ public IntegrationFixtureBase(ITestOutputHelper output) .Replace("AsyncIntegration.", "AI.") .Replace("Integration.", "I.") .Replace("SequentialI.", "SI."); + + // Console.SetOut(new TestOutputWriter(output, _testDisplayName)); } public virtual async Task InitializeAsync() @@ -388,11 +390,11 @@ protected static async Task WaitAsync(TaskCompletionSource tcs, TimeSpan t await tcs.Task.WaitAsync(timeSpan); bool result = await tcs.Task; Assert.True((true == result) && (tcs.Task.IsCompletedSuccessfully()), - $"waiting {timeSpan.TotalSeconds} seconds on a tcs for '{desc}' timed out"); + $"waiting {timeSpan.TotalSeconds} seconds on a tcs for '{desc}' failed"); } - catch (TimeoutException) + catch (TimeoutException ex) { - Assert.Fail($"waiting {timeSpan.TotalSeconds} seconds on a tcs for '{desc}' timed out"); + Assert.Fail($"waiting {timeSpan.TotalSeconds} seconds on a tcs for '{desc}' timed out, ex: {ex}"); } } diff --git a/projects/Test/Common/ProcessUtil.cs b/projects/Test/Common/ProcessUtil.cs index eb262e451c..8f21482003 100644 --- a/projects/Test/Common/ProcessUtil.cs +++ b/projects/Test/Common/ProcessUtil.cs @@ -1,8 +1,40 @@ -using System; +// This source code is dual-licensed under the Apache License, version +// 2.0, and the Mozilla Public License, version 2.0. +// +// The APL v2.0: +// +//--------------------------------------------------------------------------- +// Copyright (c) 2007-2020 VMware, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//--------------------------------------------------------------------------- +// +// The MPL v2.0: +// +//--------------------------------------------------------------------------- +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright (c) 2007-2020 VMware, Inc. All rights reserved. +//--------------------------------------------------------------------------- + +using System; using System.Collections.Generic; using System.Diagnostics; using System.Text; using System.Threading.Tasks; +using RabbitMQ.Client; namespace Test { diff --git a/projects/Test/Common/TaskExtensions.cs b/projects/Test/Common/TaskExtensions.cs deleted file mode 100644 index f67222a448..0000000000 --- a/projects/Test/Common/TaskExtensions.cs +++ /dev/null @@ -1,105 +0,0 @@ -// This source code is dual-licensed under the Apache License, version -// 2.0, and the Mozilla Public License, version 2.0. -// -// The APL v2.0: -// -//--------------------------------------------------------------------------- -// Copyright (c) 2007-2020 VMware, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -//--------------------------------------------------------------------------- -// -// The MPL v2.0: -// -//--------------------------------------------------------------------------- -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. -// -// Copyright (c) 2007-2020 VMware, Inc. All rights reserved. -//--------------------------------------------------------------------------- - -using System; -using System.Threading; -using System.Threading.Tasks; - -namespace Test -{ - internal static class TaskExtensions - { -#if NET6_0_OR_GREATER - public static Task WaitAsync(this Task task, TimeSpan timeout) - { - if (task.IsCompletedSuccessfully) - { - return task; - } - - return task.WaitAsync(timeout); - } -#else - private static readonly TaskContinuationOptions s_tco = TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously; - private static void IgnoreTaskContinuation(Task t, object s) => t.Exception.Handle(e => true); - - public static Task WaitAsync(this Task task, TimeSpan timeout) - { - if (task.Status == TaskStatus.RanToCompletion) - { - return task; - } - - return DoTimeoutAfter(task, timeout); - } - - // https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/master/AsyncGuidance.md#using-a-timeout - private static async Task DoTimeoutAfter(Task task, TimeSpan timeout) - { - using (var cts = new CancellationTokenSource()) - { - Task delayTask = Task.Delay(timeout, cts.Token); - Task resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); - if (resultTask == delayTask) - { - task.Ignore(); - throw new TimeoutException(); - } - else - { - cts.Cancel(); - } - - await task.ConfigureAwait(false); - } - } - - // https://github.com/dotnet/runtime/issues/23878 - // https://github.com/dotnet/runtime/issues/23878#issuecomment-1398958645 - private static void Ignore(this Task task) - { - if (task.IsCompleted) - { - _ = task.Exception; - } - else - { - _ = task.ContinueWith( - continuationAction: IgnoreTaskContinuation, - state: null, - cancellationToken: CancellationToken.None, - continuationOptions: s_tco, - scheduler: TaskScheduler.Default); - } - } -#endif - } -} diff --git a/projects/Test/Common/TestOutputWriter.cs b/projects/Test/Common/TestOutputWriter.cs new file mode 100644 index 0000000000..c2e47f8847 --- /dev/null +++ b/projects/Test/Common/TestOutputWriter.cs @@ -0,0 +1,63 @@ +// This source code is dual-licensed under the Apache License, version +// 2.0, and the Mozilla Public License, version 2.0. +// +// The APL v2.0: +// +//--------------------------------------------------------------------------- +// Copyright (c) 2007-2020 VMware, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//--------------------------------------------------------------------------- +// +// The MPL v2.0: +// +//--------------------------------------------------------------------------- +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright (c) 2007-2020 VMware, Inc. All rights reserved. +//--------------------------------------------------------------------------- + +using System.IO; +using System.Text; +using Xunit.Abstractions; + +namespace Test +{ + public class TestOutputWriter : TextWriter + { + private readonly ITestOutputHelper _output; + private readonly string _testDisplayName; + + public TestOutputWriter(ITestOutputHelper output, string testDisplayName) + { + _output = output; + _testDisplayName = testDisplayName; + } + + public override Encoding Encoding => Encoding.UTF8; + + public override void Write(char[] buffer, int index, int count) + { + if (count > 2) + { + var sb = new StringBuilder("[DEBUG] "); + sb.Append(_testDisplayName); + sb.Append(" | "); + sb.Append(buffer, index, count); + _output.WriteLine(sb.ToString().TrimEnd()); + } + } + } +} diff --git a/projects/Test/Integration/IntegrationFixture.cs b/projects/Test/Integration/IntegrationFixture.cs index b3bffde877..e8d389fd8f 100644 --- a/projects/Test/Integration/IntegrationFixture.cs +++ b/projects/Test/Integration/IntegrationFixture.cs @@ -29,7 +29,6 @@ // Copyright (c) 2007-2020 VMware, Inc. All rights reserved. //--------------------------------------------------------------------------- -using System.Threading; using Xunit.Abstractions; namespace Test.Integration @@ -39,18 +38,6 @@ public class IntegrationFixture : IntegrationFixtureBase public IntegrationFixture(ITestOutputHelper output) : base(output) { - int threadCount; - if (IsRunningInCI) - { - threadCount = _processorCount * 16; - } - else - { - // Assuming that dev machines have more cores - threadCount = _processorCount * 8; - } - - ThreadPool.SetMinThreads(threadCount, threadCount); } } } diff --git a/projects/Test/Integration/TestAuth.cs b/projects/Test/Integration/TestAuth.cs index 637c9490b3..6b0f9d534c 100644 --- a/projects/Test/Integration/TestAuth.cs +++ b/projects/Test/Integration/TestAuth.cs @@ -59,6 +59,8 @@ public async Task TestAuthFailure() ConnectionFactory connFactory = CreateConnectionFactory(); connFactory.UserName = "guest"; connFactory.Password = "incorrect-password"; + connFactory.AutomaticRecoveryEnabled = true; + connFactory.TopologyRecoveryEnabled = true; try { diff --git a/projects/Test/Integration/TestConnectionFactory.cs b/projects/Test/Integration/TestConnectionFactory.cs index 71d4ca0d3a..6a291f552a 100644 --- a/projects/Test/Integration/TestConnectionFactory.cs +++ b/projects/Test/Integration/TestConnectionFactory.cs @@ -123,7 +123,7 @@ public void TestConnectionFactoryWithCustomSocketFactory() } [Fact] - public async Task TestCreateConnectionUsesSpecifiedPort() + public async Task TestCreateConnectionWithInvalidPortThrows() { ConnectionFactory cf = CreateConnectionFactory(); cf.AutomaticRecoveryEnabled = true; @@ -312,7 +312,7 @@ public async Task TestCreateConnectionWithForcedAddressFamily() } [Fact] - public async Task TestCreateConnectionUsesInvalidAmqpTcpEndpoint() + public async Task TestCreateConnectionWithInvalidAmqpTcpEndpointThrows() { ConnectionFactory cf = CreateConnectionFactory(); var ep = new AmqpTcpEndpoint("localhost", 1234); diff --git a/projects/Test/Integration/TestHeartbeats.cs b/projects/Test/Integration/TestHeartbeats.cs index 93f29bed39..3d43a14e3a 100644 --- a/projects/Test/Integration/TestHeartbeats.cs +++ b/projects/Test/Integration/TestHeartbeats.cs @@ -100,46 +100,35 @@ public async Task TestHundredsOfConnectionsWithRandomHeartbeatInterval() const ushort connectionCount = 200; - ThreadPool.GetMinThreads(out int origWorkerThreads, out int origCompletionPortThreads); + var rnd = new Random(); + var conns = new List(); + try { - var rnd = new Random(); - var conns = new List(); - - // Since we are using the ThreadPool, let's set MinThreads to a high-enough value. - ThreadPool.SetMinThreads(connectionCount, connectionCount); - - try + for (int i = 0; i < connectionCount; i++) { - for (int i = 0; i < connectionCount; i++) - { - ushort n = Convert.ToUInt16(rnd.Next(2, 6)); - ConnectionFactory cf = CreateConnectionFactory(); - cf.RequestedHeartbeat = TimeSpan.FromSeconds(n); - cf.AutomaticRecoveryEnabled = false; - - IConnection conn = await cf.CreateConnectionAsync($"{_testDisplayName}:{i}"); - conns.Add(conn); - IChannel ch = await conn.CreateChannelAsync(); - conn.ConnectionShutdown += (sender, evt) => - { - CheckInitiator(evt); - }; - } - - await SleepFor(60); - } - finally - { - foreach (IConnection conn in conns) - { - await conn.CloseAsync(); - } + ushort n = Convert.ToUInt16(rnd.Next(2, 6)); + ConnectionFactory cf = CreateConnectionFactory(); + cf.RequestedHeartbeat = TimeSpan.FromSeconds(n); + cf.AutomaticRecoveryEnabled = false; + + IConnection conn = await cf.CreateConnectionAsync($"{_testDisplayName}:{i}"); + conns.Add(conn); + IChannel ch = await conn.CreateChannelAsync(); + conn.ConnectionShutdown += (sender, evt) => + { + CheckInitiator(evt); + }; } + + await SleepFor(60); } finally { - Assert.True(ThreadPool.SetMinThreads(origWorkerThreads, origCompletionPortThreads)); + foreach (IConnection conn in conns) + { + await conn.CloseAsync(); + } } } diff --git a/projects/Test/Integration/TestInitialConnection.cs b/projects/Test/Integration/TestInitialConnection.cs index 267d0b61c8..6aecda6928 100644 --- a/projects/Test/Integration/TestInitialConnection.cs +++ b/projects/Test/Integration/TestInitialConnection.cs @@ -45,24 +45,38 @@ public TestInitialConnection(ITestOutputHelper output) : base(output) { } + public override Task InitializeAsync() + { + // NB: nothing to do here since each test creates its own factory, + // connections and channels + Assert.Null(_connFactory); + Assert.Null(_conn); + Assert.Null(_channel); + return Task.CompletedTask; + } + [Fact] - public async Task TestBasicConnectionRecoveryWithHostnameList() + public async Task TestWithHostnameList() { - AutorecoveringConnection c = await CreateAutorecoveringConnectionAsync(new List() { "127.0.0.1", "localhost" }); - Assert.True(c.IsOpen); - await c.CloseAsync(); + using (AutorecoveringConnection c = await CreateAutorecoveringConnectionAsync(new List() { "127.0.0.1", "localhost" })) + { + Assert.True(c.IsOpen); + await c.CloseAsync(); + } } [Fact] - public async Task TestBasicConnectionRecoveryWithHostnameListAndUnreachableHosts() + public async Task TestWithHostnameListAndUnreachableHosts() { - AutorecoveringConnection c = await CreateAutorecoveringConnectionAsync(new List() { "191.72.44.22", "127.0.0.1", "localhost" }); - Assert.True(c.IsOpen); - await c.CloseAsync(); + using (AutorecoveringConnection c = await CreateAutorecoveringConnectionAsync(new List() { "191.72.44.22", "127.0.0.1", "localhost" })) + { + Assert.True(c.IsOpen); + await c.CloseAsync(); + } } [Fact] - public async Task TestBasicConnectionRecoveryWithHostnameListWithOnlyUnreachableHosts() + public async Task TestWithHostnameListWithOnlyUnreachableHosts() { await Assert.ThrowsAsync(() => { diff --git a/projects/Test/SequentialIntegration/TestConnectionBlocked.cs b/projects/Test/SequentialIntegration/TestConnectionBlocked.cs index bfba03f359..f3415af09e 100644 --- a/projects/Test/SequentialIntegration/TestConnectionBlocked.cs +++ b/projects/Test/SequentialIntegration/TestConnectionBlocked.cs @@ -44,6 +44,12 @@ public TestConnectionBlocked(ITestOutputHelper output) : base(output) { } + public override async Task InitializeAsync() + { + await UnblockAsync(); + await base.InitializeAsync(); + } + public override async Task DisposeAsync() { await UnblockAsync(); @@ -78,16 +84,26 @@ public async Task TestDisposeOnBlockedConnectionDoesNotHang() Task disposeTask = Task.Run(async () => { - await _conn.AbortAsync(); - _conn.Dispose(); - _conn = null; - tcs.SetResult(true); + try + { + await _conn.AbortAsync(); + _conn.Dispose(); + tcs.SetResult(true); + } + catch (Exception) + { + tcs.SetResult(false); + } + finally + { + _conn = null; + } }); Task anyTask = Task.WhenAny(tcs.Task, disposeTask); await anyTask.WaitAsync(LongWaitSpan); bool disposeSuccess = await tcs.Task; - Assert.True(disposeSuccess, "Dispose must have finished within 20 seconds after starting"); + Assert.True(disposeSuccess, $"Dispose must have finished within {LongWaitSpan.TotalSeconds} seconds after starting"); } } }