Skip to content

Commit

Permalink
Extend PreInitialization Support to readonly GC fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Suchiman committed Apr 6, 2023
1 parent 657865f commit 6a53219
Showing 1 changed file with 186 additions and 25 deletions.
211 changes: 186 additions & 25 deletions src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/TypePreinit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;

using ILCompiler.DependencyAnalysis;
Expand Down Expand Up @@ -49,7 +50,7 @@ private TypePreinit(MetadataType owningType, CompilationModuleGroup compilationG
if (!field.IsStatic || field.IsLiteral || field.IsThreadStatic || field.HasRva)
continue;

_fieldValues.Add(field, NewUninitializedLocationValue(field.FieldType));
_fieldValues.Add(field, NewUninitializedLocationValue(field.FieldType));
}
}

Expand Down Expand Up @@ -352,11 +353,11 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack<Method
// and resetting it would lead to unpredictable analysis durations.
int baseInstructionCounter = instructionCounter;
Status status = nestedPreinit.TryScanMethod(field.OwningType.GetStaticConstructor(), null, recursionProtect, ref instructionCounter, out Value _);
recursionProtect.Pop();
if (!status.IsSuccessful)
{
return Status.Fail(methodIL.OwningMethod, opcode, "Nested cctor failed to preinit");
}
recursionProtect.Pop();
Value value = nestedPreinit._fieldValues[field];
if (value is ValueTypeValue)
stack.PushFromLocation(field.FieldType, value);
Expand Down Expand Up @@ -440,12 +441,11 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack<Method
recursionProtect ??= new Stack<MethodDesc>();
recursionProtect.Push(methodIL.OwningMethod);
Status callResult = TryScanMethod(method, methodParams, recursionProtect, ref instructionCounter, out retVal);
recursionProtect.Pop();
if (!callResult.IsSuccessful)
{
recursionProtect.Pop();
return callResult;
}
recursionProtect.Pop();
}

if (!methodSig.ReturnType.IsVoid)
Expand Down Expand Up @@ -575,13 +575,11 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack<Method
recursionProtect ??= new Stack<MethodDesc>();
recursionProtect.Push(methodIL.OwningMethod);
Status ctorCallResult = TryScanMethod(ctor, ctorParameters, recursionProtect, ref instructionCounter, out _);
recursionProtect.Pop();
if (!ctorCallResult.IsSuccessful)
{
recursionProtect.Pop();
return ctorCallResult;
}

recursionProtect.Pop();
}

stack.PushFromLocation(owningType, instance);
Expand All @@ -600,7 +598,7 @@ private Status TryScanMethod(MethodIL methodIL, Value[] parameters, Stack<Method
Value value = stack.PopIntoLocation(field.FieldType);
StackEntry instance = stack.Pop();

if (field.FieldType.IsGCPointer && value != null)
if (field.FieldType.IsGCPointer && value != null && !field.IsInitOnly)
{
return Status.Fail(methodIL.OwningMethod, opcode, "Reference field");
}
Expand Down Expand Up @@ -1891,8 +1889,9 @@ private abstract class BaseValueTypeValue : Value
}

// Also represents pointers and function pointer.
private sealed class ValueTypeValue : BaseValueTypeValue, IAssignableValue
private sealed class ValueTypeValue : BaseValueTypeValue, IAssignableValue, IHasInstanceFields
{
public Dictionary<FieldDesc, Value> GCFields;
public readonly byte[] InstanceBytes;

public override int Size => InstanceBytes.Length;
Expand All @@ -1903,19 +1902,20 @@ public ValueTypeValue(TypeDesc type)
InstanceBytes = new byte[type.GetElementSize().AsInt];
}

private ValueTypeValue(byte[] bytes)
private ValueTypeValue(byte[] bytes, Dictionary<FieldDesc, Value> gcFields = null)
{
InstanceBytes = bytes;
GCFields = gcFields;
}

public override Value Clone()
{
return new ValueTypeValue((byte[])InstanceBytes.Clone());
return new ValueTypeValue((byte[])InstanceBytes.Clone(), GCFields == null ? null : new(GCFields));
}

public override bool TryCreateByRef(out Value value)
{
value = new ByRefValue(InstanceBytes, 0);
value = new ByRefValue(InstanceBytes, 0, this);
return true;
}

Expand All @@ -1933,6 +1933,7 @@ bool IAssignableValue.TryAssign(Value value)
}

Array.Copy(vtvalue.InstanceBytes, InstanceBytes, InstanceBytes.Length);
GCFields = vtvalue.GCFields == null ? null : new(vtvalue.GCFields);
return true;
}

Expand All @@ -1942,24 +1943,69 @@ public override bool Equals(Value value)
|| vtvalue.InstanceBytes.Length != InstanceBytes.Length)
{
ThrowHelper.ThrowInvalidProgramException();
return false;
}

if ((GCFields == null) != (vtvalue.GCFields == null) || GCFields != null && GCFields.Count != vtvalue.GCFields.Count)
return false;

for (int i = 0; i < InstanceBytes.Length; i++)
{
if (InstanceBytes[i] != ((ValueTypeValue)value).InstanceBytes[i])
if (InstanceBytes[i] != vtvalue.InstanceBytes[i])
return false;
}

if (GCFields != null)
{
foreach (var (field, myValue) in GCFields)
{
if (!vtvalue.GCFields.TryGetValue(field, out var otherValue))
return false;

if ((myValue == null) != (otherValue == null) || myValue != null && !myValue.Equals(otherValue))
return false;
}
}

return true;
}

public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory factory)
{
builder.EmitBytes(InstanceBytes);
if (GCFields != null)
{
int bytesWritten = 0;
foreach (var (field, value) in GCFields.OrderBy(x => x.Key.Offset.AsInt))
{
int fieldOffset = field.Offset.AsInt;
int fieldSize = field.FieldType.GetElementSize().AsInt;
//if (fieldOffset + fieldSize > _instanceBytes.Length - _offset)
// ThrowHelper.ThrowInvalidProgramException();

int bytesFromInstance = bytesWritten - fieldOffset;
if (bytesFromInstance > 0)
{
builder.EmitBytes(InstanceBytes, bytesWritten, bytesFromInstance);
bytesWritten += bytesFromInstance;
}
value.WriteFieldData(ref builder, factory);
bytesWritten += fieldSize;
}
}
else
{
builder.EmitBytes(InstanceBytes);
}
}

public override bool GetRawData(NodeFactory factory, out object data)
{
if (GCFields != null)
{
data = null;
return false;
}

data = InstanceBytes;
return true;
}
Expand All @@ -1973,6 +2019,33 @@ private byte[] AsExactByteCount(int size)
return InstanceBytes;
}

Value IHasInstanceFields.GetField(FieldDesc field)
{
if (field.FieldType.IsGCPointer)
{
return GCFields.GetValueOrDefault(field);
}
else
{
return new FieldAccessor(InstanceBytes, 0).GetField(field);
}
}

void IHasInstanceFields.SetField(FieldDesc field, Value value)
{
if (field.FieldType.IsGCPointer)
{
GCFields ??= new();
GCFields[field] = value;
}
else
{
new FieldAccessor(InstanceBytes, 0).SetField(field, value);
}
}

ByRefValue IHasInstanceFields.GetFieldAddress(FieldDesc field) => new FieldAccessor(InstanceBytes, 0).GetFieldAddress(field);

public override sbyte AsSByte() => (sbyte)AsExactByteCount(1)[0];
public override short AsInt16() => BitConverter.ToInt16(AsExactByteCount(2), 0);
public override int AsInt32() => BitConverter.ToInt32(AsExactByteCount(4), 0);
Expand Down Expand Up @@ -2144,11 +2217,13 @@ private sealed class ByRefValue : Value, IHasInstanceFields
{
public readonly byte[] PointedToBytes;
public readonly int PointedToOffset;
public readonly IHasInstanceFields PointedToType;

public ByRefValue(byte[] pointedToBytes, int pointedToOffset)
public ByRefValue(byte[] pointedToBytes, int pointedToOffset, IHasInstanceFields pointedToType = null)
{
PointedToBytes = pointedToBytes;
PointedToOffset = pointedToOffset;
PointedToType = pointedToType;
}

public override bool Equals(Value value)
Expand All @@ -2162,9 +2237,39 @@ public override bool Equals(Value value)
&& PointedToOffset == ((ByRefValue)value).PointedToOffset;
}

Value IHasInstanceFields.GetField(FieldDesc field) => new FieldAccessor(PointedToBytes, PointedToOffset).GetField(field);
void IHasInstanceFields.SetField(FieldDesc field, Value value) => new FieldAccessor(PointedToBytes, PointedToOffset).SetField(field, value);
ByRefValue IHasInstanceFields.GetFieldAddress(FieldDesc field) => new FieldAccessor(PointedToBytes, PointedToOffset).GetFieldAddress(field);
Value IHasInstanceFields.GetField(FieldDesc field)
{
if (PointedToType != null)
{
return PointedToType.GetField(field);
}
else
{
return new FieldAccessor(PointedToBytes, PointedToOffset).GetField(field);
}
}

void IHasInstanceFields.SetField(FieldDesc field, Value value)
{
if (PointedToType != null)
{
PointedToType.SetField(field, value);
}
else
{
new FieldAccessor(PointedToBytes, PointedToOffset).SetField(field, value);
}
}

ByRefValue IHasInstanceFields.GetFieldAddress(FieldDesc field)
{
if (field.FieldType.IsGCPointer)
{
throw new NotSupportedException();
}

return new FieldAccessor(PointedToBytes, PointedToOffset).GetFieldAddress(field);
}

public void Initialize(int size)
{
Expand Down Expand Up @@ -2523,6 +2628,7 @@ private class ObjectInstance : AllocatedReferenceTypeValue, IHasInstanceFields,
#pragma warning restore CA1852
{
private readonly byte[] _data;
public Dictionary<FieldDesc, Value> GCFields;

public ObjectInstance(DefType type, AllocationSite allocationSite)
: base(type, allocationSite)
Expand Down Expand Up @@ -2565,9 +2671,32 @@ public bool TryUnboxAny(TypeDesc type, out Value value)
return true;
}

Value IHasInstanceFields.GetField(FieldDesc field) => new FieldAccessor(_data).GetField(field);
void IHasInstanceFields.SetField(FieldDesc field, Value value) => new FieldAccessor(_data).SetField(field, value);
ByRefValue IHasInstanceFields.GetFieldAddress(FieldDesc field) => new FieldAccessor(_data).GetFieldAddress(field);
Value IHasInstanceFields.GetField(FieldDesc field)
{
if (field.FieldType.IsGCPointer)
{
return GCFields.GetValueOrDefault(field);
}
else
{
return new FieldAccessor(_data, 0).GetField(field);
}
}

void IHasInstanceFields.SetField(FieldDesc field, Value value)
{
if (field.FieldType.IsGCPointer)
{
GCFields ??= new();
GCFields[field] = value;
}
else
{
new FieldAccessor(_data, 0).SetField(field, value);
}
}

ByRefValue IHasInstanceFields.GetFieldAddress(FieldDesc field) => new FieldAccessor(_data, 0).GetFieldAddress(field);

public override void WriteFieldData(ref ObjectDataBuilder builder, NodeFactory factory)
{
Expand All @@ -2581,10 +2710,42 @@ public virtual void WriteContent(ref ObjectDataBuilder builder, ISymbolNode this
Debug.Assert(!node.RepresentsIndirectionCell); // Shouldn't have allowed preinitializing this
builder.EmitPointerReloc(node);

// We skip the first pointer because that's the MethodTable pointer
// we just initialized above.
int pointerSize = factory.Target.PointerSize;
builder.EmitBytes(_data, pointerSize, _data.Length - pointerSize);
if (GCFields != null)
{
// We skip the first pointer because that's the MethodTable pointer
// we just initialized above.
int bytesWritten = factory.Target.PointerSize;
foreach (var (field, value) in GCFields.OrderBy(x => x.Key.Offset.AsInt))
{
int fieldOffset = field.Offset.AsInt;
int fieldSize = field.FieldType.GetElementSize().AsInt;
//if (fieldOffset + fieldSize > _instanceBytes.Length - _offset)
// ThrowHelper.ThrowInvalidProgramException();

int bytesFromInstance = bytesWritten - fieldOffset;
if (bytesFromInstance > 0)
{
builder.EmitBytes(_data, bytesWritten, bytesFromInstance);
bytesWritten += bytesFromInstance;
}
if (value == null)
{
builder.EmitZeroPointer();
}
else
{
value.WriteFieldData(ref builder, factory);
}
bytesWritten += fieldSize;
}
}
else
{
// We skip the first pointer because that's the MethodTable pointer
// we just initialized above.
int pointerSize = factory.Target.PointerSize;
builder.EmitBytes(_data, pointerSize, _data.Length - pointerSize);
}
}

public bool IsKnownImmutable => !Type.GetFields().GetEnumerator().MoveNext();
Expand Down

0 comments on commit 6a53219

Please sign in to comment.