Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mechanism for runtime to query host for information #78798

Merged
merged 5 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/coreclr/dlls/mscoree/exports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#endif // FEATURE_GDBJIT
#include "bundle.h"
#include "pinvokeoverride.h"
#include <hostinformation.h>
#include <corehost/host_runtime_contract.h>

#define ASSERTE_ALL_BUILDS(expr) _ASSERTE_ALL_BUILDS((expr))

Expand Down Expand Up @@ -122,7 +124,8 @@ static void ConvertConfigPropertiesToUnicode(
LPCWSTR** propertyValuesWRef,
BundleProbeFn** bundleProbe,
PInvokeOverrideFn** pinvokeOverride,
bool* hostPolicyEmbedded)
bool* hostPolicyEmbedded,
host_runtime_contract** hostContract)
{
LPCWSTR* propertyKeysW = new (nothrow) LPCWSTR[propertyCount];
ASSERTE_ALL_BUILDS(propertyKeysW != nullptr);
Expand All @@ -139,7 +142,8 @@ static void ConvertConfigPropertiesToUnicode(
{
// If this application is a single-file bundle, the bundle-probe callback
// is passed in as the value of "BUNDLE_PROBE" property (encoded as a string).
*bundleProbe = (BundleProbeFn*)_wcstoui64(propertyValuesW[propertyIndex], nullptr, 0);
if (*bundleProbe == nullptr)
*bundleProbe = (BundleProbeFn*)_wcstoui64(propertyValuesW[propertyIndex], nullptr, 0);
}
else if (strcmp(propertyKeys[propertyIndex], "PINVOKE_OVERRIDE") == 0)
{
Expand All @@ -152,6 +156,14 @@ static void ConvertConfigPropertiesToUnicode(
// The HOSTPOLICY_EMBEDDED property indicates if the executable has hostpolicy statically linked in
*hostPolicyEmbedded = (wcscmp(propertyValuesW[propertyIndex], W("true")) == 0);
}
else if (strcmp(propertyKeys[propertyIndex], HOST_PROPERTY_RUNTIME_CONTRACT) == 0)
{
// Host contract is passed in as the value of HOST_RUNTIME_CONTRACT property (encoded as a string).
host_runtime_contract* hostContractLocal = (host_runtime_contract*)_wcstoui64(propertyValuesW[propertyIndex], nullptr, 0);
*hostContract = hostContractLocal;
if (hostContractLocal->bundle_probe != nullptr)
*bundleProbe = hostContractLocal->bundle_probe;
}
}

*propertyKeysWRef = propertyKeysW;
Expand Down Expand Up @@ -196,6 +208,7 @@ int coreclr_initialize(
BundleProbeFn* bundleProbe = nullptr;
bool hostPolicyEmbedded = false;
PInvokeOverrideFn* pinvokeOverride = nullptr;
host_runtime_contract* hostContract = nullptr;

ConvertConfigPropertiesToUnicode(
propertyKeys,
Expand All @@ -205,7 +218,8 @@ int coreclr_initialize(
&propertyValuesW,
&bundleProbe,
&pinvokeOverride,
&hostPolicyEmbedded);
&hostPolicyEmbedded,
&hostContract);

#ifdef TARGET_UNIX
DWORD error = PAL_InitializeCoreCLR(exePath, g_coreclr_embedded);
Expand All @@ -221,7 +235,12 @@ int coreclr_initialize(

g_hostpolicy_embedded = hostPolicyEmbedded;

if (pinvokeOverride != nullptr)
if (hostContract != nullptr)
{
HostInformation::SetContract(hostContract);
}

if (pinvokeOverride != nullptr && (hostContract == nullptr || hostContract->pinvoke_override == nullptr))
{
PInvokeOverride::SetPInvokeOverride(pinvokeOverride, PInvokeOverride::Source::RuntimeConfiguration);
}
Expand Down
16 changes: 16 additions & 0 deletions src/coreclr/inc/hostinformation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#ifndef _HOSTINFORMATION_H_
#define _HOSTINFORMATION_H_

#include <corehost/host_runtime_contract.h>

class HostInformation
{
public:
static void SetContract(_In_ host_runtime_contract* hostContract);
static bool GetProperty(_In_z_ const char* name, SString& value);
};

#endif // _HOSTINFORMATION_H_
1 change: 1 addition & 0 deletions src/coreclr/vm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ set(VM_SOURCES_WKS
genanalysis.cpp
genmeth.cpp
hosting.cpp
hostinformation.cpp
ilmarshalers.cpp
interopconverter.cpp
interoputil.cpp
Expand Down
46 changes: 46 additions & 0 deletions src/coreclr/vm/hostinformation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#include "common.h"
#include "hostinformation.h"
#include "pinvokeoverride.h"

namespace
{
host_runtime_contract* s_hostContract = nullptr;
}

void HostInformation::SetContract(_In_ host_runtime_contract* hostContract)
{
_ASSERTE(s_hostContract == nullptr);
s_hostContract = hostContract;

if (s_hostContract->pinvoke_override != nullptr)
PInvokeOverride::SetPInvokeOverride(s_hostContract->pinvoke_override, PInvokeOverride::Source::RuntimeConfiguration);
}

bool HostInformation::GetProperty(_In_z_ const char* name, SString& value)
{
if (s_hostContract == nullptr || s_hostContract->get_runtime_property == nullptr)
return false;

size_t len = MAX_PATH + 1;
char* dest = value.OpenUTF8Buffer(static_cast<COUNT_T>(len));
size_t lenActual = s_hostContract->get_runtime_property(name, dest, len, s_hostContract->context);
value.CloseBuffer();

// Doesn't exist or failed to get property
if (lenActual == (size_t)-1 || lenActual == 0)
return false;

if (lenActual <= len)
return true;

// Buffer was not large enough
len = lenActual;
dest = value.OpenUTF8Buffer(static_cast<COUNT_T>(len));
lenActual = s_hostContract->get_runtime_property(name, dest, len, s_hostContract->context);
value.CloseBuffer();

return lenActual > 0 && lenActual <= len;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
<OutputType>Exe</OutputType>
<RuntimeFrameworkVersion>$(MNAVersion)</RuntimeFrameworkVersion>
<DefineConstants Condition="'$(OS)' == 'Windows_NT'">WINDOWS;$(DefineConstants)</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;

namespace HostApiInvokerApp
{
public static unsafe class HostRuntimeContract
{
internal struct host_runtime_contract
{
public void* context;
public IntPtr bundle_probe;
public IntPtr pinvoke_override;
public delegate* unmanaged[Stdcall]<byte*, byte*, nint, void*, nint> get_runtime_property;
}

private static host_runtime_contract GetContract()
{
string contractString = (string)AppContext.GetData("HOST_RUNTIME_CONTRACT");
if (string.IsNullOrEmpty(contractString))
throw new Exception("HOST_RUNTIME_CONTRACT not found");

host_runtime_contract* contract = (host_runtime_contract*)Convert.ToUInt64(contractString, 16);
return *contract;
}

private static void Test_get_runtime_property(string[] args)
{
host_runtime_contract contract = GetContract();

foreach (string name in args)
{
string value = GetProperty(name, contract);
Console.WriteLine($"{nameof(host_runtime_contract.get_runtime_property)}: {name} = {(value == null ? "<none>" : value)}");
}

static string GetProperty(string name, host_runtime_contract contract)
{
Span<byte> nameSpan = stackalloc byte[Encoding.UTF8.GetMaxByteCount(name.Length)];
byte* namePtr = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(nameSpan));
int nameLen = Encoding.UTF8.GetBytes(name, nameSpan);
nameSpan[nameLen] = 0;

nint len = 256;
byte* buffer = stackalloc byte[(int)len];
nint lenActual = contract.get_runtime_property(namePtr, buffer, len, contract.context);
if (lenActual <= 0)
{
Console.WriteLine($"No value for {name} - {nameof(host_runtime_contract.get_runtime_property)} returned {lenActual}");
return null;
}

if (lenActual <= len)
return Encoding.UTF8.GetString(buffer, (int)lenActual);

len = lenActual;
byte* expandedBuffer = stackalloc byte[(int)len];
lenActual = contract.get_runtime_property(namePtr, expandedBuffer, len, contract.context);
return Encoding.UTF8.GetString(expandedBuffer, (int)lenActual);
}
}

public static bool RunTest(string apiToTest, string[] args)
{
switch (apiToTest)
{
case $"{nameof(host_runtime_contract)}.{nameof(host_runtime_contract.get_runtime_property)}":
Test_get_runtime_property(args);
break;
default:
return false;
}

return true;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@ public static void MainCore(string[] args)
Console.WriteLine("Hello World!");
Console.WriteLine(string.Join(Environment.NewLine, args));

// A small operation involving NewtonSoft.Json to ensure the assembly is loaded properly
var t = typeof(Newtonsoft.Json.JsonReader);

// Enable tracing so that test assertion failures are easier to diagnose.
Environment.SetEnvironmentVariable("COREHOST_TRACE", "1");

// If requested, test multilevel lookup using fake Global SDK directories:
// 1. using a fake ProgramFiles location
// 2. using a fake SDK Self-Registered location
// Note that this has to be set here and not in the calling test process because
// Note that this has to be set here and not in the calling test process because
// %ProgramFiles% gets reset on process creation.
string testMultilevelLookupProgramFiles = Environment.GetEnvironmentVariable("TEST_MULTILEVEL_LOOKUP_PROGRAM_FILES");
string testMultilevelLookupSelfRegistered = Environment.GetEnvironmentVariable("TEST_MULTILEVEL_LOOKUP_SELF_REGISTERED");
Expand All @@ -65,17 +62,15 @@ public static void MainCore(string[] args)

string apiToTest = args[0];
if (HostFXR.RunTest(apiToTest, args))
{
return;
}
else if (HostPolicy.RunTest(apiToTest, args))
{

if (HostPolicy.RunTest(apiToTest, args))
return;
}
else
{
throw new ArgumentException($"Invalid API to test passed as args[0]): {apiToTest}");
}

if (HostRuntimeContract.RunTest(apiToTest, args))
return;

throw new ArgumentException($"Invalid API to test passed as args[0]): {apiToTest}");
}
}
}
15 changes: 15 additions & 0 deletions src/installer/tests/HostActivation.Tests/NativeHostApis.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,21 @@ public void Hostpolicy_corehost_set_error_writer_test()
.Should().Pass();
}

[Fact]
public void HostRuntimeContract_get_runtime_property()
{
var fixture = sharedTestState.HostApiInvokerAppFixture;

fixture.BuiltDotnet.Exec(fixture.TestProject.AppDll, "host_runtime_contract.get_runtime_property", "APP_CONTEXT_BASE_DIRECTORY", "ENTRY_ASSEMBLY_NAME", "DOES_NOT_EXIST")
.CaptureStdOut()
.CaptureStdErr()
.Execute()
.Should().Pass()
.And.HaveStdOutContaining($"APP_CONTEXT_BASE_DIRECTORY = {Path.GetDirectoryName(fixture.TestProject.AppDll)}")
.And.HaveStdOutContaining($"ENTRY_ASSEMBLY_NAME = {fixture.TestProject.AssemblyName}")
.And.HaveStdOutContaining($"DOES_NOT_EXIST = <none>");
}

public class SharedTestState : IDisposable
{
public TestProjectFixture HostApiInvokerAppFixture { get; }
Expand Down
41 changes: 41 additions & 0 deletions src/native/corehost/host_runtime_contract.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#ifndef __HOST_RUNTIME_CONTRACT_H__
#define __HOST_RUNTIME_CONTRACT_H__

#include <stddef.h>
#include <stdint.h>

#if defined(_WIN32)
#define HOST_CONTRACT_CALLTYPE __stdcall
#else
#define HOST_CONTRACT_CALLTYPE
#endif

// Known host property names
#define HOST_PROPERTY_RUNTIME_CONTRACT "HOST_RUNTIME_CONTRACT"
#define HOST_PROPERTY_ENTRY_ASSEMBLY_NAME "ENTRY_ASSEMBLY_NAME"

struct host_runtime_contract
{
void* context;

bool(HOST_CONTRACT_CALLTYPE* bundle_probe)(
const char* path,
int64_t* offset,
int64_t* size,
int64_t* compressedSize);

const void* (HOST_CONTRACT_CALLTYPE* pinvoke_override)(
const char* library_name,
const char* entry_point_name);

size_t(HOST_CONTRACT_CALLTYPE* get_runtime_property)(
const char* key,
char* value_buffer,
size_t value_buffer_size,
void* contract_context);
};

#endif // __HOST_RUNTIME_CONTRACT_H__
11 changes: 11 additions & 0 deletions src/native/corehost/hostmisc/pal.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ namespace pal
return buffer;
}

size_t pal_utf8string(const string_t& str, char* out_buffer, size_t len);
bool pal_utf8string(const string_t& str, std::vector<char>* out);
bool pal_clrstring(const string_t& str, std::vector<char>* out);
bool clr_palstring(const char* cstr, string_t* out);
Expand Down Expand Up @@ -236,6 +237,16 @@ namespace pal

inline const string_t strerror(int errnum) { return ::strerror(errnum); }

inline size_t pal_utf8string(const string_t& str, char* out_buffer, size_t buffer_len)
{
size_t len = str.size() + 1;
if (buffer_len < len)
return len;

::strncpy(out_buffer, str.c_str(), str.size());
out_buffer[len - 1] = '\0';
return len;
}
inline bool pal_utf8string(const string_t& str, std::vector<char>* out) { out->assign(str.begin(), str.end()); out->push_back('\0'); return true; }
inline bool pal_clrstring(const string_t& str, std::vector<char>* out) { return pal_utf8string(str, out); }
inline bool clr_palstring(const char* cstr, string_t* out) { out->assign(cstr); return true; }
Expand Down
11 changes: 11 additions & 0 deletions src/native/corehost/hostmisc/pal.windows.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,17 @@ static bool wchar_convert_helper(DWORD code_page, const char* cstr, size_t len,
return ::MultiByteToWideChar(code_page, 0, cstr, static_cast<uint32_t>(len), &(*out)[0], static_cast<uint32_t>(out->size())) != 0;
}

size_t pal::pal_utf8string(const pal::string_t& str, char* out_buffer, size_t len)
{
// Pass -1 as we want explicit null termination in the char buffer.
size_t size = ::WideCharToMultiByte(CP_UTF8, 0, str.c_str(), -1, nullptr, 0, nullptr, nullptr);
if (size == 0 || size > len)
return size;

// Pass -1 as we want explicit null termination in the char buffer.
return ::WideCharToMultiByte(CP_UTF8, 0, str.c_str(), -1, out_buffer, static_cast<uint32_t>(len), nullptr, nullptr);
}

bool pal::pal_utf8string(const pal::string_t& str, std::vector<char>* out)
{
out->clear();
Expand Down
Loading