Skip to content

Commit

Permalink
Add UnmanagedCallersOnly attribute to SafeDeleteSslContext.ReadFromCo…
Browse files Browse the repository at this point in the history
…nnection/WriteToConnection methods (#55947)

* Add UnmanagedCallersOnly attribute to  SafeDeleteSslContext.ReadFromConnection/WriteToConnection methods

* Fix the build

Co-authored-by: Jan Kotas <jkotas@microsoft.com>

* Fix the build

* Add SslSetConnection interop method to make sure the right SafeDeleteSslContext instance is associated to an ssl session

* Update entrypoints.c with new DllImport

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
Co-authored-by: Steve Pfister <steve.pfister@microsoft.com>
Co-authored-by: Alexander Köplinger <alex.koeplinger@outlook.com>
  • Loading branch information
4 people committed Aug 17, 2021
1 parent bbf4e79 commit fb9dac8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ internal enum PAL_TlsIo
[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslCreateContext")]
internal static extern System.Net.SafeSslHandle SslCreateContext(int isServer);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetConnection")]
internal static extern int SslSetConnection(
SafeSslHandle sslHandle,
IntPtr sslConnection);

[DllImport(Interop.Libraries.AppleCryptoNative)]
private static extern int AppleCryptoNative_SslSetMinProtocolVersion(
SafeSslHandle sslHandle,
Expand Down Expand Up @@ -119,10 +124,10 @@ private static extern int AppleCryptoNative_SslSetTargetName(
private static extern int AppleCryptoNative_SslSetAcceptClientCert(SafeSslHandle sslHandle);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslSetIoCallbacks")]
internal static extern int SslSetIoCallbacks(
internal static extern unsafe int SslSetIoCallbacks(
SafeSslHandle sslHandle,
SSLReadFunc readCallback,
SSLWriteFunc writeCallback);
delegate* unmanaged<IntPtr, byte*, void**, int> readCallback,
delegate* unmanaged<IntPtr, byte*, void**, int> writeCallback);

[DllImport(Interop.Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_SslWrite")]
internal static extern unsafe PAL_TlsIo SslWrite(SafeSslHandle sslHandle, byte* writeFrom, int count, out int bytesWritten);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ static const Entry s_cryptoAppleNative[] =
DllImportEntry(AppleCryptoNative_SecKeyCopyExternalRepresentation)
DllImportEntry(AppleCryptoNative_SecKeyCopyPublicKey)
DllImportEntry(AppleCryptoNative_SslCreateContext)
DllImportEntry(AppleCryptoNative_SslSetConnection)
DllImportEntry(AppleCryptoNative_SslSetAcceptClientCert)
DllImportEntry(AppleCryptoNative_SslSetMinProtocolVersion)
DllImportEntry(AppleCryptoNative_SslSetMaxProtocolVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer)
#pragma clang diagnostic pop
}

int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
return SSLSetConnection(sslContext, sslConnection);
#pragma clang diagnostic pop
}

int32_t AppleCryptoNative_SslSetAcceptClientCert(SSLContextRef sslContext)
{
#pragma clang diagnostic push
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ Returns NULL if an invalid boolean is given for isServer, an SSLContextRef other
*/
PALEXPORT SSLContextRef AppleCryptoNative_SslCreateContext(int32_t isServer);

/*
Data that is used to uniquely identify an SSL session.
Returns the result of SSLSetConnection
*/
PALEXPORT int32_t AppleCryptoNative_SslSetConnection(SSLContextRef sslContext, SSLConnectionRef sslConnection);

/*
Indicate that an SSL Context (in server mode) should allow a client to present a mutual auth cert.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Net.Http;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Win32.SafeHandles;
Expand All @@ -22,8 +23,6 @@ internal sealed class SafeDeleteSslContext : SafeDeleteContext
private const int OSStatus_errSSLWouldBlock = -9803;
private const int InitialBufferSize = 2048;
private SafeSslHandle _sslContext;
private Interop.AppleCrypto.SSLReadFunc _readCallback;
private Interop.AppleCrypto.SSLWriteFunc _writeCallback;
private ArrayBuffer _inputBuffer = new ArrayBuffer(InitialBufferSize);
private ArrayBuffer _outputBuffer = new ArrayBuffer(InitialBufferSize);

Expand All @@ -38,19 +37,20 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthentication
{
int osStatus;

_sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);

// Make sure the class instance is associated to the session and is provided
// in the Read/Write callback connection parameter
SslSetConnection(_sslContext);

unsafe
{
_readCallback = ReadFromConnection;
_writeCallback = WriteToConnection;
osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
_sslContext,
&ReadFromConnection,
&WriteToConnection);
}

_sslContext = CreateSslContext(credential, sslAuthenticationOptions.IsServer);

osStatus = Interop.AppleCrypto.SslSetIoCallbacks(
_sslContext,
_readCallback,
_writeCallback);

if (osStatus != 0)
{
throw Interop.AppleCrypto.CreateExceptionForOSStatus(osStatus);
Expand Down Expand Up @@ -142,6 +142,13 @@ private static SafeSslHandle CreateSslContext(SafeFreeSslCredentials credential,
return sslContext;
}

private void SslSetConnection(SafeSslHandle sslContext)
{
GCHandle handle = GCHandle.Alloc(this, GCHandleType.Weak);

Interop.AppleCrypto.SslSetConnection(sslContext, GCHandle.ToIntPtr(handle));
}

public override bool IsInvalid => _sslContext?.IsInvalid ?? true;

protected override void Dispose(bool disposing)
Expand All @@ -160,8 +167,12 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}

private unsafe int WriteToConnection(void* connection, byte* data, void** dataLength)
[UnmanagedCallersOnly]
private static unsafe int WriteToConnection(IntPtr connection, byte* data, void** dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

// We don't pool these buffers and we can't because there's a race between their us in the native
// read/write callbacks and being disposed when the SafeHandle is disposed. This race is benign currently,
// but if we were to pool the buffers we would have a potential use-after-free issue.
Expand All @@ -173,23 +184,27 @@ private unsafe int WriteToConnection(void* connection, byte* data, void** dataLe
int toWrite = (int)length;
var inputBuffer = new ReadOnlySpan<byte>(data, toWrite);

_outputBuffer.EnsureAvailableSpace(toWrite);
inputBuffer.CopyTo(_outputBuffer.AvailableSpan);
_outputBuffer.Commit(toWrite);
context._outputBuffer.EnsureAvailableSpace(toWrite);
inputBuffer.CopyTo(context._outputBuffer.AvailableSpan);
context._outputBuffer.Commit(toWrite);
// Since we can enqueue everything, no need to re-assign *dataLength.

return OSStatus_noErr;
}
catch (Exception e)
{
if (NetEventSource.Log.IsEnabled())
NetEventSource.Error(this, $"WritingToConnection failed: {e.Message}");
NetEventSource.Error(context, $"WritingToConnection failed: {e.Message}");
return OSStatus_writErr;
}
}

private unsafe int ReadFromConnection(void* connection, byte* data, void** dataLength)
[UnmanagedCallersOnly]
private static unsafe int ReadFromConnection(IntPtr connection, byte* data, void** dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

try
{
ulong toRead = (ulong)*dataLength;
Expand All @@ -201,16 +216,16 @@ private unsafe int ReadFromConnection(void* connection, byte* data, void** dataL

uint transferred = 0;

if (_inputBuffer.ActiveLength == 0)
if (context._inputBuffer.ActiveLength == 0)
{
*dataLength = (void*)0;
return OSStatus_errSSLWouldBlock;
}

int limit = Math.Min((int)toRead, _inputBuffer.ActiveLength);
int limit = Math.Min((int)toRead, context._inputBuffer.ActiveLength);

_inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
_inputBuffer.Discard(limit);
context._inputBuffer.ActiveSpan.Slice(0, limit).CopyTo(new Span<byte>(data, limit));
context._inputBuffer.Discard(limit);
transferred = (uint)limit;

*dataLength = (void*)transferred;
Expand All @@ -219,7 +234,7 @@ private unsafe int ReadFromConnection(void* connection, byte* data, void** dataL
catch (Exception e)
{
if (NetEventSource.Log.IsEnabled())
NetEventSource.Error(this, $"ReadFromConnectionfailed: {e.Message}");
NetEventSource.Error(context, $"ReadFromConnectionfailed: {e.Message}");
return OSStatus_readErr;
}
}
Expand Down

0 comments on commit fb9dac8

Please sign in to comment.