Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement discriminator of models #3599

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b7a391a
remove the link temporarily
ArcturusZhang Jul 2, 2024
8180b12
add a step to build the test project
ArcturusZhang Jul 2, 2024
9d534d8
fix the base type issue
ArcturusZhang Jul 2, 2024
f1b692d
fix issues
ArcturusZhang Jul 2, 2024
ec92093
Merge branch 'main' into fix-base-type-issue
ArcturusZhang Jul 3, 2024
2afcc7b
Merge branch 'main' into remove-the-link-to-make-the-project-build
ArcturusZhang Jul 3, 2024
3504d18
update the base ctor
ArcturusZhang Jul 3, 2024
ca1e5b9
fix the deserialization static method issue
ArcturusZhang Jul 3, 2024
942f936
refactor
ArcturusZhang Jul 3, 2024
65ebb99
fix typo
ArcturusZhang Jul 3, 2024
a649a22
add the link back
ArcturusZhang Jul 3, 2024
a9a84f7
fix typo
ArcturusZhang Jul 3, 2024
f79423c
Merge remote-tracking branch 'forked/remove-the-link-to-make-the-proj…
ArcturusZhang Jul 3, 2024
1704461
fix the return type issue
ArcturusZhang Jul 3, 2024
e1edc70
fix test cases
ArcturusZhang Jul 3, 2024
b38bf1f
Merge remote-tracking branch 'origin/main' into fix-base-type-issue
ArcturusZhang Jul 3, 2024
47ee984
fix after merge
ArcturusZhang Jul 3, 2024
3b44a4b
overhauls the test cases for mrw serialization
ArcturusZhang Jul 4, 2024
0d53c60
refine
ArcturusZhang Jul 4, 2024
b21a84b
Merge branch 'main' into fix-base-type-issue
ArcturusZhang Jul 5, 2024
2e3caa4
Merge remote-tracking branch 'origin/main' into fix-base-type-issue
ArcturusZhang Jul 8, 2024
4b635b6
resolve comments
ArcturusZhang Jul 8, 2024
6c0ccb6
implement the discriminator
ArcturusZhang Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix the deserialization static method issue
  • Loading branch information
ArcturusZhang committed Jul 3, 2024
commit ca1e5b97bd1d690566b7cd0c9b3012e5c865ae24
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ internal sealed class MrwSerializationTypeProvider : TypeProvider
private MethodProvider? _serializationConstructor;
// Flag to determine if the model should override the serialization methods
private readonly bool _shouldOverrideMethods;
private readonly MrwSerializationTypeProvider? _baseSerializationProvider;

public MrwSerializationTypeProvider(TypeProvider provider, InputModelType inputModel)
{
Expand All @@ -67,6 +68,7 @@ public MrwSerializationTypeProvider(TypeProvider provider, InputModelType inputM
_jsonModelObjectInterface = _isStruct ? (CSharpType)typeof(IJsonModel<object>) : null;
_persistableModelTInterface = new CSharpType(typeof(IPersistableModel<>), provider.Type);
_persistableModelObjectInterface = _isStruct ? (CSharpType)typeof(IPersistableModel<object>) : null;
_baseSerializationProvider = FindSerializationOnBase(_model);
_rawDataField = BuildRawDataField();
_shouldOverrideMethods = _model.Inherits != null && _model.Inherits is { IsFrameworkType: false, Implementation: TypeProvider };
_utf8JsonWriterSnippet = new Utf8JsonWriterSnippet(_utf8JsonWriterParameter);
Expand Down Expand Up @@ -146,8 +148,20 @@ protected override MethodProvider[] BuildConstructors()
return null;
}

// check if there is a raw data field on my base, if so, we do not have to have one here
if (_baseSerializationProvider?._rawDataField != null)
{
return null;
}

var modifiers = FieldModifiers.Private;
if (!_model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Sealed))
{
modifiers |= FieldModifiers.Protected;
}

var FieldProvider = new FieldProvider(
modifiers: FieldModifiers.Private,
modifiers: modifiers,
type: _privateAdditionalRawDataPropertyType,
name: PrivateAdditionalPropertiesPropertyName);

Expand Down Expand Up @@ -324,7 +338,7 @@ internal MethodProvider BuildPersistableModelWriteCoreMethod()
// BinaryData PersistableModelWriteCore(ModelReaderWriterOptions options)
return new MethodProvider
(
new MethodSignature(PersistableModelWriteCoreMethodName, null, modifiers, returnType, null, [ _serializationOptionsParameter]),
new MethodSignature(PersistableModelWriteCoreMethodName, null, modifiers, returnType, null, [_serializationOptionsParameter]),
BuildPersistableModelWriteCoreMethodBody(),
this
);
Expand Down Expand Up @@ -454,14 +468,16 @@ internal MethodProvider BuildPersistableModelGetFormatFromOptionsObjectDeclarati
/// <returns>The constructed serialization constructor.</returns>
internal MethodProvider BuildSerializationConstructor()
{
var serializationCtorParameters = BuildSerializationConstructorParameters();
var (baseCtor, initializer) = BuildConstructorInitializer();
var serializationCtorParameters = BuildSerializationConstructorParameters(baseCtor?.Parameters ?? []);

return new MethodProvider(
signature: new ConstructorSignature(
Type,
$"Initializes a new instance of {Type:C}",
MethodSignatureModifiers.Internal,
serializationCtorParameters),
serializationCtorParameters,
Initializer: initializer),
bodyStatements: new MethodBodyStatement[]
{
GetPropertyInitializers(serializationCtorParameters)
Expand Down Expand Up @@ -500,7 +516,20 @@ private MethodBodyStatement[] BuildDeserializationMethodBody()
MethodBodyStatement rawDataDictionaryDeclaration = MethodBodyStatement.Empty;
MethodBodyStatement assignRawData = MethodBodyStatement.Empty;

if (_rawDataField != null)
// recusively get the raw data field from myself and all the base
var rawDataField = _rawDataField;
var baseSerialization = _baseSerializationProvider;
while (rawDataField == null)
{
if (baseSerialization == null)
{
break;
}
rawDataField = baseSerialization._rawDataField;
baseSerialization = baseSerialization._baseSerializationProvider;
}

if (rawDataField != null)
{
var rawDataType = new CSharpType(typeof(Dictionary<string, BinaryData>));
// IDictionary<string, BinaryData> serializedAdditionalRawData = default;
Expand All @@ -519,32 +548,34 @@ private MethodBodyStatement[] BuildDeserializationMethodBody()
assignRawData = additionalRawDataDictionary.Assign(rawDataDictionary).Terminate();
}

var allProperties = GetAllProperties();

// Build the deserialization statements for each property
ForeachStatement deserializePropertiesForEachStatement = new("prop", _jsonElementParameterSnippet.EnumerateObject(), out var prop)
{
BuildDeserializePropertiesStatements(new JsonPropertySnippet(prop), rawDataDictionary)
BuildDeserializePropertiesStatements(allProperties, new JsonPropertySnippet(prop), rawDataDictionary)
};

return
[
new IfStatement(_jsonElementParameterSnippet.ValueKindEqualsNull()) { Return(Null) },
GetPropertyVariableDeclarations(),
GetPropertyVariableDeclarations(allProperties),
additionalRawDataDictionaryDeclaration,
rawDataDictionaryDeclaration,
deserializePropertiesForEachStatement,
assignRawData,
Return(New.Instance(_model.Type, GetSerializationCtorParameterValues(additionalRawDataDictionary)))
Return(New.Instance(_model.Type, GetSerializationCtorParameterValues(allProperties, additionalRawDataDictionary)))
];
}

private MethodBodyStatement[] GetPropertyVariableDeclarations()
private MethodBodyStatement[] GetPropertyVariableDeclarations(IReadOnlyList<PropertyProvider> properties)
{
var propertyCount = _model.Properties.Count;
var propertyCount = properties.Count;
MethodBodyStatement[] propertyDeclarationStatements = new MethodBodyStatement[propertyCount];

for (var i = 0; i < propertyCount; i++)
{
var property = _model.Properties[i];
var property = properties[i];
var variableRef = property.AsVariableExpression;
propertyDeclarationStatements[i] = Declare(variableRef, Default);
}
Expand Down Expand Up @@ -599,24 +630,26 @@ private MethodBodyStatement CallBaseJsonModelWriteCore()

private MethodBodyStatement GetPropertyInitializers(IReadOnlyList<ParameterProvider> parameters)
{
var parameterDict = parameters.ToDictionary(p => p.Name, p => p);
List<MethodBodyStatement> methodBodyStatements = new();

foreach (var param in parameters)
foreach (var property in _model.Properties)
{
if (param.Name == _rawDataField?.Name.ToVariableName())
{
methodBodyStatements.Add(_rawDataField.Assign(param).Terminate());
continue;
}

ValueExpression initializationValue = param;
var initializationStatement = param.AsPropertyExpression.Assign(initializationValue).Terminate();
if (initializationStatement != null)
var parameterName = property.Name.FirstCharToLowerCase();
if (parameterDict.TryGetValue(parameterName, out var parameter))
{
ValueExpression initializationValue = parameter;
var initializationStatement = property.Assign(initializationValue).Terminate();
methodBodyStatements.Add(initializationStatement);
}
}

if (_rawDataField != null)
{
var parameterName = _rawDataField.Name.ToVariableName();
methodBodyStatements.Add(_rawDataField.Assign(parameterDict[parameterName]).Terminate());
}

return methodBodyStatements;
}

Expand All @@ -625,17 +658,18 @@ private MethodBodyStatement GetPropertyInitializers(IReadOnlyList<ParameterProvi
/// <paramref name="additionalRawDataDictionary"/> is the variable reference for the additional raw data dictionary.
/// </summary>
private ValueExpression[] GetSerializationCtorParameterValues(
IReadOnlyList<PropertyProvider> properties,
VariableExpression? additionalRawDataDictionary)
{
var propertyCount = _model.Properties.Count;
var propertyCount = properties.Count;
var serializationCtorParametersCount = SerializationConstructor.Signature.Parameters.Count;
ValueExpression[] serializationCtorParameters = new ValueExpression[serializationCtorParametersCount];
var serializationCtorParameterValues = new Dictionary<string, ValueExpression>(propertyCount);

// Map property variable names to their corresponding parameter values
for (var i = 0; i < propertyCount; i++)
{
var property = _model.Properties[i];
var property = properties[i];
var propertyVarName = property.Name.ToVariableName();
var propertyVarRef = property.AsVariableExpression;
serializationCtorParameterValues[propertyVarName] = GetValueForSerializationConstructor(property, propertyVarRef, property.WireInfo);
Expand All @@ -651,7 +685,7 @@ private ValueExpression[] GetSerializationCtorParameterValues(
for (var i = 0; i < serializationCtorParametersCount; i++)
{
var parameter = SerializationConstructor.Signature.Parameters[i];
var paramVarName = parameter.Name.ToVariableName();
var paramVarName = parameter.Name;
serializationCtorParameters[i] = serializationCtorParameterValues.TryGetValue(paramVarName, out var value) ? value : Default;
}

Expand Down Expand Up @@ -679,14 +713,15 @@ private static ValueExpression GetValueForSerializationConstructor(
}

private List<MethodBodyStatement> BuildDeserializePropertiesStatements(
IReadOnlyList<PropertyProvider> properties,
JsonPropertySnippet jsonPropertySnippet,
DictionarySnippet? rawDataDictionary)
{
List<MethodBodyStatement> propertyDeserializationStatements = new();
// Create each property's deserialization statement
for (var i = 0; i < _model.Properties.Count; i++)
for (var i = 0; i < properties.Count; i++)
{
var property = _model.Properties[i];
var property = properties[i];
var propertyWireInfo = property.WireInfo;
var propertySerializationName = propertyWireInfo?.SerializedName ?? property.Name;
var checkIfJsonPropEqualsName = new IfStatement(jsonPropertySnippet.NameEquals(propertySerializationName.ToVariableName()))
Expand All @@ -710,6 +745,20 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(
return propertyDeserializationStatements;
}

private IReadOnlyList<PropertyProvider> GetAllProperties()
{
var provider = _model;
var properties = new List<PropertyProvider>(provider.Properties);

while (provider.Inherits is { IsFrameworkType: false, Implementation: TypeProvider baseType })
{
provider = baseType;
properties.AddRange(provider.Properties);
}

return properties;
}

private MethodBodyStatement[] DeserializeProperty(
PropertyProvider property,
JsonPropertySnippet jsonPropertySnippet)
Expand Down Expand Up @@ -891,24 +940,63 @@ private static MethodBodyStatement NullCheckCollectionItemIfRequired(
? new IfElseStatement(arrayItemVar.ValueKindEqualsNull(), assignNull, deserializeValue)
: deserializeValue;

private static MrwSerializationTypeProvider? FindSerializationOnBase(TypeProvider model)
{
if (model.Inherits is not { IsFrameworkType: false, Implementation: TypeProvider baseType })
{
return null;
}

if (baseType.SerializationProviders.Count == 0)
{
return null;
}

// finds the first MrwSerializationTypeProvider in serialization providers
return baseType.SerializationProviders.FirstOrDefault(s => s is MrwSerializationTypeProvider) as MrwSerializationTypeProvider;
}

private (ConstructorSignature? BaseSignature, ConstructorInitializer? Initializer) BuildConstructorInitializer()
{
// find the constructor on the base type
if (_baseSerializationProvider == null || _baseSerializationProvider.Constructors.Count == 0)
{
return (null, null);
}

// we cannot know which ctor to call, but in our implemenation, it should only be one
var ctor = _baseSerializationProvider.Constructors[0];
if (ctor.Signature is not ConstructorSignature ctorSignature || ctorSignature.Parameters.Count == 0)
{
return (null, null);
}
// construct the initializer using the parameters from base signature
var initializer = new ConstructorInitializer(true, ctorSignature.Parameters);

return (ctorSignature, initializer);
}

/// <summary>
/// Builds the parameters for the serialization constructor by iterating through the input model properties.
/// It then adds raw data field to the constructor if it doesn't already exist in the list of constructed parameters.
/// </summary>
/// <returns>The list of parameters for the serialization parameter.</returns>
private List<ParameterProvider> BuildSerializationConstructorParameters()
private List<ParameterProvider> BuildSerializationConstructorParameters(IReadOnlyList<ParameterProvider> baseParameters)
{
List<ParameterProvider> constructorParameters = new List<ParameterProvider>();
var parameterNames = baseParameters.Select(p => p.Name).ToHashSet();
var parameterCapacity = baseParameters.Count + _inputModel.Properties.Count;
var constructorParameters = new List<ParameterProvider>(parameterCapacity);
bool shouldAddRawDataField = _rawDataField != null;

// add the base parameters
constructorParameters.AddRange(baseParameters);

foreach (var property in _inputModel.Properties)
{
var parameter = new ParameterProvider(property);
constructorParameters.Add(parameter);

if (shouldAddRawDataField && string.Equals(parameter.Name, _rawDataField?.Name, StringComparison.OrdinalIgnoreCase))
if (!parameterNames.Contains(parameter.Name))
{
shouldAddRawDataField = false;
constructorParameters.Add(parameter);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ protected override MethodProvider[] BuildConstructors()
? MethodSignatureModifiers.Public
: MethodSignatureModifiers.Internal;
var (baseConstructor, constructorInitializer) = BuildConstructorInitializer();
var constructorParameters = BuildConstructorParameters(baseConstructor);
var constructorParameters = BuildConstructorParameters(baseConstructor?.Parameters ?? []);

var constructor = new MethodProvider(
signature: new ConstructorSignature(
Expand Down Expand Up @@ -128,19 +128,14 @@ protected override MethodProvider[] BuildConstructors()
return (ctorSignature, initializer);
}

private IReadOnlyList<ParameterProvider> BuildConstructorParameters(ConstructorSignature? baseConstructor)
private IReadOnlyList<ParameterProvider> BuildConstructorParameters(IReadOnlyList<ParameterProvider> baseParameters)
{
var baseParameters = baseConstructor?.Parameters ?? Array.Empty<ParameterProvider>();
var parameterCapacity = baseParameters.Count + _inputModel.Properties.Count;
var parameterNames = new HashSet<string>(parameterCapacity);
var constructorParameters = new List<ParameterProvider>(baseParameters.Count + _inputModel.Properties.Count);
var parameterNames = baseParameters.Select(p => p.Name).ToHashSet();
var constructorParameters = new List<ParameterProvider>(parameterCapacity);

// add the base parameters
foreach (var parameter in baseParameters)
{
parameterNames.Add(parameter.Name);
constructorParameters.Add(parameter);
}
constructorParameters.AddRange(baseParameters);

foreach (var property in _inputModel.Properties)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,5 @@ private static VariableExpression GetVariableExpression(ParameterProvider parame

private VariableExpression? _asVariable;
public VariableExpression AsExpression => _asVariable ??= this;

private MemberExpression? _asProperty;
public MemberExpression AsPropertyExpression => _asProperty ??= new MemberExpression(null, Name.FirstCharToUpperCase());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,22 @@ public static string FirstCharToUpperCase(this string str)
return new string(span);
}

public static string FirstCharToLowerCase(this string str)
{
if (string.IsNullOrEmpty(str))
return str;

var strSpan = str.AsSpan();

if (char.IsLower(strSpan[0]))
return str;

Span<char> span = stackalloc char[strSpan.Length];
strSpan.CopyTo(span);
span[0] = char.ToLower(span[0]);
return new string(span);
}

public static IEnumerable<string> SplitByCamelCase(this string camelCase)
{
var humanizedString = HumanizedCamelCaseRegex.Replace(camelCase, "$1");
Expand Down
Loading
Loading