Skip to content

Commit

Permalink
Refactor ComponentActivator (dotnet#37784)
Browse files Browse the repository at this point in the history
* Refactor ComponentActivator.

* Add ComponentActivation test for initialization from self-contained app.

* Rename tests around load_assembly_and_get_function_pointer.

* Update app ComponentActivation tests to initialize from an app.
Address PR feedback.

* Rely on PtrToStringAuto to call Uni or UTF8
  • Loading branch information
rseanhall committed Jun 13, 2020
1 parent ba4dc36 commit e393cde
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;

namespace Internal.Runtime.InteropServices
{
Expand All @@ -21,17 +21,7 @@ public static class ComponentActivator

private static string MarshalToString(IntPtr arg, string argName)
{
if (arg == IntPtr.Zero)
{
throw new ArgumentNullException(argName);
}

#if TARGET_WINDOWS
string? result = Marshal.PtrToStringUni(arg);
#else
string? result = Marshal.PtrToStringUTF8(arg);
#endif

string? result = Marshal.PtrToStringAuto(arg);
if (result == null)
{
throw new ArgumentNullException(argName);
Expand All @@ -50,46 +40,19 @@ private static string MarshalToString(IntPtr arg, string argName)
/// <param name="functionHandle">Pointer where to store the function pointer result</param>
[UnmanagedCallersOnly]
public static unsafe int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNative,
IntPtr typeNameNative,
IntPtr methodNameNative,
IntPtr delegateTypeNative,
IntPtr reserved,
IntPtr functionHandle)
IntPtr typeNameNative,
IntPtr methodNameNative,
IntPtr delegateTypeNative,
IntPtr reserved,
IntPtr functionHandle)
{
try
{
// Load the assembly and create a resolver callback for types.
// Validate all parameters first.
string assemblyPath = MarshalToString(assemblyPathNative, nameof(assemblyPathNative));
IsolatedComponentLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);
Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);

// Get the requested type.
string typeName = MarshalToString(typeNameNative, nameof(typeNameNative));
Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;

// Get the method name on the type.
string methodName = MarshalToString(methodNameNative, nameof(methodNameNative));

// Determine the signature of the type. There are 3 possibilities:
// * No delegate type was supplied - use the default (i.e. ComponentEntryPoint).
// * A sentinel value was supplied - the function is marked UnmanagedCallersOnly. This means
// a function pointer can be returned without creating a delegate.
// * A delegate type was supplied - Load the type and create a delegate for that method.
Type? delegateType;
if (delegateTypeNative == IntPtr.Zero)
{
delegateType = typeof(ComponentEntryPoint);
}
else if (delegateTypeNative == (IntPtr)(-1))
{
delegateType = null;
}
else
{
string delegateTypeName = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
}

if (reserved != IntPtr.Zero)
{
throw new ArgumentOutOfRangeException(nameof(reserved));
Expand All @@ -100,35 +63,11 @@ public static unsafe int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNa
throw new ArgumentNullException(nameof(functionHandle));
}

IntPtr functionPtr;
if (delegateType == null)
{
// Match search semantics of the CreateDelegate() function below.
BindingFlags bindingFlags = BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
MethodInfo? methodInfo = type.GetMethod(methodName, bindingFlags);
if (methodInfo == null)
throw new MissingMethodException(typeName, methodName);

// Verify the function is properly marked.
if (null == methodInfo.GetCustomAttribute<UnmanagedCallersOnlyAttribute>())
throw new InvalidOperationException(SR.InvalidOperation_FunctionMissingUnmanagedCallersOnly);

functionPtr = methodInfo.MethodHandle.GetFunctionPointer();
}
else
{
Delegate d = Delegate.CreateDelegate(delegateType, type, methodName)!;

functionPtr = Marshal.GetFunctionPointerForDelegate(d);

lock (s_delegates)
{
// Keep a reference to the delegate to prevent it from being garbage collected
s_delegates[functionPtr] = d;
}
}
// Set up the AssemblyLoadContext for this delegate.
AssemblyLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);

*(IntPtr*)functionHandle = functionPtr;
// Create the function pointer.
*(IntPtr*)functionHandle = InternalGetFunctionPointer(alc, typeName, methodName, delegateTypeNative);
}
catch (Exception e)
{
Expand All @@ -153,5 +92,67 @@ private static IsolatedComponentLoadContext GetIsolatedComponentLoadContext(stri

return alc;
}

private static IntPtr InternalGetFunctionPointer(AssemblyLoadContext alc,
string typeName,
string methodName,
IntPtr delegateTypeNative)
{
// Create a resolver callback for types.
Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);

// Determine the signature of the type. There are 3 possibilities:
// * No delegate type was supplied - use the default (i.e. ComponentEntryPoint).
// * A sentinel value was supplied - the function is marked UnmanagedCallersOnly. This means
// a function pointer can be returned without creating a delegate.
// * A delegate type was supplied - Load the type and create a delegate for that method.
Type? delegateType;
if (delegateTypeNative == IntPtr.Zero)
{
delegateType = typeof(ComponentEntryPoint);
}
else if (delegateTypeNative == (IntPtr)(-1))
{
delegateType = null;
}
else
{
string delegateTypeName = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
}

// Get the requested type.
Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;

IntPtr functionPtr;
if (delegateType == null)
{
// Match search semantics of the CreateDelegate() function below.
BindingFlags bindingFlags = BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
MethodInfo? methodInfo = type.GetMethod(methodName, bindingFlags);
if (methodInfo == null)
throw new MissingMethodException(typeName, methodName);

// Verify the function is properly marked.
if (null == methodInfo.GetCustomAttribute<UnmanagedCallersOnlyAttribute>())
throw new InvalidOperationException(SR.InvalidOperation_FunctionMissingUnmanagedCallersOnly);

functionPtr = methodInfo.MethodHandle.GetFunctionPointer();
}
else
{
Delegate d = Delegate.CreateDelegate(delegateType, type, methodName)!;

functionPtr = Marshal.GetFunctionPointerForDelegate(d);

lock (s_delegates)
{
// Keep a reference to the delegate to prevent it from being garbage collected
s_delegates[functionPtr] = d;
}
}

return functionPtr;
}
}
}
12 changes: 6 additions & 6 deletions src/installer/corehost/cli/test/nativehost/host_context_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ namespace
return _printable_delegate_name_t{ delegate_name };
}

int call_delegate_flavour(
int call_load_assembly_and_get_function_pointer_flavour(
load_assembly_and_get_function_pointer_fn delegate,
const pal::char_t *assembly_path,
const pal::char_t *type_name,
Expand Down Expand Up @@ -294,7 +294,7 @@ namespace
return rc;
}

bool load_assembly_and_get_function_pointer_test(
bool component_load_assembly_and_get_function_pointer_test(
const hostfxr_exports &hostfxr,
const pal::char_t *config_path,
int argc,
Expand Down Expand Up @@ -327,7 +327,7 @@ namespace
else
{
test_output << log_prefix << _X("hostfxr_get_runtime_delegate succeeded: ") << std::hex << std::showbase << rc << std::endl;
rc = call_delegate_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
rc = call_load_assembly_and_get_function_pointer_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
}
}

Expand Down Expand Up @@ -370,7 +370,7 @@ namespace
else
{
test_output << log_prefix << _X("hostfxr_get_runtime_delegate succeeded: ") << std::hex << std::showbase << rc << std::endl;
rc = call_delegate_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
rc = call_load_assembly_and_get_function_pointer_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
}
}

Expand Down Expand Up @@ -603,7 +603,7 @@ bool host_context_test::non_context_mixed(
return success;
}

bool host_context_test::load_assembly_and_get_function_pointer(
bool host_context_test::component_load_assembly_and_get_function_pointer(
const pal::string_t &hostfxr_path,
const pal::char_t *config_path,
int argc,
Expand All @@ -612,7 +612,7 @@ bool host_context_test::load_assembly_and_get_function_pointer(
{
hostfxr_exports hostfxr{ hostfxr_path };

return load_assembly_and_get_function_pointer_test(hostfxr, config_path, argc, argv, config_log_prefix, test_output);
return component_load_assembly_and_get_function_pointer_test(hostfxr, config_path, argc, argv, config_log_prefix, test_output);
}

bool host_context_test::app_load_assembly_and_get_function_pointer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace host_context_test
int argc,
const pal::char_t *argv[],
pal::stringstream_t &test_output);
bool load_assembly_and_get_function_pointer(
bool component_load_assembly_and_get_function_pointer(
const pal::string_t &hostfxr_path,
const pal::char_t *config_path,
int argc,
Expand Down
4 changes: 2 additions & 2 deletions src/installer/corehost/cli/test/nativehost/nativehost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ int main(const int argc, const pal::char_t *argv[])
std::cout << tostr(test_output.str()).data() << std::endl;
return success ? EXIT_SUCCESS : EXIT_FAILURE;
}
else if (pal::strcmp(command, _X("load_assembly_and_get_function_pointer")) == 0)
else if (pal::strcmp(command, _X("component_load_assembly_and_get_function_pointer")) == 0)
{
// args: ... <hostfxr_path> <app_or_config_path> <assembly_path> <type_name> <method_name> [<assembly_path> <type_name> <method_name>...]
const int min_argc = 4;
Expand All @@ -236,7 +236,7 @@ int main(const int argc, const pal::char_t *argv[])
pal::stringstream_t test_output;
bool success = false;

success = host_context_test::load_assembly_and_get_function_pointer(hostfxr_path, app_or_config_path, remaining_argc, remaining_argv, test_output);
success = host_context_test::component_load_assembly_and_get_function_pointer(hostfxr_path, app_or_config_path, remaining_argc, remaining_argv, test_output);

std::cout << tostr(test_output.str()).data() << std::endl;
return success ? EXIT_SUCCESS : EXIT_FAILURE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>$(NetCoreAppCurrent)</TargetFramework>
<RuntimeFrameworkVersion>$(MNAVersion)</RuntimeFrameworkVersion>
<GenerateRuntimeConfigurationFiles>true</GenerateRuntimeConfigurationFiles>
<EnableDynamicLoading>true</EnableDynamicLoading>
</PropertyGroup>

</Project>
Loading

0 comments on commit e393cde

Please sign in to comment.