diff --git a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/ContractRegistry.cs b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/ContractRegistry.cs index 76b7ca71eb0537..53fefbc8ea812e 100644 --- a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/ContractRegistry.cs +++ b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader.Abstractions/ContractRegistry.cs @@ -168,8 +168,9 @@ public bool TryGetContract([NotNullWhen(true)] out TContract contract } /// - /// Register a contract implementation for a specific version. - /// External packages use this to add contract versions or entirely new contract interfaces. + /// Register a contract implementation for a specific version. An empty + /// is used as the fallback when the target + /// does not advertise a version for the contract. /// public abstract void Register(string version, Func creator) where TContract : IContract; diff --git a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader/CachingContractRegistry.cs b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader/CachingContractRegistry.cs index 376b680f35966c..e72eb5d4dfce6d 100644 --- a/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader/CachingContractRegistry.cs +++ b/src/native/managed/cdac/Microsoft.Diagnostics.DataContractReader/CachingContractRegistry.cs @@ -48,15 +48,22 @@ public override bool TryGetContract([NotNullWhen(true)] out TContract return true; } - if (!_tryGetContractVersion(TContract.Name, out string? version)) + Func? creator; + if (_tryGetContractVersion(TContract.Name, out string? version)) { - failureReason = $"Target does not support contract '{typeof(TContract).Name}'."; - return false; + // Target declares a version — require an implementation for it. + // Do NOT fall back to the default registration in this case: a + // missing version-specific impl is a real version-skew failure + // and silently using a default would mask it. + if (!_creators.TryGetValue((typeof(TContract), version), out creator)) + { + failureReason = $"Target supports contract '{typeof(TContract).Name}' version {version}, but no implementation is registered for that version."; + return false; + } } - - if (!_creators.TryGetValue((typeof(TContract), version), out Func? creator)) + else if (!_creators.TryGetValue((typeof(TContract), string.Empty), out creator)) { - failureReason = $"Target supports contract '{typeof(TContract).Name}' version {version}, but no implementation is registered for that version."; + failureReason = $"Target does not support contract '{typeof(TContract).Name}'."; return false; } diff --git a/src/native/managed/cdac/gen/CdacGenerator.cs b/src/native/managed/cdac/gen/CdacGenerator.cs index 2375b2fdb60d05..1ccbe542f76db4 100644 --- a/src/native/managed/cdac/gen/CdacGenerator.cs +++ b/src/native/managed/cdac/gen/CdacGenerator.cs @@ -34,10 +34,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // sees Contracts' copy via [InternalsVisibleTo] and shouldn't emit its own). // Each helper is gated independently to handle version-skew scenarios where // one helper is present but the other is not. - IncrementalValueProvider<(bool EmitLayoutSet, bool EmitTypeNameResolver)> shouldEmitHelpers = context.CompilationProvider + IncrementalValueProvider<(bool EmitLayoutSet, bool EmitTypeNameResolver, bool EmitGeneratedTypeCacheContract)> shouldEmitHelpers = context.CompilationProvider .Select(static (compilation, _) => ( EmitLayoutSet: !IsTypeAccessible(compilation, LayoutSetSource.FullyQualifiedName), - EmitTypeNameResolver: !IsTypeAccessible(compilation, TypeNameResolverSource.FullyQualifiedName))); + EmitTypeNameResolver: !IsTypeAccessible(compilation, TypeNameResolverSource.FullyQualifiedName), + EmitGeneratedTypeCacheContract: !IsTypeAccessible(compilation, GeneratedTypeCacheContractSource.FullyQualifiedName))); context.RegisterSourceOutput(shouldEmitHelpers, static (ctx, flags) => { @@ -54,6 +55,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) TypeNameResolverSource.HintName, SourceText.From(TypeNameResolverSource.Source, Encoding.UTF8)); } + + if (flags.EmitGeneratedTypeCacheContract) + { + ctx.AddSource( + GeneratedTypeCacheContractSource.HintName, + SourceText.From(GeneratedTypeCacheContractSource.Source, Encoding.UTF8)); + } }); IncrementalValuesProvider models = context.SyntaxProvider diff --git a/src/native/managed/cdac/gen/Emitter.cs b/src/native/managed/cdac/gen/Emitter.cs index 162d75d2ddcce4..920346ef18af87 100644 --- a/src/native/managed/cdac/gen/Emitter.cs +++ b/src/native/managed/cdac/gen/Emitter.cs @@ -134,7 +134,7 @@ private static void EmitWriteBackMethod(StringBuilder sb, MemberModel member) sb.AppendLine($" public void Write{member.Name}({propType} value)"); sb.AppendLine(" {"); - sb.AppendLine($" LayoutSet layouts = LayoutSet.Resolve(_target, _typeNames);"); + sb.AppendLine($" LayoutSet layouts = _target.GetCachedLayoutSet(_typeNames);"); sb.AppendLine($" layouts.Select(Address, out var t, out var b, out var n, {NameArgs(member)});"); if (member.ReadKind == FieldReadKind.Bool) { @@ -176,7 +176,7 @@ private static void EmitConstructor(StringBuilder sb, CdacTypeModel model, bool if (needsDescriptor) { sb.AppendLine(); - sb.AppendLine($" LayoutSet layouts = LayoutSet.Resolve(target, _typeNames);"); + sb.AppendLine($" LayoutSet layouts = target.GetCachedLayoutSet(_typeNames);"); } sb.AppendLine(); diff --git a/src/native/managed/cdac/gen/GeneratedTypeCacheContractSource.cs b/src/native/managed/cdac/gen/GeneratedTypeCacheContractSource.cs new file mode 100644 index 00000000000000..15e9d1efc26470 --- /dev/null +++ b/src/native/managed/cdac/gen/GeneratedTypeCacheContractSource.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Diagnostics.DataContractReader.DataGenerator; + +/// +/// Source for the IGeneratedTypeCache contract and its supporting +/// types, emitted into each consuming assembly. +/// +internal static class GeneratedTypeCacheContractSource +{ + public const string HintName = "GeneratedTypeCacheContract.g.cs"; + + public const string Namespace = "Microsoft.Diagnostics.DataContractReader.Generated"; + + public const string FullyQualifiedName = Namespace + ".IGeneratedTypeCache"; + + public const string Source = """ +// +#nullable enable + +using System; +using System.Collections.Generic; +using Microsoft.Diagnostics.DataContractReader; +using Microsoft.Diagnostics.DataContractReader.Contracts; + +namespace Microsoft.Diagnostics.DataContractReader.Generated; + +/// +/// Target-agnostic contract that caches results +/// per target, keyed by the generator-emitted _typeNames array reference. +/// +internal interface IGeneratedTypeCache : IContract +{ + static string IContract.Name => nameof(GeneratedTypeCache); + + LayoutSet GetOrAddLayoutSet(string[] typeNames); +} + +internal static class GeneratedTypeCache { } + +internal sealed class GeneratedTypeCache_1 : IGeneratedTypeCache +{ + private readonly Target _target; + private readonly Dictionary _cache = + new(ReferenceEqualityComparer.Instance); + + public GeneratedTypeCache_1(Target target) => _target = target; + + public LayoutSet GetOrAddLayoutSet(string[] typeNames) + { + if (!_cache.TryGetValue(typeNames, out LayoutSet? cached)) + { + cached = LayoutSet.Resolve(_target, typeNames); + _cache[typeNames] = cached; + } + return cached; + } + + public void Flush(FlushScope scope) + { + // LayoutSets are immutable across ForwardExecution; only clear on All. + if (scope == FlushScope.All) + _cache.Clear(); + } +} + +internal static class TargetExtensions +{ + public static LayoutSet GetCachedLayoutSet(this Target target, string[] typeNames) + { + if (!target.Contracts.TryGetContract(out IGeneratedTypeCache contract)) + { + target.Contracts.Register(string.Empty, static t => new GeneratedTypeCache_1(t)); + contract = target.Contracts.GetContract(); + } + return contract.GetOrAddLayoutSet(typeNames); + } +} +"""; +} diff --git a/src/native/managed/cdac/gen/LayoutSetSource.cs b/src/native/managed/cdac/gen/LayoutSetSource.cs index d7e7ac962a1997..b910d95ee56dfb 100644 --- a/src/native/managed/cdac/gen/LayoutSetSource.cs +++ b/src/native/managed/cdac/gen/LayoutSetSource.cs @@ -33,7 +33,7 @@ namespace Microsoft.Diagnostics.DataContractReader.Generated; /// first, then managed type metadata), trying each candidate field name /// per source. Sources are resolved lazily. /// -internal readonly struct LayoutSet +internal sealed class LayoutSet { private readonly LazyLayout[] _layouts; diff --git a/src/native/managed/cdac/tests/DataGenerator/TestTarget.cs b/src/native/managed/cdac/tests/DataGenerator/TestTarget.cs index b823054d3b0220..a959d7381b9ce6 100644 --- a/src/native/managed/cdac/tests/DataGenerator/TestTarget.cs +++ b/src/native/managed/cdac/tests/DataGenerator/TestTarget.cs @@ -36,6 +36,7 @@ public TestTarget(int pointerSize = 8, bool isLittleEndian = true) IsLittleEndian = isLittleEndian; ManagedTypeSourceMock = new Mock(); _contracts = new TestContractRegistry(ManagedTypeSourceMock.Object); + _contracts.SetTarget(this); _processedData = new TestDataCache(this); } @@ -262,12 +263,17 @@ public override bool TryReadGlobalPointer(string name, [NotNullWhen(true)] out T private sealed class TestContractRegistry : ContractRegistry { private readonly IManagedTypeSource _managedTypeSource; + private readonly Dictionary<(Type, string), Func> _creators = new(); + private readonly Dictionary _resolved = new(); + private Target? _target; public TestContractRegistry(IManagedTypeSource managedTypeSource) { _managedTypeSource = managedTypeSource; } + public void SetTarget(Target target) => _target = target; + public override IManagedTypeSource ManagedTypeSource => _managedTypeSource; public override bool TryGetContract([NotNullWhen(true)] out TContract contract, out string? failureReason) @@ -278,13 +284,31 @@ public override bool TryGetContract([NotNullWhen(true)] out TContract failureReason = null; return true; } + if (_resolved.TryGetValue(typeof(TContract), out IContract? cached)) + { + contract = (TContract)cached; + failureReason = null; + return true; + } + // No target-declared versions in this stub — fall through directly + // to the empty-string "default" registration. + if (_creators.TryGetValue((typeof(TContract), string.Empty), out Func? fallback)) + { + if (_target is null) + throw new InvalidOperationException("TestContractRegistry: SetTarget must be called before TryGetContract."); + IContract created = fallback(_target); + _resolved[typeof(TContract)] = created; + contract = (TContract)created; + failureReason = null; + return true; + } contract = default!; failureReason = "Not registered in TestContractRegistry."; return false; } public override void Register(string version, Func creator) - => throw new NotImplementedException(); + => _creators[(typeof(TContract), version)] = t => creator(t); public override void Flush(FlushScope scope) { } } diff --git a/src/native/managed/cdac/tests/TestInfrastructure/TestPlaceholderTarget.cs b/src/native/managed/cdac/tests/TestInfrastructure/TestPlaceholderTarget.cs index 12561ffd8676a7..bf9db469f9d3e1 100644 --- a/src/native/managed/cdac/tests/TestInfrastructure/TestPlaceholderTarget.cs +++ b/src/native/managed/cdac/tests/TestInfrastructure/TestPlaceholderTarget.cs @@ -635,14 +635,20 @@ public override bool TryGetContract([NotNullWhen(true)] out TContract } else if (_versions.TryGetValue(typeof(TContract), out string? version)) { + // Target declares a version — require an implementation for it. + // No fallback to the empty-string default in this case. if (!_creators.TryGetValue((typeof(TContract), version), out var creator)) { failureReason = $"Target supports contract '{typeof(TContract).Name}' version {version}, but no implementation is registered for that version."; return false; } - resolved = creator(_target); } + else if (_creators.TryGetValue((typeof(TContract), string.Empty), out var fallback)) + { + // No target-declared version — fall back to the empty-string default. + resolved = fallback(_target); + } else { failureReason = $"Contract '{typeof(TContract).Name}' is not supported by the target.";