Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ComponentActivator #37784

Merged
merged 5 commits into from
Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;

namespace Internal.Runtime.InteropServices
{
Expand Down Expand Up @@ -39,6 +39,31 @@ private static string MarshalToString(IntPtr arg, string argName)
return result;
Copy link
Member

@jkotas jkotas Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two checks for null in MarshalToString is redundant. You can delete one of them while you are on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was probably added so the compiler knows we can't return null. Is adding '!' preferred over this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

! is preferred over adding unreachable code.

}

private static (Type? delegateType, string? delegateTypeName) GetDelegateTypeArgs(IntPtr delegateTypeNative)
{
// Determine the signature of the type. There are 3 possibilities:
// * No delegate type was supplied - use the default (i.e. ComponentEntryPoint).
// * A sentinel value was supplied - the function is marked UnmanagedCallersOnly. This means
// a function pointer can be returned without creating a delegate.
// * A delegate type was supplied - Load the type and create a delegate for that method.
string? delegateTypeName = null;
Type? delegateType = null;
if (delegateTypeNative == IntPtr.Zero)
{
delegateType = typeof(ComponentEntryPoint);
}
else if (delegateTypeNative == (IntPtr)(-1))
{
// Leave both null.
}
else
{
delegateTypeName = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
}

return (delegateType, delegateTypeName);
}

/// <summary>
/// Native hosting entry point for creating a native delegate
/// </summary>
Expand All @@ -50,45 +75,19 @@ private static string MarshalToString(IntPtr arg, string argName)
/// <param name="functionHandle">Pointer where to store the function pointer result</param>
[UnmanagedCallersOnly]
public static unsafe int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNative,
IntPtr typeNameNative,
IntPtr methodNameNative,
IntPtr delegateTypeNative,
IntPtr reserved,
IntPtr functionHandle)
IntPtr typeNameNative,
IntPtr methodNameNative,
IntPtr delegateTypeNative,
IntPtr reserved,
IntPtr functionHandle)
{
try
{
// Load the assembly and create a resolver callback for types.
// Validate all parameters first.
string assemblyPath = MarshalToString(assemblyPathNative, nameof(assemblyPathNative));
IsolatedComponentLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);
Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);

// Get the requested type.
string typeName = MarshalToString(typeNameNative, nameof(typeNameNative));
Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;

// Get the method name on the type.
string methodName = MarshalToString(methodNameNative, nameof(methodNameNative));

// Determine the signature of the type. There are 3 possibilities:
// * No delegate type was supplied - use the default (i.e. ComponentEntryPoint).
// * A sentinel value was supplied - the function is marked UnmanagedCallersOnly. This means
// a function pointer can be returned without creating a delegate.
// * A delegate type was supplied - Load the type and create a delegate for that method.
Type? delegateType;
if (delegateTypeNative == IntPtr.Zero)
{
delegateType = typeof(ComponentEntryPoint);
}
else if (delegateTypeNative == (IntPtr)(-1))
{
delegateType = null;
}
else
{
string delegateTypeName = MarshalToString(delegateTypeNative, nameof(delegateTypeNative));
delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
}
(Type? delegateType, string? delegateTypeName) = GetDelegateTypeArgs(delegateTypeNative);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic is harder to follow now that this is split over 3 different functions. Was it really necessary to split it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can InternalGetFunctionPointer just take delegateTypeNative?

I do not see the point by doing little bit of upfront work in GetDelegateTypeArgs, passing the tuple around, and finishing it in InternalGetFunctionPointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to validate the parameters before doing anything meaningful like creating an ALC. Can't the Marshal call throw an exception if given a bad pointer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't the Marshal call throw an exception if given a bad pointer?

Yes, but it's not really a parameter validation - it will try to read from the pointer, if that throws an AV it will turn into an exception.

I understand the goal here (to validate everything up-front), so I'm basically fine with this, but if @jkotas or others would prefer to not use this, I would also be OK if we duplicate the parameter validation (basically inline the method here as it was before, and then make a copy into the second one we're introducing) - or I would not mind if we simply postpone the marshal call to the point where we're going to actually get the delegate type.


if (reserved != IntPtr.Zero)
{
Expand All @@ -100,35 +99,11 @@ public static unsafe int LoadAssemblyAndGetFunctionPointer(IntPtr assemblyPathNa
throw new ArgumentNullException(nameof(functionHandle));
}

IntPtr functionPtr;
if (delegateType == null)
{
// Match search semantics of the CreateDelegate() function below.
BindingFlags bindingFlags = BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
MethodInfo? methodInfo = type.GetMethod(methodName, bindingFlags);
if (methodInfo == null)
throw new MissingMethodException(typeName, methodName);
// Set up the AssemblyLoadContext for this delegate.
AssemblyLoadContext alc = InternalLoadAssembly(assemblyPath);

// Verify the function is properly marked.
if (null == methodInfo.GetCustomAttribute<UnmanagedCallersOnlyAttribute>())
throw new InvalidOperationException(SR.InvalidOperation_FunctionMissingUnmanagedCallersOnly);

functionPtr = methodInfo.MethodHandle.GetFunctionPointer();
}
else
{
Delegate d = Delegate.CreateDelegate(delegateType, type, methodName)!;

functionPtr = Marshal.GetFunctionPointerForDelegate(d);

lock (s_delegates)
{
// Keep a reference to the delegate to prevent it from being garbage collected
s_delegates[functionPtr] = d;
}
}

*(IntPtr*)functionHandle = functionPtr;
// Create the function pointer.
*(IntPtr*)functionHandle = InternalGetFunctionPointer(alc, typeName, methodName, delegateType, delegateTypeName);
}
catch (Exception e)
{
Expand All @@ -153,5 +128,61 @@ private static IsolatedComponentLoadContext GetIsolatedComponentLoadContext(stri

return alc;
}

private static AssemblyLoadContext InternalLoadAssembly(string assemblyPath)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of wrapping GetIsolatedComponentLoadContext by another method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#37696 (comment). It's easy to remove InternalLoadAssembly until we actually need it if we don't want it today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also prefer to not have this method (if we ever introduce a native hosting way to create an ALC, it would have to take the parameter as IntPtr anyway, so I don't really see the usefulness of this wrapper).

{

IsolatedComponentLoadContext alc = GetIsolatedComponentLoadContext(assemblyPath);
return alc;
}

private static IntPtr InternalGetFunctionPointer(AssemblyLoadContext alc,
string typeName,
string methodName,
Type? delegateType,
string? delegateTypeName)
{
// Create a resolver callback for types.
Func<AssemblyName, Assembly> resolver = name => alc.LoadFromAssemblyName(name);

// Get the requested delegateType if a name was given.
if (delegateTypeName != null)
{
delegateType = Type.GetType(delegateTypeName, resolver, null, throwOnError: true)!;
}

// Get the requested type.
Type type = Type.GetType(typeName, resolver, null, throwOnError: true)!;

IntPtr functionPtr;
if (delegateType == null)
{
// Match search semantics of the CreateDelegate() function below.
BindingFlags bindingFlags = BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
MethodInfo? methodInfo = type.GetMethod(methodName, bindingFlags);
if (methodInfo == null)
throw new MissingMethodException(typeName, methodName);

// Verify the function is properly marked.
if (null == methodInfo.GetCustomAttribute<UnmanagedCallersOnlyAttribute>())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be changed to methodInfo.IsDefined(typeof(UnmanagedCallersOnlyAttribute), false). If memory serves @jkotas indicated this was a cheaper call at one point and when I originally wrote this I didn't know that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to make this change in this PR. The equivalent code is methodInfo.IsDefined(typeof(UnmanagedCallersOnlyAttribute), true).

throw new InvalidOperationException(SR.InvalidOperation_FunctionMissingUnmanagedCallersOnly);

functionPtr = methodInfo.MethodHandle.GetFunctionPointer();
}
else
{
Delegate d = Delegate.CreateDelegate(delegateType, type, methodName)!;

functionPtr = Marshal.GetFunctionPointerForDelegate(d);

lock (s_delegates)
{
// Keep a reference to the delegate to prevent it from being garbage collected
s_delegates[functionPtr] = d;
}
}

return functionPtr;
}
}
}
12 changes: 6 additions & 6 deletions src/installer/corehost/cli/test/nativehost/host_context_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ namespace
return _printable_delegate_name_t{ delegate_name };
}

int call_delegate_flavour(
int call_load_assembly_and_get_function_pointer_flavour(
load_assembly_and_get_function_pointer_fn delegate,
const pal::char_t *assembly_path,
const pal::char_t *type_name,
Expand Down Expand Up @@ -294,7 +294,7 @@ namespace
return rc;
}

bool load_assembly_and_get_function_pointer_test(
bool component_load_assembly_and_get_function_pointer_test(
const hostfxr_exports &hostfxr,
const pal::char_t *config_path,
int argc,
Expand Down Expand Up @@ -327,7 +327,7 @@ namespace
else
{
test_output << log_prefix << _X("hostfxr_get_runtime_delegate succeeded: ") << std::hex << std::showbase << rc << std::endl;
rc = call_delegate_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
rc = call_load_assembly_and_get_function_pointer_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
}
}

Expand Down Expand Up @@ -370,7 +370,7 @@ namespace
else
{
test_output << log_prefix << _X("hostfxr_get_runtime_delegate succeeded: ") << std::hex << std::showbase << rc << std::endl;
rc = call_delegate_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
rc = call_load_assembly_and_get_function_pointer_flavour(delegate, assembly_path, type_name, method_name, log_prefix, test_output);
}
}

Expand Down Expand Up @@ -603,7 +603,7 @@ bool host_context_test::non_context_mixed(
return success;
}

bool host_context_test::load_assembly_and_get_function_pointer(
bool host_context_test::component_load_assembly_and_get_function_pointer(
const pal::string_t &hostfxr_path,
const pal::char_t *config_path,
int argc,
Expand All @@ -612,7 +612,7 @@ bool host_context_test::load_assembly_and_get_function_pointer(
{
hostfxr_exports hostfxr{ hostfxr_path };

return load_assembly_and_get_function_pointer_test(hostfxr, config_path, argc, argv, config_log_prefix, test_output);
return component_load_assembly_and_get_function_pointer_test(hostfxr, config_path, argc, argv, config_log_prefix, test_output);
}

bool host_context_test::app_load_assembly_and_get_function_pointer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace host_context_test
int argc,
const pal::char_t *argv[],
pal::stringstream_t &test_output);
bool load_assembly_and_get_function_pointer(
bool component_load_assembly_and_get_function_pointer(
const pal::string_t &hostfxr_path,
const pal::char_t *config_path,
int argc,
Expand Down
4 changes: 2 additions & 2 deletions src/installer/corehost/cli/test/nativehost/nativehost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ int main(const int argc, const pal::char_t *argv[])
std::cout << tostr(test_output.str()).data() << std::endl;
return success ? EXIT_SUCCESS : EXIT_FAILURE;
}
else if (pal::strcmp(command, _X("load_assembly_and_get_function_pointer")) == 0)
else if (pal::strcmp(command, _X("component_load_assembly_and_get_function_pointer")) == 0)
{
// args: ... <hostfxr_path> <app_or_config_path> <assembly_path> <type_name> <method_name> [<assembly_path> <type_name> <method_name>...]
const int min_argc = 4;
Expand All @@ -236,7 +236,7 @@ int main(const int argc, const pal::char_t *argv[])
pal::stringstream_t test_output;
bool success = false;

success = host_context_test::load_assembly_and_get_function_pointer(hostfxr_path, app_or_config_path, remaining_argc, remaining_argv, test_output);
success = host_context_test::component_load_assembly_and_get_function_pointer(hostfxr_path, app_or_config_path, remaining_argc, remaining_argv, test_output);

std::cout << tostr(test_output.str()).data() << std::endl;
return success ? EXIT_SUCCESS : EXIT_FAILURE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<TargetFramework>$(NetCoreAppCurrent)</TargetFramework>
<RuntimeFrameworkVersion>$(MNAVersion)</RuntimeFrameworkVersion>
<GenerateRuntimeConfigurationFiles>true</GenerateRuntimeConfigurationFiles>
<EnableDynamicLoading>true</EnableDynamicLoading>
</PropertyGroup>

</Project>
Loading