Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnmanagedCallersOnly attribute to SafeDeleteSslContext.ReadFromConnection/WriteToConnection methods #55947

Merged
merged 8 commits into from
Aug 17, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ internal enum PAL_TlsIo
[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslCreateContext")]
internal static extern System.Net.SafeSslHandle SslCreateContext(int isServer);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetConnection")]
internal static extern int SslSetConnection(
SafeSslHandle sslHandle,
IntPtr sslConnection);

[DllImport(Interop.Libraries.AppleCryptoNative)]
private static extern int AppleCryptoNative_SslSetMinProtocolVersion(
SafeSslHandle sslHandle,
Expand Down Expand Up @@ -119,10 +124,10 @@ private static extern int AppleCryptoNative_SslSetTargetName(
private static extern int AppleCryptoNative_SslSetAcceptClientCert(SafeSslHandle sslHandle);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetIoCallbacks")]
internal static extern int SslSetIoCallbacks(
internal static extern unsafe int SslSetIoCallbacks(
SafeSslHandle sslHandle,
SSLReadFunc readCallback,
SSLWriteFunc writeCallback);
delegate* unmanaged<IntPtr, byte*, void**, int> readCallback,
delegate* unmanaged<IntPtr, byte*, void**, int> writeCallback);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslWrite")]
internal static extern unsafe PAL_TlsIo SslWrite(SafeSslHandle sslHandle, byte* writeFrom, int count, out int bytesWritten);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer)
#pragma clang diagnostic pop
}

int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
return SSLSetConnection(sslContext, sslConnection);
#pragma clang diagnostic pop
}

int32_t AppleCryptoNative_SslSetAcceptClientCert(SSLContextRef sslContext)
{
#pragma clang diagnostic push
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ Returns NULL if an invalid boolean is given for isServer, an SSLContextRef other
*/
PALEXPORT SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer);

/*
Data that is used to uniquely identify an SSL session.

Returns the result of SSLSetConnection
*/
PALEXPORT int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection);

/*
Indicate that an SSL Context (in server mode) should allow a client to present a mutual auth cert.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Net.Http;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Win32.SafeHandles;
Expand All @@ -22,8 +23,6 @@ internal sealed class SafeDeleteSslContext : SafeDeleteContext
private const int OSStatus_errSSLWouldBlock = -9803;
private const int InitialBufferSize = 2048;
private SafeSslHandle _sslContext;
private Interop.AppleCrypto.SSLReadFunc _readCallback;
private Interop.AppleCrypto.SSLWriteFunc _writeCallback;
private ArrayBuffer _inputBuffer = new ArrayBuffer(InitialBufferSize);
private ArrayBuffer _outputBuffer = new ArrayBuffer(InitialBufferSize);

Expand All @@ -38,19 +37,20 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthentication
{
int osStatus;

_sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);

// Make sure the class instance is associated to the session and is provided
// in the Read/Write callback connection parameter
SslSetConnection(_sslContext);

unsafe
{
_readCallback = ReadFromConnection;
_writeCallback = WriteToConnection;
osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
_sslContext,
&ReadFromConnection,
&WriteToConnection);
}

_sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);

osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
_sslContext,
_readCallback,
_writeCallback);

if (osStatus != 0)
{
throw Interop.AppleCrypto.CreateExceptionForOSStatus(osStatus);
Expand Down Expand Up @@ -142,6 +142,13 @@ private static SafeSslHandle CreateSslContext(SafeFreeSslCredentials credential,
return sslContext;
}

private void SslSetConnection(SafeSslHandle sslContext)
{
GCHandle handle = GCHandle.Alloc(this, GCHandleType.Weak);

Interop.AppleCrypto.SslSetConnection(sslContext, GCHandle.ToIntPtr(handle));
}

public override bool IsInvalid => _sslContext?.IsInvalid ?? true;

protected override void Dispose(bool disposing)
Expand All @@ -160,8 +167,12 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}

private unsafe int WriteToConnection(void* connection, byte* data, void** dataLength)
[UnmanagedCallersOnly]
private static unsafe int WriteToConnection(IntPtr connection, byte* data, void** dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

// We don't pool these buffers and we can't because there's a race between their us in the native
// read/write callbacks and being disposed when the SafeHandle is disposed. This race is benign currently,
// but if we were to pool the buffers we would have a potential use-after-free issue.
Expand All @@ -173,23 +184,27 @@ private unsafe int WriteToConnection(void* connection, byte* data, void** dataLe
int toWrite = (int)length;
var inputBuffer = new ReadOnlySpan<byte>(data, toWrite);

_outputBuffer.EnsureAvailableSpace(toWrite);
inputBuffer.CopyTo(_outputBuffer.AvailableSpan);
_outputBuffer.Commit(toWrite);
context._outputBuffer.EnsureAvailableSpace(toWrite);
inputBuffer.CopyTo(context._outputBuffer.AvailableSpan);
context._outputBuffer.Commit(toWrite);
// Since we can enqueue everything, no need to re-assign *dataLength.

return OSStatus_noErr;
}
catch (Exception e)
{
if (NetEventSource.Log.IsEnabled())
NetEventSource.Error(this, $"WritingToConnection failed: {e.Message}");
NetEventSource.Error(context, $"WritingToConnection failed: {e.Message}");
return OSStatus_writErr;
}
}

private unsafe int ReadFromConnection(void* connection, byte* data, void** dataLength)
[UnmanagedCallersOnly]
private static unsafe int ReadFromConnection(IntPtr connection, byte* data, void** dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

try
{
ulong toRead = (ulong)*dataLength;
Expand All @@ -201,16 +216,16 @@ private unsafe int ReadFromConnection(void* connection, byte* data, void** dataL

uint transferred = 0;

if (_inputBuffer.ActiveLength == 0)
if (context._inputBuffer.ActiveLength == 0)
{
*dataLength = (void*)0;
return OSStatus_errSSLWouldBlock;
}

int limit = Math.Min((int)toRead, _inputBuffer.ActiveLength);
int limit = Math.Min((int)toRead, context._inputBuffer.ActiveLength);

_inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
_inputBuffer.Discard(limit);
context._inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
context._inputBuffer.Discard(limit);
transferred = (uint)limit;

*dataLength = (void*)transferred;
Expand All @@ -219,7 +234,7 @@ private unsafe int ReadFromConnection(void* connection, byte* data, void** dataL
catch (Exception e)
{
if (NetEventSource.Log.IsEnabled())
NetEventSource.Error(this, $"ReadFromConnectionfailed: {e.Message}");
NetEventSource.Error(context, $"ReadFromConnectionfailed: {e.Message}");
return OSStatus_readErr;
}
}
Expand Down