This commit is contained in:
Jackson Schuster 2025-07-30 09:54:01 -04:00 committed by GitHub
commit 32f9e73b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 777 additions and 87 deletions

View File

@ -20,11 +20,23 @@ namespace Microsoft.Interop
/// <summary>
/// COM methods that require shadowing declarations on the derived interface.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface && !m.IsExternallyDefined);
/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
/// <summary>
/// The size of the vtable for this interface, including the base interface methods and IUnknown methods.
/// </summary>
public int VTableSize => Methods.Length == 0
? IUnknownConstants.VTableSize
: 1 + Methods.Max(m => m.GenerationContext.VtableIndexData.Index);
/// <summary>
/// The size of the vtable for the base interface, including it's base interface methods and IUnknown methods.
/// </summary>
public int BaseVTableSize => VTableSize - DeclaredMethods.Count();
}
}

View File

@ -54,7 +54,7 @@ namespace Microsoft.Interop
var externalInterfaceSymbols = attributedInterfaces.SelectMany(static (data, ct) =>
{
return ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(data.Symbol);
});
}).Collect().SelectMany(static (data, ct) => data.Distinct(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance));
var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics.Concat(externalInterfaceSymbols);
@ -84,11 +84,7 @@ namespace Microsoft.Interop
.SelectMany(static (data, ct) =>
{
return ComMethodContext.CalculateAllMethods(data, ct);
})
// Now that we've determined method offsets, we can remove all externally defined methods.
// We'll also filter out methods originally declared on externally defined base interfaces
// as we may not be able to emit them into our assembly.
.Where(context => !context.Method.OriginalDeclaringInterface.IsExternallyDefined);
});
// Now that we've determined method offsets, we can remove all externally defined interfaces.
var interfaceContextsToGenerate = interfaceContexts.Where(context => !context.IsExternallyDefined);
@ -107,13 +103,20 @@ namespace Microsoft.Interop
return new ComMethodContext(
data.Method,
data.OwningInterface,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info, ct));
CalculateStubInformation(
data.Method.MethodInfo.Syntax,
symbolMap[data.Method.MethodInfo],
data.Method.Index,
env,
data.OwningInterface.Info,
ct));
}).WithTrackingName(StepNames.CalculateStubInformation);
var interfaceAndMethodsContexts = comMethodContexts
.Collect()
.Combine(interfaceContextsToGenerate.Collect())
.SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
.SelectMany((data, ct) =>
GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
// Generate the code for the managed-to-unmanaged stubs.
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
@ -256,12 +259,22 @@ namespace Microsoft.Interop
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
}
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterfaceInfo, CancellationToken ct)
/// <summary>
/// Calculates the shared information needed for both source-available and sourceless stub generation.
/// </summary>
private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation(
IMethodSymbol symbol,
int index,
StubEnvironment environment,
ISignatureDiagnosticLocations diagnosticLocations,
ComInterfaceInfo owningInterfaceInfo,
CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.LcidConversionAttrType;
INamedTypeSymbol? suppressGCTransitionAttrType = environment.SuppressGCTransitionAttrType;
INamedTypeSymbol? unmanagedCallConvAttrType = environment.UnmanagedCallConvAttrType;
// Get any attributes of interest on the method
AttributeData? lcidConversionAttr = null;
AttributeData? suppressGCTransitionAttribute = null;
@ -282,8 +295,7 @@ namespace Microsoft.Interop
}
}
var locations = new MethodSignatureDiagnosticLocations(syntax);
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), diagnosticLocations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
if (lcidConversionAttr is not null)
{
@ -293,8 +305,8 @@ namespace Microsoft.Interop
GeneratedComInterfaceCompilationData.TryGetGeneratedComInterfaceAttributeFromInterface(symbol.ContainingType, out var generatedComAttribute);
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComAttribute);
// Create the stub.
// Create the stub.
var signatureContext = SignatureContext.Create(
symbol,
DefaultMarshallingInfoParser.Create(
@ -387,10 +399,6 @@ namespace Microsoft.Interop
GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue);
}
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
suppressGCTransitionAttribute,
unmanagedCallConvAttribute,
@ -398,10 +406,7 @@ namespace Microsoft.Interop
var declaringType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com);
MarshallingInfo exceptionMarshallingInfo;
if (generatedComInterfaceAttributeData.ExceptionToUnmanagedMarshaller is null)
{
exceptionMarshallingInfo = new ComExceptionMarshalling();
@ -418,11 +423,9 @@ namespace Microsoft.Interop
return new IncrementalMethodStubGenerationContext(
signatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
locations,
diagnosticLocations,
callConv.ToSequenceEqualImmutableArray(SyntaxEquivalentComparer.Instance),
virtualMethodIndexData,
new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com),
exceptionMarshallingInfo,
environment.EnvironmentFlags,
owningInterfaceInfo.Type,
@ -431,6 +434,45 @@ namespace Microsoft.Interop
ComInterfaceDispatchMarshallingInfo.Instance);
}
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
{
ISignatureDiagnosticLocations locations = syntax is null
? NoneSignatureDiagnosticLocations.Instance
: new MethodSignatureDiagnosticLocations(syntax);
var sourcelessStubInformation = CalculateSharedStubInformation(
symbol,
index,
environment,
locations,
owningInterface,
ct);
if (syntax is null)
return sourcelessStubInformation;
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
var methodSyntaxTemplate = new ContainingSyntax(
new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(),
SyntaxKind.MethodDeclaration,
syntax.Identifier,
syntax.TypeParameterList);
return new SourceAvailableIncrementalMethodStubGenerationContext(
sourcelessStubInformation.SignatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
locations,
sourcelessStubInformation.CallingConvention,
sourcelessStubInformation.VtableIndexData,
sourcelessStubInformation.ExceptionMarshallingInfo,
sourcelessStubInformation.EnvironmentFlags,
sourcelessStubInformation.TypeKeyOwner,
sourcelessStubInformation.DeclaringType,
sourcelessStubInformation.Diagnostics,
ComInterfaceDispatchMarshallingInfo.Instance);
}
private static MarshalDirection GetDirectionFromOptions(ComInterfaceOptions options)
{
if (options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper | ComInterfaceOptions.ComObjectWrapper))
@ -520,12 +562,12 @@ namespace Microsoft.Interop
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
var shadowImplementations = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
var inheritedStubs = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
@ -661,7 +703,6 @@ namespace Microsoft.Interop
BlockSyntax fillBaseInterfaceSlots;
if (interfaceMethods.Interface.Base is null)
{
// If we don't have a base interface, we need to manually fill in the base iUnknown slots.
@ -740,7 +781,7 @@ namespace Microsoft.Interop
}
else
{
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset>));
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize>));
fillBaseInterfaceSlots = Block(
MethodInvocationStatement(
TypeSyntaxes.System_Runtime_InteropServices_NativeMemory,
@ -750,7 +791,7 @@ namespace Microsoft.Interop
TypeSyntaxes.StrategyBasedComWrappers
.Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails"),
Argument( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
Argument(
TypeOfExpression(ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName))
.Dot(IdentifierName("TypeHandle"))))
.Dot(IdentifierName("ManagedVirtualMethodTable"))),
@ -767,7 +808,7 @@ namespace Microsoft.Interop
ParenthesizedExpression(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.BaseVTableSize))))))));
}
var validDeclaredMethods = interfaceMethods.DeclaredMethods
@ -787,7 +828,7 @@ namespace Microsoft.Interop
IdentifierName($"{declaredMethodContext.MethodInfo.MethodName}_{declaredMethodContext.GenerationContext.VtableIndexData.Index}")),
PrefixUnaryExpression(
SyntaxKind.AddressOfExpression,
IdentifierName($"ABI_{declaredMethodContext.GenerationContext.StubMethodSyntaxTemplate.Identifier}")))));
IdentifierName($"ABI_{((SourceAvailableIncrementalMethodStubGenerationContext)declaredMethodContext.GenerationContext).StubMethodSyntaxTemplate.Identifier}")))));
}
return ImplementationInterfaceTemplate

View File

@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.CodeAnalysis;
@ -10,7 +12,6 @@ using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using InterfaceInfo = (Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol);
using DiagnosticOrInterfaceInfo = Microsoft.Interop.DiagnosticOr<(Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol)>;
using System.Diagnostics;
namespace Microsoft.Interop
{
@ -176,6 +177,13 @@ namespace Microsoft.Interop
return builder.ToImmutable();
}
internal sealed class EqualityComparerForExternalIfaces : IEqualityComparer<(ComInterfaceInfo InterfaceInfo, INamedTypeSymbol Symbol)>
{
public bool Equals((ComInterfaceInfo, INamedTypeSymbol) x, (ComInterfaceInfo, INamedTypeSymbol) y) => SymbolEqualityComparer.Default.Equals(x.Item2, y.Item2);
public int GetHashCode((ComInterfaceInfo, INamedTypeSymbol) obj) => SymbolEqualityComparer.Default.GetHashCode(obj.Item2);
public static readonly EqualityComparerForExternalIfaces Instance = new();
}
private static bool IsInPartialContext(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
{
// Verify that the types the interface is declared in are marked partial.

View File

@ -21,8 +21,8 @@ namespace Microsoft.Interop
internal sealed class ComMethodContext : IEquatable<ComMethodContext>
{
/// <summary>
/// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="IncrementalMethodStubGenerationContext"/> generated for it yet.
/// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="IncrementalMethodStubGenerationContext"/> requires an ISymbol
/// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="SourceAvailableIncrementalMethodStubGenerationContext"/> generated for it yet.
/// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="SourceAvailableIncrementalMethodStubGenerationContext"/> requires an ISymbol
/// </summary>
/// <param name="OriginalDeclaringInterface">
/// The interface that originally declared the method in user code
@ -48,7 +48,7 @@ namespace Microsoft.Interop
/// <param name="builder">The partially constructed context</param>
/// <param name="owningInterface">The final owning interface of this method context</param>
/// <param name="generationContext">The generation context for this method</param>
public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext generationContext)
public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext? generationContext)
{
_state = new State(builder.OriginalDeclaringInterface, owningInterface, builder.MethodInfo, generationContext);
}
@ -65,6 +65,8 @@ namespace Microsoft.Interop
public ComMethodInfo MethodInfo => _state.MethodInfo;
public bool IsExternallyDefined => _state.OriginalDeclaringInterface.IsExternallyDefined || _state.OwningInterface.IsExternallyDefined;
public IncrementalMethodStubGenerationContext GenerationContext => _state.GenerationContext;
public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;
@ -77,12 +79,18 @@ namespace Microsoft.Interop
private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
|| IsHiddenOnDerivedInterface
|| IsExternallyDefined)
{
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
if (GenerationContext is not SourceAvailableIncrementalMethodStubGenerationContext sourceAvailableContext)
{
throw new InvalidOperationException("Cannot generate stubs for non-source available methods.");
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(sourceAvailableContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(sourceAvailableContext.TypeKeyOwner, sourceAvailableContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
}
private GeneratedMethodContextBase? _unmanagedToManagedStub;
@ -91,12 +99,18 @@ namespace Microsoft.Interop
private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
|| IsHiddenOnDerivedInterface
|| IsExternallyDefined)
{
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
if (GenerationContext is not SourceAvailableIncrementalMethodStubGenerationContext sourceAvailableContext)
{
throw new InvalidOperationException("Cannot generate stubs for non-source available methods.");
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(sourceAvailableContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(sourceAvailableContext.OriginalDefiningType, sourceAvailableContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
}
private MethodDeclarationSyntax? _unreachableExceptionStub;
@ -183,7 +197,7 @@ namespace Microsoft.Interop
return cachedValue;
}
int startingIndex = 3;
int startingIndex = IUnknownConstants.VTableSize;
List<Builder> methods = new();
// If we have a base interface, we should add the inherited methods to our list in vtable order
if (iface.Base is not null)

View File

@ -17,7 +17,7 @@ namespace Microsoft.Interop
/// </summary>
internal sealed record ComMethodInfo
{
public MethodDeclarationSyntax Syntax { get; init; }
public MethodDeclarationSyntax? Syntax { get; init; }
public string MethodName { get; init; }
public SequenceEqualImmutableArray<AttributeInfo> Attributes { get; init; }
public bool IsUserDefinedShadowingMethod { get; init; }
@ -94,7 +94,7 @@ namespace Microsoft.Interop
if (ifaceContext.IsExternallyDefined)
{
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((
new ComMethodInfo(null!, method.Name, method.GetAttributes().Select(AttributeInfo.From).ToImmutableArray().ToSequenceEqual(), false),
new ComMethodInfo(null, method.Name, method.GetAttributes().Select(AttributeInfo.From).ToImmutableArray().ToSequenceEqual(), false),
method));
}

View File

@ -0,0 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
namespace Microsoft.Interop
{
internal static class IUnknownConstants
{
public const int QueryInterfaceIndex = 0;
public const int AddRefIndex = 1;
public const int ReleaseIndex = 2;
public const int VTableSize = 3;
}
}

View File

@ -1,19 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
namespace Microsoft.Interop
{
internal abstract record GeneratedMethodContextBase(ManagedTypeInfo OriginalDefiningType, SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics);
internal sealed record IncrementalMethodStubGenerationContext(
internal record IncrementalMethodStubGenerationContext(
SignatureContext SignatureContext,
ContainingSyntaxContext ContainingSyntaxContext,
ContainingSyntax StubMethodSyntaxTemplate,
MethodSignatureDiagnosticLocations DiagnosticLocation,
ISignatureDiagnosticLocations DiagnosticLocation,
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
VirtualMethodIndexData VtableIndexData,
MarshallingInfo ExceptionMarshallingInfo,
@ -22,4 +18,28 @@ namespace Microsoft.Interop
ManagedTypeInfo DeclaringType,
SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics,
MarshallingInfo ManagedThisMarshallingInfo) : GeneratedMethodContextBase(DeclaringType, Diagnostics);
internal sealed record SourceAvailableIncrementalMethodStubGenerationContext(
SignatureContext SignatureContext,
ContainingSyntaxContext ContainingSyntaxContext,
ContainingSyntax StubMethodSyntaxTemplate,
ISignatureDiagnosticLocations DiagnosticLocation,
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
VirtualMethodIndexData VtableIndexData,
MarshallingInfo ExceptionMarshallingInfo,
EnvironmentFlags EnvironmentFlags,
ManagedTypeInfo TypeKeyOwner,
ManagedTypeInfo DeclaringType,
SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics,
MarshallingInfo ManagedThisMarshallingInfo) : IncrementalMethodStubGenerationContext(
SignatureContext,
DiagnosticLocation,
CallingConvention,
VtableIndexData,
ExceptionMarshallingInfo,
EnvironmentFlags,
TypeKeyOwner,
DeclaringType,
Diagnostics,
ManagedThisMarshallingInfo);
}

View File

@ -21,7 +21,7 @@ namespace Microsoft.Interop
internal const string VirtualMethodTarget = "__target";
public static (MethodDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateManagedToNativeStub(
IncrementalMethodStubGenerationContext methodStub,
SourceAvailableIncrementalMethodStubGenerationContext methodStub,
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
{
var diagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), methodStub.DiagnosticLocation, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
@ -128,7 +128,7 @@ namespace Microsoft.Interop
private const string ManagedThisParameterIdentifier = "@this";
public static (MethodDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateNativeToManagedStub(
IncrementalMethodStubGenerationContext methodStub,
SourceAvailableIncrementalMethodStubGenerationContext methodStub,
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
{
var diagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), methodStub.DiagnosticLocation, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
@ -174,7 +174,7 @@ namespace Microsoft.Interop
methodStub.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
}
private static ImmutableArray<TypePositionInfo> AddManagedToUnmanagedImplicitThis(IncrementalMethodStubGenerationContext methodStub)
private static ImmutableArray<TypePositionInfo> AddManagedToUnmanagedImplicitThis(SourceAvailableIncrementalMethodStubGenerationContext methodStub)
{
ImmutableArray<TypePositionInfo> originalElements = methodStub.SignatureContext.ElementTypeInformation;
@ -232,7 +232,7 @@ namespace Microsoft.Interop
}
public static BlockSyntax GenerateVirtualMethodTableSlotAssignments(
IEnumerable<IncrementalMethodStubGenerationContext> vtableMethods,
IEnumerable<SourceAvailableIncrementalMethodStubGenerationContext> vtableMethods,
string vtableIdentifier,
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
{

View File

@ -62,7 +62,7 @@ namespace Microsoft.Interop
// Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs
// for each method.
IncrementalValuesProvider<IncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
.Combine(context.CreateStubEnvironmentProvider())
.Select(static (data, ct) => new
{
@ -89,7 +89,7 @@ namespace Microsoft.Interop
context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs");
// Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation.
IncrementalValuesProvider<IncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
generateStubInformation
.Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional);
@ -195,7 +195,7 @@ namespace Microsoft.Interop
};
}
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
private static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
@ -306,7 +306,7 @@ namespace Microsoft.Interop
MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);
return new IncrementalMethodStubGenerationContext(
return new SourceAvailableIncrementalMethodStubGenerationContext(
signatureContext,
containingSyntaxContext,
methodSyntaxTemplate,
@ -363,7 +363,7 @@ namespace Microsoft.Interop
}
private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateManagedToNativeStub(
IncrementalMethodStubGenerationContext methodStub)
SourceAvailableIncrementalMethodStubGenerationContext methodStub)
{
var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
@ -376,7 +376,7 @@ namespace Microsoft.Interop
}
private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateNativeToManagedStub(
IncrementalMethodStubGenerationContext methodStub)
SourceAvailableIncrementalMethodStubGenerationContext methodStub)
{
var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
@ -433,7 +433,7 @@ namespace Microsoft.Interop
.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(NameSyntaxes.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute)))));
}
private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<ContainingSyntaxContext, IncrementalMethodStubGenerationContext> vtableMethods)
private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<ContainingSyntaxContext, SourceAvailableIncrementalMethodStubGenerationContext> vtableMethods)
{
ContainingSyntaxContext containingSyntax = vtableMethods.Key.AddContainingSyntax(NativeTypeContainingSyntax);

View File

@ -42,6 +42,15 @@ namespace Microsoft.Interop
DiagnosticInfo CreateDiagnosticInfo(DiagnosticDescriptor descriptor, GeneratorDiagnostic diagnostic);
}
public class NoneSignatureDiagnosticLocations : ISignatureDiagnosticLocations
{
public static readonly NoneSignatureDiagnosticLocations Instance = new();
public DiagnosticInfo CreateDiagnosticInfo(DiagnosticDescriptor descriptor, GeneratorDiagnostic diagnostic)
{
return diagnostic.ToDiagnosticInfo(descriptor, Location.None, string.Empty);
}
}
public sealed record MethodSignatureDiagnosticLocations(string MethodIdentifier, ImmutableArray<Location> ManagedParameterLocations, Location FallbackLocation) : ISignatureDiagnosticLocations
{
public MethodSignatureDiagnosticLocations(MethodDeclarationSyntax syntax)

View File

@ -96,7 +96,7 @@ namespace Microsoft.Interop
ByValueContentsMarshalKind = byValueContentsMarshalKind,
ByValueMarshalAttributeLocations = (inLocation, outLocation),
ScopedKind = paramSymbol.ScopedKind,
IsExplicitThis = ((ParameterSyntax)paramSymbol.DeclaringSyntaxReferences[0].GetSyntax()).Modifiers.Any(SyntaxKind.ThisKeyword)
IsExplicitThis = ((ParameterSyntax?)paramSymbol.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax())?.Modifiers.Any(SyntaxKind.ThisKeyword) ?? false
};
return typeInfo;

View File

@ -17,7 +17,7 @@
<ItemGroup>
<EnabledGenerators Include="ComInterfaceGenerator" />
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs" Link="Common\DisableRuntimeMarshalling.cs" />
<Compile Include="..\TestAssets\SharedTypes\ComInterfaces\**\*.cs" Link="ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
<Compile Include="..\Common\ComInterfaces\**\*.cs" Link="ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
</ItemGroup>
<ItemGroup>

View File

@ -0,0 +1,256 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
using Xunit;
namespace ComInterfaceGenerator.Tests
{
public partial class CrossAssemblyInheritanceTests
{
[GeneratedComClass]
[Guid("e0c6b35f-1234-4567-8901-123456789abc")]
internal partial class DerivedExternalBaseImpl : IDerivedExternalBase
{
private int _value = 10;
public int GetInt() => _value;
public void SetInt(int x) => _value = x;
public string GetName() => "DerivedExternalBase";
}
[GeneratedComClass]
[Guid("e0c6b35f-1234-4567-8901-123456789abd")]
internal partial class DerivedExternalBase2Impl : IDerivedExternalBase2
{
private int _value = 20;
public int GetInt() => _value;
public void SetInt(int x) => _value = x;
public string GetName() => "DerivedExternalBase2";
}
[GeneratedComClass]
[Guid("e0c6b35f-1234-4567-8901-123456789abe")]
internal partial class DerivedFromExternalDerivedImpl : IDerivedFromExternalDerived
{
private int _value = 30;
private bool _boolValue = true;
public int GetInt() => _value;
public void SetInt(int x) => _value = x;
public bool GetBool() => _boolValue;
public void SetBool(bool x) => _boolValue = x;
public string GetName() => "DerivedFromExternalDerived";
}
[GeneratedComClass]
[Guid("e0c6b35f-1234-4567-8901-123456789abf")]
internal partial class DerivedFromDerivedExternalDerivedImpl : IDerivedFromDerivedExternalDerived
{
private int _value = 40;
private bool _boolValue = false;
private float _floatValue = 3.14f;
public int GetInt() => _value;
public void SetInt(int x) => _value = x;
public bool GetBool() => _boolValue;
public void SetBool(bool x) => _boolValue = x;
public string GetName() => "DerivedFromDerivedExternalDerived";
public float GetFloat() => _floatValue;
}
[Fact]
public void IDerivedExternalBase_CanCallMethods()
{
var implementation = new DerivedExternalBaseImpl();
var comWrappers = new StrategyBasedComWrappers();
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
var externalBase = (IExternalBase)managedObj;
Assert.Equal(10, externalBase.GetInt());
externalBase.SetInt(15);
Assert.Equal(15, externalBase.GetInt());
var derivedExternalBase = (IDerivedExternalBase)managedObj;
Assert.Equal(15, derivedExternalBase.GetInt());
Assert.Equal("DerivedExternalBase", derivedExternalBase.GetName());
}
[Fact]
public void IDerivedExternalBase2_CanCallMethods()
{
var implementation = new DerivedExternalBase2Impl();
var comWrappers = new StrategyBasedComWrappers();
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
// Test as base interface
var externalBase = (IExternalBase)managedObj;
Assert.Equal(20, externalBase.GetInt());
externalBase.SetInt(25);
Assert.Equal(25, externalBase.GetInt());
// Test as derived interface
var derivedExternalBase2 = (IDerivedExternalBase2)managedObj;
Assert.Equal(25, derivedExternalBase2.GetInt());
Assert.Equal("DerivedExternalBase2", derivedExternalBase2.GetName());
}
[Fact]
public void IDerivedFromExternalDerived_CanCallMethods()
{
var implementation = new DerivedFromExternalDerivedImpl();
var comWrappers = new StrategyBasedComWrappers();
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
var externalBase = (IExternalBase)managedObj;
Assert.Equal(30, externalBase.GetInt());
externalBase.SetInt(35);
Assert.Equal(35, externalBase.GetInt());
var externalDerived = (IExternalDerived)managedObj;
Assert.Equal(35, externalDerived.GetInt());
Assert.True(externalDerived.GetBool());
externalDerived.SetBool(false);
Assert.False(externalDerived.GetBool());
var derivedFromExternalDerived = (IDerivedFromExternalDerived)managedObj;
Assert.Equal(35, derivedFromExternalDerived.GetInt());
Assert.False(derivedFromExternalDerived.GetBool());
Assert.Equal("DerivedFromExternalDerived", derivedFromExternalDerived.GetName());
}
[Fact]
public void IDerivedFromDerivedExternalDerived_CanCallMethods()
{
var implementation = new DerivedFromDerivedExternalDerivedImpl();
var comWrappers = new StrategyBasedComWrappers();
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
var externalBase = (IExternalBase)managedObj;
Assert.Equal(40, externalBase.GetInt());
externalBase.SetInt(45);
Assert.Equal(45, externalBase.GetInt());
var externalDerived = (IExternalDerived)managedObj;
Assert.Equal(45, externalDerived.GetInt());
Assert.False(externalDerived.GetBool());
externalDerived.SetBool(true);
Assert.True(externalDerived.GetBool());
var derivedFromExternalDerived = (IDerivedFromExternalDerived)managedObj;
Assert.Equal(45, derivedFromExternalDerived.GetInt());
Assert.True(derivedFromExternalDerived.GetBool());
Assert.Equal("DerivedFromDerivedExternalDerived", derivedFromExternalDerived.GetName());
var derivedFromDerivedExternalDerived = (IDerivedFromDerivedExternalDerived)managedObj;
Assert.Equal(45, derivedFromDerivedExternalDerived.GetInt());
Assert.True(derivedFromDerivedExternalDerived.GetBool());
Assert.Equal("DerivedFromDerivedExternalDerived", derivedFromDerivedExternalDerived.GetName());
Assert.Equal(3.14f, derivedFromDerivedExternalDerived.GetFloat());
}
[Fact]
public unsafe void MultipleInterfacesDerivedFromSameBase_ShareCommonVTableLayout()
{
IIUnknownDerivedDetails baseDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
IIUnknownDerivedDetails derived1Details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase).TypeHandle);
IIUnknownDerivedDetails derived2Details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase2).TypeHandle);
var numBaseMethods = typeof(IExternalBase).GetMethods().Length;
var numPointersToCompare = 3 + numBaseMethods; // IUnknown (3) + base methods
// Both derived interfaces should have the same base vtable layout
var baseVTable = new ReadOnlySpan<nint>(baseDetails.ManagedVirtualMethodTable, numPointersToCompare);
var derived1BaseVTable = new ReadOnlySpan<nint>(derived1Details.ManagedVirtualMethodTable, numPointersToCompare);
var derived2BaseVTable = new ReadOnlySpan<nint>(derived2Details.ManagedVirtualMethodTable, numPointersToCompare);
Assert.True(baseVTable.SequenceEqual(derived1BaseVTable),
"IDerivedExternalBase should have consistent base vtable layout");
Assert.True(baseVTable.SequenceEqual(derived2BaseVTable),
"IDerivedExternalBase2 should have consistent base vtable layout");
Assert.True(derived1BaseVTable.SequenceEqual(derived2BaseVTable),
"Both derived interfaces should have identical base vtable layouts");
}
[Fact]
public unsafe void CrossAssemblyInheritance_VTableLayoutIsCorrect()
{
IIUnknownDerivedDetails baseInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
IIUnknownDerivedDetails derivedInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase).TypeHandle);
var numBaseMethods = typeof(IExternalBase).GetMethods().Length;
var numPointersToCompare = 3 + numBaseMethods;
// The first part of the vtable should match between base and derived
var expectedBaseVTable = new ReadOnlySpan<nint>(baseInterfaceDetails.ManagedVirtualMethodTable, numPointersToCompare);
var actualDerivedVTable = new ReadOnlySpan<nint>(derivedInterfaceDetails.ManagedVirtualMethodTable, numPointersToCompare);
Assert.True(expectedBaseVTable.SequenceEqual(actualDerivedVTable),
"Base interface methods should have the same vtable entries in derived interface");
}
[Fact]
public unsafe void IDerivedFromDerivedExternalDerived_VTableLayoutIsCorrect()
{
var baseDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
var externalDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IExternalDerived).TypeHandle);
var derivedFromExternalDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IDerivedFromExternalDerived).TypeHandle);
var deepDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
.GetIUnknownDerivedDetails(typeof(IDerivedFromDerivedExternalDerived).TypeHandle);
var baseMethods = typeof(IExternalBase).GetMethods().Length;
var baseVTableSize = 3 + baseMethods;
var baseVTable = new ReadOnlySpan<nint>(baseDetails.ManagedVirtualMethodTable, baseVTableSize);
var externalDerivedBaseVTable = new ReadOnlySpan<nint>(externalDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
var derivedFromExternalDerivedBaseVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
var deepDerivedBaseVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
Assert.True(baseVTable.SequenceEqual(externalDerivedBaseVTable),
"IExternalDerived should have consistent base vtable layout");
Assert.True(baseVTable.SequenceEqual(derivedFromExternalDerivedBaseVTable),
"IDerivedFromExternalDerived should have consistent base vtable layout");
Assert.True(baseVTable.SequenceEqual(deepDerivedBaseVTable),
"IDerivedFromDerivedExternalDerived should have consistent base vtable layout");
var externalDerivedMethods = typeof(IExternalDerived).GetMethods().Length;
var externalDerivedVTableSize = 3 + externalDerivedMethods;
var externalDerivedVTable = new ReadOnlySpan<nint>(externalDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
var derivedFromExternalDerivedIntermediateVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
var deepDerivedIntermediateVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
Assert.True(externalDerivedVTable.SequenceEqual(derivedFromExternalDerivedIntermediateVTable),
"IDerivedFromExternalDerived should have consistent IExternalDerived vtable layout");
Assert.True(externalDerivedVTable.SequenceEqual(deepDerivedIntermediateVTable),
"IDerivedFromDerivedExternalDerived should have consistent IExternalDerived vtable layout");
var derivedFromExternalDerivedMethods = typeof(IDerivedFromExternalDerived).GetMethods().Length;
var derivedFromExternalDerivedVTableSize = 3 + derivedFromExternalDerivedMethods;
var derivedFromExternalDerivedVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, derivedFromExternalDerivedVTableSize);
var deepDerivedVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, derivedFromExternalDerivedVTableSize);
Assert.True(derivedFromExternalDerivedVTable.SequenceEqual(deepDerivedVTable),
"IDerivedFromDerivedExternalDerived should have consistent IDerivedFromExternalDerived vtable layout");
}
}
}

View File

@ -13,6 +13,56 @@ namespace ComInterfaceGenerator.Tests
{
public partial class IDerivedTests
{
[GeneratedComInterface]
[Guid("7F0DB364-3C04-4487-9193-4BB05DC7B654")]
internal partial interface IDerivedFromSharedType2 : SharedTypes.ComInterfaces.IGetAndSetInt
{
int GetTwoTimesInt();
}
[GeneratedComInterface]
[Guid("7F0DB364-3C04-4487-9194-4BB05DC7B654")]
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
internal partial interface IDerivedFromSharedType : SharedTypes.ComInterfaces.IGetAndSetInt
#pragma warning restore SYSLIB1230
{
int GetIntPlusOne();
}
[GeneratedComClass]
[Guid("7F0DB364-3C04-4487-9195-4BB05DC7B654")]
internal partial class DerivedFromSharedTypeImpl : IDerivedFromSharedType, IDerivedFromSharedType2
{
int _value = 42;
public int GetInt() => _value;
public int GetIntPlusOne() => _value + 1;
public int GetTwoTimesInt() => _value * 2;
public void SetInt(int value) { _value = value; }
}
[Fact]
public unsafe void TypesDerivedFromSharedTypeHaveCorrectVTableSize()
{
var managedSourceObject = new DerivedFromSharedTypeImpl();
var cw = new StrategyBasedComWrappers();
var nativeObj = cw.GetOrCreateComInterfaceForObject(managedSourceObject, CreateComInterfaceFlags.None);
object managedObj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
IGetAndSetInt getAndSetInt = (IGetAndSetInt)managedObj;
IDerivedFromSharedType derivedFromSharedType = (IDerivedFromSharedType)managedObj;
IDerivedFromSharedType2 derivedFromSharedType2 = (IDerivedFromSharedType2)managedObj;
Assert.Equal(42, getAndSetInt.GetInt());
Assert.Equal(42, derivedFromSharedType.GetInt());
Assert.Equal(42, derivedFromSharedType2.GetInt());
getAndSetInt.SetInt(100);
Assert.Equal(100, getAndSetInt.GetInt());
Assert.Equal(101, derivedFromSharedType.GetIntPlusOne());
Assert.Equal(200, derivedFromSharedType2.GetTwoTimesInt());
}
[Fact]
public unsafe void DerivedInterfaceTypeProvidesBaseInterfaceUnmanagedToManagedMembers()
{
@ -36,16 +86,16 @@ namespace ComInterfaceGenerator.Tests
public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce()
{
var cw = new SingleQIComWrapper();
var derivedImpl = new DerivedImpl();
var derivedImpl = new Derived();
var nativeObj = cw.GetOrCreateComInterfaceForObject(derivedImpl, CreateComInterfaceFlags.None);
var obj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
IDerived iface = (IDerived)obj;
Assert.Equal(3, iface.GetInt());
Assert.Equal(0, iface.GetInt());
iface.SetInt(5);
Assert.Equal(5, iface.GetInt());
Assert.Equal("myName", iface.GetName());
Assert.Equal("hello", iface.GetName());
iface.SetName("updated");
Assert.Equal("updated", iface.GetName());
@ -58,22 +108,6 @@ namespace ComInterfaceGenerator.Tests
Assert.Equal(1, countQi.QiCallCount);
}
[GeneratedComClass]
partial class DerivedImpl : IDerived
{
int data = 3;
string myName = "myName";
public void DoThingWithString(string name) => throw new NotImplementedException();
public int GetInt() => data;
public string GetName() => myName;
public void SetInt(int n) => data = n;
public void SetName(string name) => myName = name;
}
/// <summary>
/// Used to ensure that QI is only called once when calling base methods on a derived COM interface
/// </summary>

View File

@ -0,0 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
[Guid("da8eed10-f3f4-42c3-86de-04f2dc56514e")]
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
internal partial interface IDerivedExternalBase : IExternalBase
#pragma warning restore SYSLIB1230
{
string GetName();
}
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
[Guid("c3d3990e-5b05-4a9b-adc4-58c521700ece")]
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
internal partial interface IDerivedExternalBase2 : IExternalBase
#pragma warning restore SYSLIB1230
{
string GetName();
}

View File

@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
[Guid("f252bddd-aac0-4004-acfd-b39f73fb9791")]
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
internal partial interface IDerivedFromExternalDerived : IExternalDerived
#pragma warning restore SYSLIB1230
{
string GetName();
}

View File

@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
[Guid("b158aaf2-85a3-40e7-805f-5797580a05f2")]
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
internal partial interface IDerivedFromDerivedExternalDerived : IDerivedFromExternalDerived
#pragma warning restore SYSLIB1230
{
float GetFloat();
}

View File

@ -8,8 +8,10 @@ using System.Reflection.Metadata;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.CodeAnalysis.Text;
using Microsoft.Interop;
using Xunit;
@ -317,7 +319,7 @@ namespace ComInterfaceGenerator.Unit.Tests
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
partial interface {|#1:IComInterface2|} : IComInterface
@ -342,6 +344,74 @@ namespace ComInterfaceGenerator.Unit.Tests
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
}
[Fact]
public async Task ComInterfacesInheritingFromTheSameInterfaceAcrossCompilationsCalculatesCorrectVTableIndex()
{
string baseSource = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
public partial interface IComInterface
{
void Method();
}
""";
string derivedSource = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
internal partial interface {|#1:IComInterface2|} : IComInterface
{
void DerivedMethod();
}
[GeneratedComInterface]
[Guid("0951f7b7-a700-4de4-930e-0b1fbc4684a9")]
internal partial interface {|#2:IComInterface3|} : IComInterface
{
void DerivedMethod();
}
""";
await VerifyInvocationWithMultipleProjectsAsync(
derivedSource,
baseSource,
"IComInterface2",
"DerivedMethod",
(newComp, _) =>
{
// Validate VTable sizes for interfaces inheriting from the same base
// IUnknown has 3 methods, IComInterface adds 1 method = 4 total
// Both IComInterface2 and IComInterface3 inherit from IComInterface and add 1 method each = 5 total
ValidateInterface("IComInterface2", 5);
ValidateInterface("IComInterface3", 5);
void ValidateInterface(string name, int expectedVTableSize)
{
INamedTypeSymbol? userDefinedInterface = newComp.Assembly.GetTypeByMetadataName(name);
Assert.NotNull(userDefinedInterface);
ITypeSymbol vtableType = new ComInterfaceImplementationLocator().FindVTableStructType(newComp, userDefinedInterface);
int actualVTableSize = vtableType.GetMembers().OfType<IFieldSymbol>().Count();
if (expectedVTableSize != actualVTableSize)
{
Assert.Fail($"VTable size mismatch for {name}. Expected: {expectedVTableSize}, Actual: {actualVTableSize}. VTable structure:\n{vtableType.DeclaringSyntaxReferences[0].GetSyntax().SyntaxTree.GetText()}");
}
Assert.Equal(expectedVTableSize, actualVTableSize);
}
},
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning),
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface3", "IComInterface").WithLocation(2).WithSeverity(DiagnosticSeverity.Warning));
}
[Fact]
public async Task ComInterfaceInheritingAcrossCompilationsChainInBaseCalculatesCorrectVTableIndex()
{
@ -369,7 +439,7 @@ namespace ComInterfaceGenerator.Unit.Tests
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
partial interface {|#1:IComInterface3|} : IComInterface2
@ -446,6 +516,84 @@ namespace ComInterfaceGenerator.Unit.Tests
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
}
[Fact]
public async Task ComInterfaceDeepInheritanceChainCalculatesCorrectVTableSizes()
{
string baseSource = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
public partial interface IComInterface
{
void BaseMethod();
}
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98178")]
public partial interface IComInterface2 : IComInterface
{
void MiddleMethod();
}
""";
string derivedSource = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
partial interface {|#1:IComInterface3|} : IComInterface2
{
void DerivedMethod();
}
[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E9817B")]
partial interface IComInterface4 : IComInterface3
{
void DeepDerivedMethod();
}
""";
await VerifyInvocationWithMultipleProjectsAsync(
derivedSource,
baseSource,
"IComInterface4",
"DeepDerivedMethod",
(newComp, _) =>
{
// Validate VTable sizes for deep inheritance chain
// IUnknown has 3 methods (QueryInterface=0, AddRef=1, Release=2)
// IComInterface: IUnknown (3) + BaseMethod (1) = 4 total
// IComInterface2: IComInterface (4) + MiddleMethod (1) = 5 total
// IComInterface3: IComInterface2 (5) + DerivedMethod (1) = 6 total
// IComInterface4: IComInterface3 (6) + DeepDerivedMethod (1) = 7 total
ValidateInterface("IComInterface3", 6);
ValidateInterface("IComInterface4", 7);
void ValidateInterface(string name, int expectedVTableSize)
{
INamedTypeSymbol? userDefinedInterface = newComp.Assembly.GetTypeByMetadataName(name);
Assert.NotNull(userDefinedInterface);
ITypeSymbol vtableType = new ComInterfaceImplementationLocator().FindVTableStructType(newComp, userDefinedInterface);
int actualVTableSize = vtableType.GetMembers().OfType<IFieldSymbol>().Count();
if (expectedVTableSize != actualVTableSize)
{
Assert.Fail($"VTable size mismatch for {name}. Expected: {expectedVTableSize}, Actual: {actualVTableSize}. VTable structure:\n{vtableType.DeclaringSyntaxReferences[0].GetSyntax().SyntaxTree.GetText()}");
}
Assert.Equal(expectedVTableSize, actualVTableSize);
}
},
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface3", "IComInterface2").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
}
private static async Task VerifyInvocationWithMultipleProjectsAsync(
string thisSource,
string baseSource,
@ -581,6 +729,14 @@ namespace ComInterfaceGenerator.Unit.Tests
return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1];
}
public ITypeSymbol FindVTableStructType(Compilation compilation, INamedTypeSymbol userDefinedInterface)
{
INamedTypeSymbol? implementationInterface = FindImplementationInterface(compilation, userDefinedInterface);
var vtableField = implementationInterface.GetMembers("Vtable").OfType<IFieldSymbol>().SingleOrDefault();
Assert.NotNull(vtableField);
return vtableField.Type;
}
}
}
}

View File

@ -18,6 +18,15 @@ namespace SharedTypes.ComInterfaces
internal new const string IID = "7F0DB364-3C04-4487-9193-4BB05DC7B654";
}
[GeneratedComInterface]
[Guid("D38D8B40-54A4-4685-B048-D04E215E6A93")]
internal partial interface IDerivedBool : IBool
{
void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name);
[return: MarshalUsing(typeof(Utf16StringMarshaller))]
string GetName();
}
[GeneratedComClass]
internal partial class Derived : GetAndSetInt, IDerived

View File

@ -0,0 +1,28 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid(IID)]
internal partial interface IDerivedDerived : IDerived
{
void SetFloat(float name);
float GetFloat();
internal new const string IID = "7F0DB364-3C04-4487-9193-4BB05DC7B654";
}
[GeneratedComClass]
internal partial class DerivedDerived : Derived, IDerivedDerived
{
float _data = 0;
public float GetFloat() => _data;
public void SetFloat(float name) => _data = name;
}
}

View File

@ -17,6 +17,7 @@ namespace SharedTypes.ComInterfaces
public const string IID = "2c3f9903-b586-46b1-881b-adfce9af47b1";
}
[GeneratedComClass]
internal partial class GetAndSetInt : IGetAndSetInt
{

View File

@ -20,7 +20,7 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="..\..\TestAssets\SharedTypes\ComInterfaces\**\*.cs" Link="ComInterfaceGenerator\ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
<Compile Include="..\..\Common\ComInterfaces\**\*.cs" Link="ComInterfaceGenerator\ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs" Link="Common\DisableRuntimeMarshalling.cs" />
</ItemGroup>

View File

@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid(IID)]
public partial interface IExternalBase
{
public int GetInt();
public void SetInt(int x);
public const string IID = "2c3f9903-b586-46b1-881b-adfce9af47b1";
}
}

View File

@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid(IID)]
public partial interface IExternalDerived : IExternalBase
{
[return: MarshalAs(UnmanagedType.Bool)]
bool GetBool();
public void SetBool([MarshalAs(UnmanagedType.Bool)] bool x);
new public const string IID = "594DF2B9-66CE-490D-9D05-34646675B188";
}
}

View File

@ -11,8 +11,6 @@
<ItemGroup>
<ProjectReference Include="..\..\Ancillary.Interop\Ancillary.Interop.csproj" />
<!-- This project just has shared types, but shouldn't generate code for the types -->
<EnabledGenerators Remove="@(EnabledGenerators)" />
</ItemGroup>
</Project>