Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,53 @@ protected override string BuildNamespace() => string.IsNullOrEmpty(_inputModel.N

protected override CSharpType? BuildBaseType()
{
return BaseModelProvider?.Type;
if (CustomCodeView?.BaseType != null)
{
var customBase = CustomCodeView.BaseType;

// If the custom base type doesn't have a resolved namespace, then try to resolve it from the input model map.
// This will happen if a model is customized to inherit from another generated model, but that generated model
// was not also defined in custom code so Roslyn does not recognize it.
if (string.IsNullOrEmpty(customBase.Namespace))
{
if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue(
customBase.Name, out var resolvedProvider) &&
resolvedProvider is ModelProvider resolvedModel)
{
return resolvedModel.Type;
}

// Force-create all input models so that visitors run (which may rename models
// via TypeProvider.Update) and TypeProvidersByName is fully populated.
foreach (var model in CodeModelGenerator.Instance.InputLibrary.InputNamespace.Models)
{
CodeModelGenerator.Instance.TypeFactory.CreateModel(model);
}

if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue(
customBase.Name, out resolvedProvider) &&
resolvedProvider is ModelProvider resolvedAfterCreate)
{
return resolvedAfterCreate.Type;
}
}

if (CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue(
customBase, out var mappedProvider) &&
mappedProvider is ModelProvider mappedModel)
{
return mappedModel.Type;
}

return customBase;
}

if (_inputModel.BaseModel == null)
{
return null;
}

return CodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel)?.Type;
}

protected override TypeProvider[] BuildSerializationProviders()
Expand Down Expand Up @@ -293,63 +339,16 @@ private static bool IsDiscriminator(InputProperty property)

private ModelProvider? BuildBaseModelProvider()
{
// consider models that have been customized to inherit from a different generated model
if (CustomCodeView?.BaseType != null)
{
var baseType = CustomCodeView.BaseType;

// If the custom base type doesn't have a resolved namespace, then try to resolve it from the input model map.
// This will happen if a model is customized to inherit from another generated model, but that generated model
// was not also defined in custom code so Roslyn does not recognize it.
if (string.IsNullOrEmpty(baseType.Namespace))
{
// Cheap check: the base model may already be created and registered under the right name.
if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue(
baseType.Name, out var resolvedProvider) &&
resolvedProvider is ModelProvider resolvedModel)
{
return resolvedModel;
}

// Force-create all input models so that visitors run (which may rename models
// via TypeProvider.Update) and TypeProvidersByName is fully populated.
// This is a no-op for models that have already been created.
foreach (var model in CodeModelGenerator.Instance.InputLibrary.InputNamespace.Models)
{
CodeModelGenerator.Instance.TypeFactory.CreateModel(model);
}

if (CodeModelGenerator.Instance.TypeFactory.TypeProvidersByName.TryGetValue(
baseType.Name, out resolvedProvider) &&
resolvedProvider is ModelProvider resolvedAfterCreate)
{
return resolvedAfterCreate;
}
}

// Try to find the base type in the CSharpTypeMap
if (baseType != null && CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue(
baseType,
out var customBaseType) &&
customBaseType is ModelProvider customBaseModel)
{
return customBaseModel;
}

// If the custom base type has a namespace (external type), we don't return it here
// as it's handled by BuildBaseTypeProvider() which returns a TypeProvider
if (!string.IsNullOrEmpty(baseType?.Namespace))
{
return null;
}
}

if (_inputModel.BaseModel == null)
var baseType = BaseType;
if (baseType is null)
{
return null;
}

return CodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel);
return CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue(baseType, out var provider)
&& provider is ModelProvider modelProvider
? modelProvider
: null;
}

private List<FieldProvider> BuildAdditionalPropertyFields()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,107 @@ public void BuildBaseType()
Assert.AreEqual(baseModel!.Type, derivedModel!.Type.BaseType);
}

[Test]
public void OverridingBuildBaseType_AutoResolvesBaseModelProviderForGeneratedModel()
{
var inputBase = InputFactory.Model("baseModel", usage: InputModelTypeUsage.Input, properties: []);
var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []);
ModelProvider? baseProvider = null;
MockHelpers.LoadMockGenerator(createModelCore: input =>
{
if (input == inputBase)
{
return baseProvider = new ModelProvider(input);
}
if (input == inputDerived)
{
return new BuildBaseTypeOverridingModelProvider(input, baseProvider!.Type);
}
return null;
});

var actualBase = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputBase);
var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived);

Assert.IsNotNull(actualBase);
Assert.IsNotNull(actualDerived);
Assert.AreEqual(actualBase!.Type, actualDerived!.BaseType);
Assert.AreSame(actualBase, actualDerived.BaseModelProvider);
}

[Test]
public void OverridingBuildBaseType_AutoResolvesBaseModelProviderToNullForFrameworkType()
{
var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []);
var frameworkBase = new CSharpType(typeof(InvalidOperationException));
MockHelpers.LoadMockGenerator(createModelCore: input =>
input == inputDerived ? new BuildBaseTypeOverridingModelProvider(input, frameworkBase) : null);

var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived);

Assert.IsNotNull(actualDerived);
Assert.AreEqual(frameworkBase, actualDerived!.BaseType);
Assert.IsNull(actualDerived.BaseModelProvider);
}

[Test]
public void BaseModelProvider_DefaultResolvesViaCSharpTypeMap()
{
var inputBase = InputFactory.Model("baseModel", usage: InputModelTypeUsage.Input, properties: []);
var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: [], baseModel: inputBase);

var derivedProvider = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived);
Assert.IsNotNull(derivedProvider);
Assert.IsNotNull(derivedProvider!.BaseModelProvider);
Assert.AreEqual(derivedProvider.BaseModelProvider!.Type, derivedProvider.BaseType);
}

[Test]
public void BaseModelProvider_NullWhenNoBase()
{
var inputModel = InputFactory.Model("standaloneModel", usage: InputModelTypeUsage.Input, properties: []);
var modelProvider = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel);

Assert.IsNotNull(modelProvider);
Assert.IsNull(modelProvider!.BaseType);
Assert.IsNull(modelProvider.BaseModelProvider);
}

[Test]
public void OverridingBuildBaseType_AutoResolvesBaseModelProviderToNullForNonModelTypeProvider()
{
var inputDerived = InputFactory.Model("derivedModel", usage: InputModelTypeUsage.Input, properties: []);
var nonModelTypeProvider = new NonModelTypeProvider();
MockHelpers.LoadMockGenerator(createModelCore: input =>
input == inputDerived ? new BuildBaseTypeOverridingModelProvider(input, nonModelTypeProvider.Type) : null);
CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap[nonModelTypeProvider.Type] = nonModelTypeProvider;

var actualDerived = CodeModelGenerator.Instance.TypeFactory.CreateModel(inputDerived);

Assert.IsNotNull(actualDerived);
Assert.AreEqual(nonModelTypeProvider.Type, actualDerived!.BaseType);
Assert.IsNull(actualDerived.BaseModelProvider);
}

private class NonModelTypeProvider : TypeProvider
{
protected override string BuildRelativeFilePath() => ".";
protected override string BuildName() => "NonModelBase";
protected override string BuildNamespace() => "Custom.Namespace";
}

private class BuildBaseTypeOverridingModelProvider : ModelProvider
{
private readonly CSharpType? _redirectedBaseType;

public BuildBaseTypeOverridingModelProvider(InputModelType inputModel, CSharpType? redirectedBaseType) : base(inputModel)
{
_redirectedBaseType = redirectedBaseType;
}

protected override CSharpType? BuildBaseType() => _redirectedBaseType;
}

[Test]
public void BuildModelAsStruct()
{
Expand Down
Loading