mirror of https://github.com/dotnet/runtime
Merge a24cf60673
into 02596ba8d9
This commit is contained in:
commit
32f9e73b5c
|
@ -20,11 +20,23 @@ namespace Microsoft.Interop
|
|||
/// <summary>
|
||||
/// COM methods that require shadowing declarations on the derived interface.
|
||||
/// </summary>
|
||||
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);
|
||||
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface && !m.IsExternallyDefined);
|
||||
|
||||
/// <summary>
|
||||
/// COM methods that are declared on an interface the interface inherits from.
|
||||
/// </summary>
|
||||
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
|
||||
|
||||
/// <summary>
|
||||
/// The size of the vtable for this interface, including the base interface methods and IUnknown methods.
|
||||
/// </summary>
|
||||
public int VTableSize => Methods.Length == 0
|
||||
? IUnknownConstants.VTableSize
|
||||
: 1 + Methods.Max(m => m.GenerationContext.VtableIndexData.Index);
|
||||
|
||||
/// <summary>
|
||||
/// The size of the vtable for the base interface, including it's base interface methods and IUnknown methods.
|
||||
/// </summary>
|
||||
public int BaseVTableSize => VTableSize - DeclaredMethods.Count();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ namespace Microsoft.Interop
|
|||
var externalInterfaceSymbols = attributedInterfaces.SelectMany(static (data, ct) =>
|
||||
{
|
||||
return ComInterfaceInfo.CreateInterfaceInfoForBaseInterfacesInOtherCompilations(data.Symbol);
|
||||
});
|
||||
}).Collect().SelectMany(static (data, ct) => data.Distinct(ComInterfaceInfo.EqualityComparerForExternalIfaces.Instance));
|
||||
|
||||
var interfaceSymbolsWithoutDiagnostics = interfaceSymbolsToGenerateWithoutDiagnostics.Concat(externalInterfaceSymbols);
|
||||
|
||||
|
@ -84,11 +84,7 @@ namespace Microsoft.Interop
|
|||
.SelectMany(static (data, ct) =>
|
||||
{
|
||||
return ComMethodContext.CalculateAllMethods(data, ct);
|
||||
})
|
||||
// Now that we've determined method offsets, we can remove all externally defined methods.
|
||||
// We'll also filter out methods originally declared on externally defined base interfaces
|
||||
// as we may not be able to emit them into our assembly.
|
||||
.Where(context => !context.Method.OriginalDeclaringInterface.IsExternallyDefined);
|
||||
});
|
||||
|
||||
// Now that we've determined method offsets, we can remove all externally defined interfaces.
|
||||
var interfaceContextsToGenerate = interfaceContexts.Where(context => !context.IsExternallyDefined);
|
||||
|
@ -107,13 +103,20 @@ namespace Microsoft.Interop
|
|||
return new ComMethodContext(
|
||||
data.Method,
|
||||
data.OwningInterface,
|
||||
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info, ct));
|
||||
CalculateStubInformation(
|
||||
data.Method.MethodInfo.Syntax,
|
||||
symbolMap[data.Method.MethodInfo],
|
||||
data.Method.Index,
|
||||
env,
|
||||
data.OwningInterface.Info,
|
||||
ct));
|
||||
}).WithTrackingName(StepNames.CalculateStubInformation);
|
||||
|
||||
var interfaceAndMethodsContexts = comMethodContexts
|
||||
.Collect()
|
||||
.Combine(interfaceContextsToGenerate.Collect())
|
||||
.SelectMany((data, ct) => GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
|
||||
.SelectMany((data, ct) =>
|
||||
GroupComContextsForInterfaceGeneration(data.Left, data.Right, ct));
|
||||
|
||||
// Generate the code for the managed-to-unmanaged stubs.
|
||||
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
|
||||
|
@ -256,12 +259,22 @@ namespace Microsoft.Interop
|
|||
|| typeName.Equals("hresult", StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterfaceInfo, CancellationToken ct)
|
||||
/// <summary>
|
||||
/// Calculates the shared information needed for both source-available and sourceless stub generation.
|
||||
/// </summary>
|
||||
private static IncrementalMethodStubGenerationContext CalculateSharedStubInformation(
|
||||
IMethodSymbol symbol,
|
||||
int index,
|
||||
StubEnvironment environment,
|
||||
ISignatureDiagnosticLocations diagnosticLocations,
|
||||
ComInterfaceInfo owningInterfaceInfo,
|
||||
CancellationToken ct)
|
||||
{
|
||||
ct.ThrowIfCancellationRequested();
|
||||
INamedTypeSymbol? lcidConversionAttrType = environment.LcidConversionAttrType;
|
||||
INamedTypeSymbol? suppressGCTransitionAttrType = environment.SuppressGCTransitionAttrType;
|
||||
INamedTypeSymbol? unmanagedCallConvAttrType = environment.UnmanagedCallConvAttrType;
|
||||
|
||||
// Get any attributes of interest on the method
|
||||
AttributeData? lcidConversionAttr = null;
|
||||
AttributeData? suppressGCTransitionAttribute = null;
|
||||
|
@ -282,8 +295,7 @@ namespace Microsoft.Interop
|
|||
}
|
||||
}
|
||||
|
||||
var locations = new MethodSignatureDiagnosticLocations(syntax);
|
||||
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), locations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
|
||||
var generatorDiagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), diagnosticLocations, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
|
||||
|
||||
if (lcidConversionAttr is not null)
|
||||
{
|
||||
|
@ -293,8 +305,8 @@ namespace Microsoft.Interop
|
|||
|
||||
GeneratedComInterfaceCompilationData.TryGetGeneratedComInterfaceAttributeFromInterface(symbol.ContainingType, out var generatedComAttribute);
|
||||
var generatedComInterfaceAttributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComAttribute);
|
||||
// Create the stub.
|
||||
|
||||
// Create the stub.
|
||||
var signatureContext = SignatureContext.Create(
|
||||
symbol,
|
||||
DefaultMarshallingInfoParser.Create(
|
||||
|
@ -387,10 +399,6 @@ namespace Microsoft.Interop
|
|||
GeneratorDiagnostics.SizeOfInCollectionMustBeDefinedAtCallReturnValue);
|
||||
}
|
||||
|
||||
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
|
||||
|
||||
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
|
||||
|
||||
ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
|
||||
suppressGCTransitionAttribute,
|
||||
unmanagedCallConvAttribute,
|
||||
|
@ -398,10 +406,7 @@ namespace Microsoft.Interop
|
|||
|
||||
var declaringType = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol.ContainingType);
|
||||
|
||||
var virtualMethodIndexData = new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com);
|
||||
|
||||
MarshallingInfo exceptionMarshallingInfo;
|
||||
|
||||
if (generatedComInterfaceAttributeData.ExceptionToUnmanagedMarshaller is null)
|
||||
{
|
||||
exceptionMarshallingInfo = new ComExceptionMarshalling();
|
||||
|
@ -418,11 +423,9 @@ namespace Microsoft.Interop
|
|||
|
||||
return new IncrementalMethodStubGenerationContext(
|
||||
signatureContext,
|
||||
containingSyntaxContext,
|
||||
methodSyntaxTemplate,
|
||||
locations,
|
||||
diagnosticLocations,
|
||||
callConv.ToSequenceEqualImmutableArray(SyntaxEquivalentComparer.Instance),
|
||||
virtualMethodIndexData,
|
||||
new VirtualMethodIndexData(index, ImplicitThisParameter: true, direction, true, ExceptionMarshalling.Com),
|
||||
exceptionMarshallingInfo,
|
||||
environment.EnvironmentFlags,
|
||||
owningInterfaceInfo.Type,
|
||||
|
@ -431,6 +434,45 @@ namespace Microsoft.Interop
|
|||
ComInterfaceDispatchMarshallingInfo.Instance);
|
||||
}
|
||||
|
||||
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
|
||||
{
|
||||
ISignatureDiagnosticLocations locations = syntax is null
|
||||
? NoneSignatureDiagnosticLocations.Instance
|
||||
: new MethodSignatureDiagnosticLocations(syntax);
|
||||
|
||||
var sourcelessStubInformation = CalculateSharedStubInformation(
|
||||
symbol,
|
||||
index,
|
||||
environment,
|
||||
locations,
|
||||
owningInterface,
|
||||
ct);
|
||||
|
||||
if (syntax is null)
|
||||
return sourcelessStubInformation;
|
||||
|
||||
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
|
||||
var methodSyntaxTemplate = new ContainingSyntax(
|
||||
new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(),
|
||||
SyntaxKind.MethodDeclaration,
|
||||
syntax.Identifier,
|
||||
syntax.TypeParameterList);
|
||||
|
||||
return new SourceAvailableIncrementalMethodStubGenerationContext(
|
||||
sourcelessStubInformation.SignatureContext,
|
||||
containingSyntaxContext,
|
||||
methodSyntaxTemplate,
|
||||
locations,
|
||||
sourcelessStubInformation.CallingConvention,
|
||||
sourcelessStubInformation.VtableIndexData,
|
||||
sourcelessStubInformation.ExceptionMarshallingInfo,
|
||||
sourcelessStubInformation.EnvironmentFlags,
|
||||
sourcelessStubInformation.TypeKeyOwner,
|
||||
sourcelessStubInformation.DeclaringType,
|
||||
sourcelessStubInformation.Diagnostics,
|
||||
ComInterfaceDispatchMarshallingInfo.Instance);
|
||||
}
|
||||
|
||||
private static MarshalDirection GetDirectionFromOptions(ComInterfaceOptions options)
|
||||
{
|
||||
if (options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper | ComInterfaceOptions.ComObjectWrapper))
|
||||
|
@ -520,12 +562,12 @@ namespace Microsoft.Interop
|
|||
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
|
||||
{
|
||||
var definingType = interfaceGroup.Interface.Info.Type;
|
||||
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
|
||||
var shadowImplementations = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
|
||||
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
|
||||
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
|
||||
.WithExplicitInterfaceSpecifier(
|
||||
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
|
||||
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
|
||||
var inheritedStubs = interfaceGroup.InheritedMethods.Where(m => !m.IsExternallyDefined).Select(m => m.UnreachableExceptionStub);
|
||||
return ImplementationInterfaceTemplate
|
||||
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
|
||||
.WithMembers(
|
||||
|
@ -661,7 +703,6 @@ namespace Microsoft.Interop
|
|||
|
||||
BlockSyntax fillBaseInterfaceSlots;
|
||||
|
||||
|
||||
if (interfaceMethods.Interface.Base is null)
|
||||
{
|
||||
// If we don't have a base interface, we need to manually fill in the base iUnknown slots.
|
||||
|
@ -740,7 +781,7 @@ namespace Microsoft.Interop
|
|||
}
|
||||
else
|
||||
{
|
||||
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInteraceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <startingOffset>));
|
||||
// NativeMemory.Copy(StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<baseInterfaceType>).TypeHandle).ManagedVirtualMethodTable, vtable, (nuint)(sizeof(void*) * <baseVTableSize>));
|
||||
fillBaseInterfaceSlots = Block(
|
||||
MethodInvocationStatement(
|
||||
TypeSyntaxes.System_Runtime_InteropServices_NativeMemory,
|
||||
|
@ -750,7 +791,7 @@ namespace Microsoft.Interop
|
|||
TypeSyntaxes.StrategyBasedComWrappers
|
||||
.Dot(IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
|
||||
IdentifierName("GetIUnknownDerivedDetails"),
|
||||
Argument( //baseInterfaceTypeInfo.BaseInterface.FullTypeName)),
|
||||
Argument(
|
||||
TypeOfExpression(ParseTypeName(interfaceMethods.Interface.Base.Info.Type.FullTypeName))
|
||||
.Dot(IdentifierName("TypeHandle"))))
|
||||
.Dot(IdentifierName("ManagedVirtualMethodTable"))),
|
||||
|
@ -767,7 +808,7 @@ namespace Microsoft.Interop
|
|||
ParenthesizedExpression(
|
||||
BinaryExpression(SyntaxKind.MultiplyExpression,
|
||||
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
|
||||
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
|
||||
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.BaseVTableSize))))))));
|
||||
}
|
||||
|
||||
var validDeclaredMethods = interfaceMethods.DeclaredMethods
|
||||
|
@ -787,7 +828,7 @@ namespace Microsoft.Interop
|
|||
IdentifierName($"{declaredMethodContext.MethodInfo.MethodName}_{declaredMethodContext.GenerationContext.VtableIndexData.Index}")),
|
||||
PrefixUnaryExpression(
|
||||
SyntaxKind.AddressOfExpression,
|
||||
IdentifierName($"ABI_{declaredMethodContext.GenerationContext.StubMethodSyntaxTemplate.Identifier}")))));
|
||||
IdentifierName($"ABI_{((SourceAvailableIncrementalMethodStubGenerationContext)declaredMethodContext.GenerationContext).StubMethodSyntaxTemplate.Identifier}")))));
|
||||
}
|
||||
|
||||
return ImplementationInterfaceTemplate
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Diagnostics;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Threading;
|
||||
using Microsoft.CodeAnalysis;
|
||||
|
@ -10,7 +12,6 @@ using Microsoft.CodeAnalysis.CSharp;
|
|||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using InterfaceInfo = (Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol);
|
||||
using DiagnosticOrInterfaceInfo = Microsoft.Interop.DiagnosticOr<(Microsoft.Interop.ComInterfaceInfo InterfaceInfo, Microsoft.CodeAnalysis.INamedTypeSymbol Symbol)>;
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Microsoft.Interop
|
||||
{
|
||||
|
@ -176,6 +177,13 @@ namespace Microsoft.Interop
|
|||
return builder.ToImmutable();
|
||||
}
|
||||
|
||||
internal sealed class EqualityComparerForExternalIfaces : IEqualityComparer<(ComInterfaceInfo InterfaceInfo, INamedTypeSymbol Symbol)>
|
||||
{
|
||||
public bool Equals((ComInterfaceInfo, INamedTypeSymbol) x, (ComInterfaceInfo, INamedTypeSymbol) y) => SymbolEqualityComparer.Default.Equals(x.Item2, y.Item2);
|
||||
public int GetHashCode((ComInterfaceInfo, INamedTypeSymbol) obj) => SymbolEqualityComparer.Default.GetHashCode(obj.Item2);
|
||||
public static readonly EqualityComparerForExternalIfaces Instance = new();
|
||||
}
|
||||
|
||||
private static bool IsInPartialContext(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(false)] out DiagnosticInfo? diagnostic)
|
||||
{
|
||||
// Verify that the types the interface is declared in are marked partial.
|
||||
|
|
|
@ -21,8 +21,8 @@ namespace Microsoft.Interop
|
|||
internal sealed class ComMethodContext : IEquatable<ComMethodContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="IncrementalMethodStubGenerationContext"/> generated for it yet.
|
||||
/// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="IncrementalMethodStubGenerationContext"/> requires an ISymbol
|
||||
/// A partially constructed <see cref="ComMethodContext"/> that does not have a <see cref="SourceAvailableIncrementalMethodStubGenerationContext"/> generated for it yet.
|
||||
/// <see cref="Builder"/> can be constructed without a reference to an ISymbol, whereas the <see cref="SourceAvailableIncrementalMethodStubGenerationContext"/> requires an ISymbol
|
||||
/// </summary>
|
||||
/// <param name="OriginalDeclaringInterface">
|
||||
/// The interface that originally declared the method in user code
|
||||
|
@ -48,7 +48,7 @@ namespace Microsoft.Interop
|
|||
/// <param name="builder">The partially constructed context</param>
|
||||
/// <param name="owningInterface">The final owning interface of this method context</param>
|
||||
/// <param name="generationContext">The generation context for this method</param>
|
||||
public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext generationContext)
|
||||
public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext? generationContext)
|
||||
{
|
||||
_state = new State(builder.OriginalDeclaringInterface, owningInterface, builder.MethodInfo, generationContext);
|
||||
}
|
||||
|
@ -65,6 +65,8 @@ namespace Microsoft.Interop
|
|||
|
||||
public ComMethodInfo MethodInfo => _state.MethodInfo;
|
||||
|
||||
public bool IsExternallyDefined => _state.OriginalDeclaringInterface.IsExternallyDefined || _state.OwningInterface.IsExternallyDefined;
|
||||
|
||||
public IncrementalMethodStubGenerationContext GenerationContext => _state.GenerationContext;
|
||||
|
||||
public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;
|
||||
|
@ -77,12 +79,18 @@ namespace Microsoft.Interop
|
|||
|
||||
private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
|
||||
{
|
||||
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
|
||||
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
|
||||
|| IsHiddenOnDerivedInterface
|
||||
|| IsExternallyDefined)
|
||||
{
|
||||
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
|
||||
}
|
||||
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
|
||||
return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
|
||||
if (GenerationContext is not SourceAvailableIncrementalMethodStubGenerationContext sourceAvailableContext)
|
||||
{
|
||||
throw new InvalidOperationException("Cannot generate stubs for non-source available methods.");
|
||||
}
|
||||
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(sourceAvailableContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
|
||||
return new GeneratedStubCodeContext(sourceAvailableContext.TypeKeyOwner, sourceAvailableContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
|
||||
}
|
||||
|
||||
private GeneratedMethodContextBase? _unmanagedToManagedStub;
|
||||
|
@ -91,12 +99,18 @@ namespace Microsoft.Interop
|
|||
|
||||
private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
|
||||
{
|
||||
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
|
||||
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
|
||||
|| IsHiddenOnDerivedInterface
|
||||
|| IsExternallyDefined)
|
||||
{
|
||||
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
|
||||
}
|
||||
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
|
||||
return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
|
||||
if (GenerationContext is not SourceAvailableIncrementalMethodStubGenerationContext sourceAvailableContext)
|
||||
{
|
||||
throw new InvalidOperationException("Cannot generate stubs for non-source available methods.");
|
||||
}
|
||||
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(sourceAvailableContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
|
||||
return new GeneratedStubCodeContext(sourceAvailableContext.OriginalDefiningType, sourceAvailableContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
|
||||
}
|
||||
|
||||
private MethodDeclarationSyntax? _unreachableExceptionStub;
|
||||
|
@ -183,7 +197,7 @@ namespace Microsoft.Interop
|
|||
return cachedValue;
|
||||
}
|
||||
|
||||
int startingIndex = 3;
|
||||
int startingIndex = IUnknownConstants.VTableSize;
|
||||
List<Builder> methods = new();
|
||||
// If we have a base interface, we should add the inherited methods to our list in vtable order
|
||||
if (iface.Base is not null)
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace Microsoft.Interop
|
|||
/// </summary>
|
||||
internal sealed record ComMethodInfo
|
||||
{
|
||||
public MethodDeclarationSyntax Syntax { get; init; }
|
||||
public MethodDeclarationSyntax? Syntax { get; init; }
|
||||
public string MethodName { get; init; }
|
||||
public SequenceEqualImmutableArray<AttributeInfo> Attributes { get; init; }
|
||||
public bool IsUserDefinedShadowingMethod { get; init; }
|
||||
|
@ -94,7 +94,7 @@ namespace Microsoft.Interop
|
|||
if (ifaceContext.IsExternallyDefined)
|
||||
{
|
||||
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((
|
||||
new ComMethodInfo(null!, method.Name, method.GetAttributes().Select(AttributeInfo.From).ToImmutableArray().ToSequenceEqual(), false),
|
||||
new ComMethodInfo(null, method.Name, method.GetAttributes().Select(AttributeInfo.From).ToImmutableArray().ToSequenceEqual(), false),
|
||||
method));
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
namespace Microsoft.Interop
|
||||
{
|
||||
internal static class IUnknownConstants
|
||||
{
|
||||
public const int QueryInterfaceIndex = 0;
|
||||
public const int AddRefIndex = 1;
|
||||
public const int ReleaseIndex = 2;
|
||||
public const int VTableSize = 3;
|
||||
}
|
||||
}
|
|
@ -1,19 +1,15 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using System;
|
||||
|
||||
namespace Microsoft.Interop
|
||||
{
|
||||
internal abstract record GeneratedMethodContextBase(ManagedTypeInfo OriginalDefiningType, SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics);
|
||||
|
||||
internal sealed record IncrementalMethodStubGenerationContext(
|
||||
internal record IncrementalMethodStubGenerationContext(
|
||||
SignatureContext SignatureContext,
|
||||
ContainingSyntaxContext ContainingSyntaxContext,
|
||||
ContainingSyntax StubMethodSyntaxTemplate,
|
||||
MethodSignatureDiagnosticLocations DiagnosticLocation,
|
||||
ISignatureDiagnosticLocations DiagnosticLocation,
|
||||
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
|
||||
VirtualMethodIndexData VtableIndexData,
|
||||
MarshallingInfo ExceptionMarshallingInfo,
|
||||
|
@ -22,4 +18,28 @@ namespace Microsoft.Interop
|
|||
ManagedTypeInfo DeclaringType,
|
||||
SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics,
|
||||
MarshallingInfo ManagedThisMarshallingInfo) : GeneratedMethodContextBase(DeclaringType, Diagnostics);
|
||||
|
||||
internal sealed record SourceAvailableIncrementalMethodStubGenerationContext(
|
||||
SignatureContext SignatureContext,
|
||||
ContainingSyntaxContext ContainingSyntaxContext,
|
||||
ContainingSyntax StubMethodSyntaxTemplate,
|
||||
ISignatureDiagnosticLocations DiagnosticLocation,
|
||||
SequenceEqualImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> CallingConvention,
|
||||
VirtualMethodIndexData VtableIndexData,
|
||||
MarshallingInfo ExceptionMarshallingInfo,
|
||||
EnvironmentFlags EnvironmentFlags,
|
||||
ManagedTypeInfo TypeKeyOwner,
|
||||
ManagedTypeInfo DeclaringType,
|
||||
SequenceEqualImmutableArray<DiagnosticInfo> Diagnostics,
|
||||
MarshallingInfo ManagedThisMarshallingInfo) : IncrementalMethodStubGenerationContext(
|
||||
SignatureContext,
|
||||
DiagnosticLocation,
|
||||
CallingConvention,
|
||||
VtableIndexData,
|
||||
ExceptionMarshallingInfo,
|
||||
EnvironmentFlags,
|
||||
TypeKeyOwner,
|
||||
DeclaringType,
|
||||
Diagnostics,
|
||||
ManagedThisMarshallingInfo);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace Microsoft.Interop
|
|||
internal const string VirtualMethodTarget = "__target";
|
||||
|
||||
public static (MethodDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateManagedToNativeStub(
|
||||
IncrementalMethodStubGenerationContext methodStub,
|
||||
SourceAvailableIncrementalMethodStubGenerationContext methodStub,
|
||||
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
|
||||
{
|
||||
var diagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), methodStub.DiagnosticLocation, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
|
||||
|
@ -128,7 +128,7 @@ namespace Microsoft.Interop
|
|||
private const string ManagedThisParameterIdentifier = "@this";
|
||||
|
||||
public static (MethodDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateNativeToManagedStub(
|
||||
IncrementalMethodStubGenerationContext methodStub,
|
||||
SourceAvailableIncrementalMethodStubGenerationContext methodStub,
|
||||
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
|
||||
{
|
||||
var diagnostics = new GeneratorDiagnosticsBag(new DiagnosticDescriptorProvider(), methodStub.DiagnosticLocation, SR.ResourceManager, typeof(FxResources.Microsoft.Interop.ComInterfaceGenerator.SR));
|
||||
|
@ -174,7 +174,7 @@ namespace Microsoft.Interop
|
|||
methodStub.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
|
||||
}
|
||||
|
||||
private static ImmutableArray<TypePositionInfo> AddManagedToUnmanagedImplicitThis(IncrementalMethodStubGenerationContext methodStub)
|
||||
private static ImmutableArray<TypePositionInfo> AddManagedToUnmanagedImplicitThis(SourceAvailableIncrementalMethodStubGenerationContext methodStub)
|
||||
{
|
||||
ImmutableArray<TypePositionInfo> originalElements = methodStub.SignatureContext.ElementTypeInformation;
|
||||
|
||||
|
@ -232,7 +232,7 @@ namespace Microsoft.Interop
|
|||
}
|
||||
|
||||
public static BlockSyntax GenerateVirtualMethodTableSlotAssignments(
|
||||
IEnumerable<IncrementalMethodStubGenerationContext> vtableMethods,
|
||||
IEnumerable<SourceAvailableIncrementalMethodStubGenerationContext> vtableMethods,
|
||||
string vtableIdentifier,
|
||||
Func<EnvironmentFlags, MarshalDirection, IMarshallingGeneratorResolver> generatorResolverCreator)
|
||||
{
|
||||
|
|
|
@ -62,7 +62,7 @@ namespace Microsoft.Interop
|
|||
|
||||
// Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs
|
||||
// for each method.
|
||||
IncrementalValuesProvider<IncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
|
||||
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> generateStubInformation = methodsToGenerate
|
||||
.Combine(context.CreateStubEnvironmentProvider())
|
||||
.Select(static (data, ct) => new
|
||||
{
|
||||
|
@ -89,7 +89,7 @@ namespace Microsoft.Interop
|
|||
context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs");
|
||||
|
||||
// Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation.
|
||||
IncrementalValuesProvider<IncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
|
||||
IncrementalValuesProvider<SourceAvailableIncrementalMethodStubGenerationContext> nativeToManagedStubContexts =
|
||||
generateStubInformation
|
||||
.Where(data => data.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional);
|
||||
|
||||
|
@ -195,7 +195,7 @@ namespace Microsoft.Interop
|
|||
};
|
||||
}
|
||||
|
||||
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
|
||||
private static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
|
||||
{
|
||||
ct.ThrowIfCancellationRequested();
|
||||
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
|
||||
|
@ -306,7 +306,7 @@ namespace Microsoft.Interop
|
|||
|
||||
MarshallingInfo exceptionMarshallingInfo = CreateExceptionMarshallingInfo(virtualMethodIndexAttr, symbol, environment.Compilation, generatorDiagnostics, virtualMethodIndexData);
|
||||
|
||||
return new IncrementalMethodStubGenerationContext(
|
||||
return new SourceAvailableIncrementalMethodStubGenerationContext(
|
||||
signatureContext,
|
||||
containingSyntaxContext,
|
||||
methodSyntaxTemplate,
|
||||
|
@ -363,7 +363,7 @@ namespace Microsoft.Interop
|
|||
}
|
||||
|
||||
private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateManagedToNativeStub(
|
||||
IncrementalMethodStubGenerationContext methodStub)
|
||||
SourceAvailableIncrementalMethodStubGenerationContext methodStub)
|
||||
{
|
||||
var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
|
||||
|
||||
|
@ -376,7 +376,7 @@ namespace Microsoft.Interop
|
|||
}
|
||||
|
||||
private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) GenerateNativeToManagedStub(
|
||||
IncrementalMethodStubGenerationContext methodStub)
|
||||
SourceAvailableIncrementalMethodStubGenerationContext methodStub)
|
||||
{
|
||||
var (stub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(methodStub, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
|
||||
|
||||
|
@ -433,7 +433,7 @@ namespace Microsoft.Interop
|
|||
.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(NameSyntaxes.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute)))));
|
||||
}
|
||||
|
||||
private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<ContainingSyntaxContext, IncrementalMethodStubGenerationContext> vtableMethods)
|
||||
private static MemberDeclarationSyntax GeneratePopulateVTableMethod(IGrouping<ContainingSyntaxContext, SourceAvailableIncrementalMethodStubGenerationContext> vtableMethods)
|
||||
{
|
||||
ContainingSyntaxContext containingSyntax = vtableMethods.Key.AddContainingSyntax(NativeTypeContainingSyntax);
|
||||
|
||||
|
|
|
@ -42,6 +42,15 @@ namespace Microsoft.Interop
|
|||
DiagnosticInfo CreateDiagnosticInfo(DiagnosticDescriptor descriptor, GeneratorDiagnostic diagnostic);
|
||||
}
|
||||
|
||||
public class NoneSignatureDiagnosticLocations : ISignatureDiagnosticLocations
|
||||
{
|
||||
public static readonly NoneSignatureDiagnosticLocations Instance = new();
|
||||
public DiagnosticInfo CreateDiagnosticInfo(DiagnosticDescriptor descriptor, GeneratorDiagnostic diagnostic)
|
||||
{
|
||||
return diagnostic.ToDiagnosticInfo(descriptor, Location.None, string.Empty);
|
||||
}
|
||||
}
|
||||
|
||||
public sealed record MethodSignatureDiagnosticLocations(string MethodIdentifier, ImmutableArray<Location> ManagedParameterLocations, Location FallbackLocation) : ISignatureDiagnosticLocations
|
||||
{
|
||||
public MethodSignatureDiagnosticLocations(MethodDeclarationSyntax syntax)
|
||||
|
|
|
@ -96,7 +96,7 @@ namespace Microsoft.Interop
|
|||
ByValueContentsMarshalKind = byValueContentsMarshalKind,
|
||||
ByValueMarshalAttributeLocations = (inLocation, outLocation),
|
||||
ScopedKind = paramSymbol.ScopedKind,
|
||||
IsExplicitThis = ((ParameterSyntax)paramSymbol.DeclaringSyntaxReferences[0].GetSyntax()).Modifiers.Any(SyntaxKind.ThisKeyword)
|
||||
IsExplicitThis = ((ParameterSyntax?)paramSymbol.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax())?.Modifiers.Any(SyntaxKind.ThisKeyword) ?? false
|
||||
};
|
||||
|
||||
return typeInfo;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
<ItemGroup>
|
||||
<EnabledGenerators Include="ComInterfaceGenerator" />
|
||||
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs" Link="Common\DisableRuntimeMarshalling.cs" />
|
||||
<Compile Include="..\TestAssets\SharedTypes\ComInterfaces\**\*.cs" Link="ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
|
||||
<Compile Include="..\Common\ComInterfaces\**\*.cs" Link="ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
using SharedTypes.ComInterfaces;
|
||||
using Xunit;
|
||||
|
||||
namespace ComInterfaceGenerator.Tests
|
||||
{
|
||||
public partial class CrossAssemblyInheritanceTests
|
||||
{
|
||||
[GeneratedComClass]
|
||||
[Guid("e0c6b35f-1234-4567-8901-123456789abc")]
|
||||
internal partial class DerivedExternalBaseImpl : IDerivedExternalBase
|
||||
{
|
||||
private int _value = 10;
|
||||
|
||||
public int GetInt() => _value;
|
||||
public void SetInt(int x) => _value = x;
|
||||
public string GetName() => "DerivedExternalBase";
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
[Guid("e0c6b35f-1234-4567-8901-123456789abd")]
|
||||
internal partial class DerivedExternalBase2Impl : IDerivedExternalBase2
|
||||
{
|
||||
private int _value = 20;
|
||||
|
||||
public int GetInt() => _value;
|
||||
public void SetInt(int x) => _value = x;
|
||||
public string GetName() => "DerivedExternalBase2";
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
[Guid("e0c6b35f-1234-4567-8901-123456789abe")]
|
||||
internal partial class DerivedFromExternalDerivedImpl : IDerivedFromExternalDerived
|
||||
{
|
||||
private int _value = 30;
|
||||
private bool _boolValue = true;
|
||||
|
||||
public int GetInt() => _value;
|
||||
public void SetInt(int x) => _value = x;
|
||||
public bool GetBool() => _boolValue;
|
||||
public void SetBool(bool x) => _boolValue = x;
|
||||
public string GetName() => "DerivedFromExternalDerived";
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
[Guid("e0c6b35f-1234-4567-8901-123456789abf")]
|
||||
internal partial class DerivedFromDerivedExternalDerivedImpl : IDerivedFromDerivedExternalDerived
|
||||
{
|
||||
private int _value = 40;
|
||||
private bool _boolValue = false;
|
||||
private float _floatValue = 3.14f;
|
||||
|
||||
public int GetInt() => _value;
|
||||
public void SetInt(int x) => _value = x;
|
||||
public bool GetBool() => _boolValue;
|
||||
public void SetBool(bool x) => _boolValue = x;
|
||||
public string GetName() => "DerivedFromDerivedExternalDerived";
|
||||
public float GetFloat() => _floatValue;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IDerivedExternalBase_CanCallMethods()
|
||||
{
|
||||
var implementation = new DerivedExternalBaseImpl();
|
||||
var comWrappers = new StrategyBasedComWrappers();
|
||||
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
|
||||
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
|
||||
var externalBase = (IExternalBase)managedObj;
|
||||
Assert.Equal(10, externalBase.GetInt());
|
||||
externalBase.SetInt(15);
|
||||
Assert.Equal(15, externalBase.GetInt());
|
||||
|
||||
var derivedExternalBase = (IDerivedExternalBase)managedObj;
|
||||
Assert.Equal(15, derivedExternalBase.GetInt());
|
||||
Assert.Equal("DerivedExternalBase", derivedExternalBase.GetName());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IDerivedExternalBase2_CanCallMethods()
|
||||
{
|
||||
var implementation = new DerivedExternalBase2Impl();
|
||||
var comWrappers = new StrategyBasedComWrappers();
|
||||
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
|
||||
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
|
||||
// Test as base interface
|
||||
var externalBase = (IExternalBase)managedObj;
|
||||
Assert.Equal(20, externalBase.GetInt());
|
||||
externalBase.SetInt(25);
|
||||
Assert.Equal(25, externalBase.GetInt());
|
||||
|
||||
// Test as derived interface
|
||||
var derivedExternalBase2 = (IDerivedExternalBase2)managedObj;
|
||||
Assert.Equal(25, derivedExternalBase2.GetInt());
|
||||
Assert.Equal("DerivedExternalBase2", derivedExternalBase2.GetName());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IDerivedFromExternalDerived_CanCallMethods()
|
||||
{
|
||||
var implementation = new DerivedFromExternalDerivedImpl();
|
||||
var comWrappers = new StrategyBasedComWrappers();
|
||||
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
|
||||
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
|
||||
var externalBase = (IExternalBase)managedObj;
|
||||
Assert.Equal(30, externalBase.GetInt());
|
||||
externalBase.SetInt(35);
|
||||
Assert.Equal(35, externalBase.GetInt());
|
||||
|
||||
var externalDerived = (IExternalDerived)managedObj;
|
||||
Assert.Equal(35, externalDerived.GetInt());
|
||||
Assert.True(externalDerived.GetBool());
|
||||
externalDerived.SetBool(false);
|
||||
Assert.False(externalDerived.GetBool());
|
||||
|
||||
var derivedFromExternalDerived = (IDerivedFromExternalDerived)managedObj;
|
||||
Assert.Equal(35, derivedFromExternalDerived.GetInt());
|
||||
Assert.False(derivedFromExternalDerived.GetBool());
|
||||
Assert.Equal("DerivedFromExternalDerived", derivedFromExternalDerived.GetName());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IDerivedFromDerivedExternalDerived_CanCallMethods()
|
||||
{
|
||||
var implementation = new DerivedFromDerivedExternalDerivedImpl();
|
||||
var comWrappers = new StrategyBasedComWrappers();
|
||||
var nativeObj = comWrappers.GetOrCreateComInterfaceForObject(implementation, CreateComInterfaceFlags.None);
|
||||
var managedObj = comWrappers.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
|
||||
var externalBase = (IExternalBase)managedObj;
|
||||
Assert.Equal(40, externalBase.GetInt());
|
||||
externalBase.SetInt(45);
|
||||
Assert.Equal(45, externalBase.GetInt());
|
||||
|
||||
var externalDerived = (IExternalDerived)managedObj;
|
||||
Assert.Equal(45, externalDerived.GetInt());
|
||||
Assert.False(externalDerived.GetBool());
|
||||
externalDerived.SetBool(true);
|
||||
Assert.True(externalDerived.GetBool());
|
||||
|
||||
var derivedFromExternalDerived = (IDerivedFromExternalDerived)managedObj;
|
||||
Assert.Equal(45, derivedFromExternalDerived.GetInt());
|
||||
Assert.True(derivedFromExternalDerived.GetBool());
|
||||
Assert.Equal("DerivedFromDerivedExternalDerived", derivedFromExternalDerived.GetName());
|
||||
|
||||
var derivedFromDerivedExternalDerived = (IDerivedFromDerivedExternalDerived)managedObj;
|
||||
Assert.Equal(45, derivedFromDerivedExternalDerived.GetInt());
|
||||
Assert.True(derivedFromDerivedExternalDerived.GetBool());
|
||||
Assert.Equal("DerivedFromDerivedExternalDerived", derivedFromDerivedExternalDerived.GetName());
|
||||
Assert.Equal(3.14f, derivedFromDerivedExternalDerived.GetFloat());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public unsafe void MultipleInterfacesDerivedFromSameBase_ShareCommonVTableLayout()
|
||||
{
|
||||
IIUnknownDerivedDetails baseDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
|
||||
|
||||
IIUnknownDerivedDetails derived1Details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase).TypeHandle);
|
||||
|
||||
IIUnknownDerivedDetails derived2Details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase2).TypeHandle);
|
||||
|
||||
var numBaseMethods = typeof(IExternalBase).GetMethods().Length;
|
||||
var numPointersToCompare = 3 + numBaseMethods; // IUnknown (3) + base methods
|
||||
|
||||
// Both derived interfaces should have the same base vtable layout
|
||||
var baseVTable = new ReadOnlySpan<nint>(baseDetails.ManagedVirtualMethodTable, numPointersToCompare);
|
||||
var derived1BaseVTable = new ReadOnlySpan<nint>(derived1Details.ManagedVirtualMethodTable, numPointersToCompare);
|
||||
var derived2BaseVTable = new ReadOnlySpan<nint>(derived2Details.ManagedVirtualMethodTable, numPointersToCompare);
|
||||
|
||||
Assert.True(baseVTable.SequenceEqual(derived1BaseVTable),
|
||||
"IDerivedExternalBase should have consistent base vtable layout");
|
||||
Assert.True(baseVTable.SequenceEqual(derived2BaseVTable),
|
||||
"IDerivedExternalBase2 should have consistent base vtable layout");
|
||||
Assert.True(derived1BaseVTable.SequenceEqual(derived2BaseVTable),
|
||||
"Both derived interfaces should have identical base vtable layouts");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public unsafe void CrossAssemblyInheritance_VTableLayoutIsCorrect()
|
||||
{
|
||||
IIUnknownDerivedDetails baseInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
|
||||
|
||||
IIUnknownDerivedDetails derivedInterfaceDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IDerivedExternalBase).TypeHandle);
|
||||
|
||||
var numBaseMethods = typeof(IExternalBase).GetMethods().Length;
|
||||
var numPointersToCompare = 3 + numBaseMethods;
|
||||
|
||||
// The first part of the vtable should match between base and derived
|
||||
var expectedBaseVTable = new ReadOnlySpan<nint>(baseInterfaceDetails.ManagedVirtualMethodTable, numPointersToCompare);
|
||||
var actualDerivedVTable = new ReadOnlySpan<nint>(derivedInterfaceDetails.ManagedVirtualMethodTable, numPointersToCompare);
|
||||
|
||||
Assert.True(expectedBaseVTable.SequenceEqual(actualDerivedVTable),
|
||||
"Base interface methods should have the same vtable entries in derived interface");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public unsafe void IDerivedFromDerivedExternalDerived_VTableLayoutIsCorrect()
|
||||
{
|
||||
var baseDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IExternalBase).TypeHandle);
|
||||
var externalDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IExternalDerived).TypeHandle);
|
||||
var derivedFromExternalDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IDerivedFromExternalDerived).TypeHandle);
|
||||
var deepDerivedDetails = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy
|
||||
.GetIUnknownDerivedDetails(typeof(IDerivedFromDerivedExternalDerived).TypeHandle);
|
||||
|
||||
var baseMethods = typeof(IExternalBase).GetMethods().Length;
|
||||
var baseVTableSize = 3 + baseMethods;
|
||||
|
||||
var baseVTable = new ReadOnlySpan<nint>(baseDetails.ManagedVirtualMethodTable, baseVTableSize);
|
||||
var externalDerivedBaseVTable = new ReadOnlySpan<nint>(externalDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
|
||||
var derivedFromExternalDerivedBaseVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
|
||||
var deepDerivedBaseVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, baseVTableSize);
|
||||
|
||||
Assert.True(baseVTable.SequenceEqual(externalDerivedBaseVTable),
|
||||
"IExternalDerived should have consistent base vtable layout");
|
||||
Assert.True(baseVTable.SequenceEqual(derivedFromExternalDerivedBaseVTable),
|
||||
"IDerivedFromExternalDerived should have consistent base vtable layout");
|
||||
Assert.True(baseVTable.SequenceEqual(deepDerivedBaseVTable),
|
||||
"IDerivedFromDerivedExternalDerived should have consistent base vtable layout");
|
||||
|
||||
var externalDerivedMethods = typeof(IExternalDerived).GetMethods().Length;
|
||||
var externalDerivedVTableSize = 3 + externalDerivedMethods;
|
||||
|
||||
var externalDerivedVTable = new ReadOnlySpan<nint>(externalDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
|
||||
var derivedFromExternalDerivedIntermediateVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
|
||||
var deepDerivedIntermediateVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, externalDerivedVTableSize);
|
||||
|
||||
Assert.True(externalDerivedVTable.SequenceEqual(derivedFromExternalDerivedIntermediateVTable),
|
||||
"IDerivedFromExternalDerived should have consistent IExternalDerived vtable layout");
|
||||
Assert.True(externalDerivedVTable.SequenceEqual(deepDerivedIntermediateVTable),
|
||||
"IDerivedFromDerivedExternalDerived should have consistent IExternalDerived vtable layout");
|
||||
|
||||
var derivedFromExternalDerivedMethods = typeof(IDerivedFromExternalDerived).GetMethods().Length;
|
||||
var derivedFromExternalDerivedVTableSize = 3 + derivedFromExternalDerivedMethods;
|
||||
var derivedFromExternalDerivedVTable = new ReadOnlySpan<nint>(derivedFromExternalDerivedDetails.ManagedVirtualMethodTable, derivedFromExternalDerivedVTableSize);
|
||||
var deepDerivedVTable = new ReadOnlySpan<nint>(deepDerivedDetails.ManagedVirtualMethodTable, derivedFromExternalDerivedVTableSize);
|
||||
Assert.True(derivedFromExternalDerivedVTable.SequenceEqual(deepDerivedVTable),
|
||||
"IDerivedFromDerivedExternalDerived should have consistent IDerivedFromExternalDerived vtable layout");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -13,6 +13,56 @@ namespace ComInterfaceGenerator.Tests
|
|||
{
|
||||
public partial class IDerivedTests
|
||||
{
|
||||
[GeneratedComInterface]
|
||||
[Guid("7F0DB364-3C04-4487-9193-4BB05DC7B654")]
|
||||
internal partial interface IDerivedFromSharedType2 : SharedTypes.ComInterfaces.IGetAndSetInt
|
||||
{
|
||||
int GetTwoTimesInt();
|
||||
}
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("7F0DB364-3C04-4487-9194-4BB05DC7B654")]
|
||||
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
|
||||
internal partial interface IDerivedFromSharedType : SharedTypes.ComInterfaces.IGetAndSetInt
|
||||
#pragma warning restore SYSLIB1230
|
||||
{
|
||||
int GetIntPlusOne();
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
[Guid("7F0DB364-3C04-4487-9195-4BB05DC7B654")]
|
||||
internal partial class DerivedFromSharedTypeImpl : IDerivedFromSharedType, IDerivedFromSharedType2
|
||||
{
|
||||
int _value = 42;
|
||||
|
||||
public int GetInt() => _value;
|
||||
public int GetIntPlusOne() => _value + 1;
|
||||
public int GetTwoTimesInt() => _value * 2;
|
||||
public void SetInt(int value) { _value = value; }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public unsafe void TypesDerivedFromSharedTypeHaveCorrectVTableSize()
|
||||
{
|
||||
var managedSourceObject = new DerivedFromSharedTypeImpl();
|
||||
var cw = new StrategyBasedComWrappers();
|
||||
var nativeObj = cw.GetOrCreateComInterfaceForObject(managedSourceObject, CreateComInterfaceFlags.None);
|
||||
object managedObj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
IGetAndSetInt getAndSetInt = (IGetAndSetInt)managedObj;
|
||||
IDerivedFromSharedType derivedFromSharedType = (IDerivedFromSharedType)managedObj;
|
||||
IDerivedFromSharedType2 derivedFromSharedType2 = (IDerivedFromSharedType2)managedObj;
|
||||
|
||||
Assert.Equal(42, getAndSetInt.GetInt());
|
||||
Assert.Equal(42, derivedFromSharedType.GetInt());
|
||||
Assert.Equal(42, derivedFromSharedType2.GetInt());
|
||||
|
||||
getAndSetInt.SetInt(100);
|
||||
Assert.Equal(100, getAndSetInt.GetInt());
|
||||
Assert.Equal(101, derivedFromSharedType.GetIntPlusOne());
|
||||
Assert.Equal(200, derivedFromSharedType2.GetTwoTimesInt());
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public unsafe void DerivedInterfaceTypeProvidesBaseInterfaceUnmanagedToManagedMembers()
|
||||
{
|
||||
|
@ -36,16 +86,16 @@ namespace ComInterfaceGenerator.Tests
|
|||
public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce()
|
||||
{
|
||||
var cw = new SingleQIComWrapper();
|
||||
var derivedImpl = new DerivedImpl();
|
||||
var derivedImpl = new Derived();
|
||||
var nativeObj = cw.GetOrCreateComInterfaceForObject(derivedImpl, CreateComInterfaceFlags.None);
|
||||
var obj = cw.GetOrCreateObjectForComInstance(nativeObj, CreateObjectFlags.None);
|
||||
IDerived iface = (IDerived)obj;
|
||||
|
||||
Assert.Equal(3, iface.GetInt());
|
||||
Assert.Equal(0, iface.GetInt());
|
||||
iface.SetInt(5);
|
||||
Assert.Equal(5, iface.GetInt());
|
||||
|
||||
Assert.Equal("myName", iface.GetName());
|
||||
Assert.Equal("hello", iface.GetName());
|
||||
iface.SetName("updated");
|
||||
Assert.Equal("updated", iface.GetName());
|
||||
|
||||
|
@ -58,22 +108,6 @@ namespace ComInterfaceGenerator.Tests
|
|||
Assert.Equal(1, countQi.QiCallCount);
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
partial class DerivedImpl : IDerived
|
||||
{
|
||||
int data = 3;
|
||||
string myName = "myName";
|
||||
public void DoThingWithString(string name) => throw new NotImplementedException();
|
||||
|
||||
public int GetInt() => data;
|
||||
|
||||
public string GetName() => myName;
|
||||
|
||||
public void SetInt(int n) => data = n;
|
||||
|
||||
public void SetName(string name) => myName = name;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Used to ensure that QI is only called once when calling base methods on a derived COM interface
|
||||
/// </summary>
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
using SharedTypes.ComInterfaces;
|
||||
|
||||
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
|
||||
[Guid("da8eed10-f3f4-42c3-86de-04f2dc56514e")]
|
||||
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
|
||||
internal partial interface IDerivedExternalBase : IExternalBase
|
||||
#pragma warning restore SYSLIB1230
|
||||
{
|
||||
string GetName();
|
||||
}
|
||||
|
||||
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
|
||||
[Guid("c3d3990e-5b05-4a9b-adc4-58c521700ece")]
|
||||
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
|
||||
internal partial interface IDerivedExternalBase2 : IExternalBase
|
||||
#pragma warning restore SYSLIB1230
|
||||
{
|
||||
string GetName();
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
using SharedTypes.ComInterfaces;
|
||||
|
||||
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
|
||||
[Guid("f252bddd-aac0-4004-acfd-b39f73fb9791")]
|
||||
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
|
||||
internal partial interface IDerivedFromExternalDerived : IExternalDerived
|
||||
#pragma warning restore SYSLIB1230
|
||||
{
|
||||
string GetName();
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
using SharedTypes.ComInterfaces;
|
||||
|
||||
[GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
|
||||
[Guid("b158aaf2-85a3-40e7-805f-5797580a05f2")]
|
||||
#pragma warning disable SYSLIB1230 // Specifying 'GeneratedComInterfaceAttribute' on an interface that has a base interface defined in another assembly is not supported
|
||||
internal partial interface IDerivedFromDerivedExternalDerived : IDerivedFromExternalDerived
|
||||
#pragma warning restore SYSLIB1230
|
||||
{
|
||||
float GetFloat();
|
||||
}
|
|
@ -8,8 +8,10 @@ using System.Reflection.Metadata;
|
|||
using System.Threading.Tasks;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Microsoft.CodeAnalysis.Operations;
|
||||
using Microsoft.CodeAnalysis.Testing;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using Microsoft.Interop;
|
||||
using Xunit;
|
||||
|
||||
|
@ -317,7 +319,7 @@ namespace ComInterfaceGenerator.Unit.Tests
|
|||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
|
||||
partial interface {|#1:IComInterface2|} : IComInterface
|
||||
|
@ -342,6 +344,74 @@ namespace ComInterfaceGenerator.Unit.Tests
|
|||
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ComInterfacesInheritingFromTheSameInterfaceAcrossCompilationsCalculatesCorrectVTableIndex()
|
||||
{
|
||||
string baseSource = $$"""
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
|
||||
public partial interface IComInterface
|
||||
{
|
||||
void Method();
|
||||
}
|
||||
""";
|
||||
|
||||
string derivedSource = $$"""
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
|
||||
internal partial interface {|#1:IComInterface2|} : IComInterface
|
||||
{
|
||||
void DerivedMethod();
|
||||
}
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0951f7b7-a700-4de4-930e-0b1fbc4684a9")]
|
||||
internal partial interface {|#2:IComInterface3|} : IComInterface
|
||||
{
|
||||
void DerivedMethod();
|
||||
}
|
||||
""";
|
||||
|
||||
await VerifyInvocationWithMultipleProjectsAsync(
|
||||
derivedSource,
|
||||
baseSource,
|
||||
"IComInterface2",
|
||||
"DerivedMethod",
|
||||
(newComp, _) =>
|
||||
{
|
||||
// Validate VTable sizes for interfaces inheriting from the same base
|
||||
// IUnknown has 3 methods, IComInterface adds 1 method = 4 total
|
||||
// Both IComInterface2 and IComInterface3 inherit from IComInterface and add 1 method each = 5 total
|
||||
ValidateInterface("IComInterface2", 5);
|
||||
ValidateInterface("IComInterface3", 5);
|
||||
|
||||
void ValidateInterface(string name, int expectedVTableSize)
|
||||
{
|
||||
INamedTypeSymbol? userDefinedInterface = newComp.Assembly.GetTypeByMetadataName(name);
|
||||
Assert.NotNull(userDefinedInterface);
|
||||
ITypeSymbol vtableType = new ComInterfaceImplementationLocator().FindVTableStructType(newComp, userDefinedInterface);
|
||||
int actualVTableSize = vtableType.GetMembers().OfType<IFieldSymbol>().Count();
|
||||
|
||||
if (expectedVTableSize != actualVTableSize)
|
||||
{
|
||||
Assert.Fail($"VTable size mismatch for {name}. Expected: {expectedVTableSize}, Actual: {actualVTableSize}. VTable structure:\n{vtableType.DeclaringSyntaxReferences[0].GetSyntax().SyntaxTree.GetText()}");
|
||||
}
|
||||
|
||||
Assert.Equal(expectedVTableSize, actualVTableSize);
|
||||
}
|
||||
},
|
||||
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning),
|
||||
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface3", "IComInterface").WithLocation(2).WithSeverity(DiagnosticSeverity.Warning));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ComInterfaceInheritingAcrossCompilationsChainInBaseCalculatesCorrectVTableIndex()
|
||||
{
|
||||
|
@ -369,7 +439,7 @@ namespace ComInterfaceGenerator.Unit.Tests
|
|||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
|
||||
partial interface {|#1:IComInterface3|} : IComInterface2
|
||||
|
@ -446,6 +516,84 @@ namespace ComInterfaceGenerator.Unit.Tests
|
|||
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface2", "IComInterface").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ComInterfaceDeepInheritanceChainCalculatesCorrectVTableSizes()
|
||||
{
|
||||
string baseSource = $$"""
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
|
||||
public partial interface IComInterface
|
||||
{
|
||||
void BaseMethod();
|
||||
}
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E98178")]
|
||||
public partial interface IComInterface2 : IComInterface
|
||||
{
|
||||
void MiddleMethod();
|
||||
}
|
||||
""";
|
||||
|
||||
string derivedSource = $$"""
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E9817A")]
|
||||
partial interface {|#1:IComInterface3|} : IComInterface2
|
||||
{
|
||||
void DerivedMethod();
|
||||
}
|
||||
|
||||
[GeneratedComInterface]
|
||||
[Guid("0A617667-4961-4F90-B74F-6DC368E9817B")]
|
||||
partial interface IComInterface4 : IComInterface3
|
||||
{
|
||||
void DeepDerivedMethod();
|
||||
}
|
||||
""";
|
||||
|
||||
await VerifyInvocationWithMultipleProjectsAsync(
|
||||
derivedSource,
|
||||
baseSource,
|
||||
"IComInterface4",
|
||||
"DeepDerivedMethod",
|
||||
(newComp, _) =>
|
||||
{
|
||||
// Validate VTable sizes for deep inheritance chain
|
||||
// IUnknown has 3 methods (QueryInterface=0, AddRef=1, Release=2)
|
||||
// IComInterface: IUnknown (3) + BaseMethod (1) = 4 total
|
||||
// IComInterface2: IComInterface (4) + MiddleMethod (1) = 5 total
|
||||
// IComInterface3: IComInterface2 (5) + DerivedMethod (1) = 6 total
|
||||
// IComInterface4: IComInterface3 (6) + DeepDerivedMethod (1) = 7 total
|
||||
|
||||
ValidateInterface("IComInterface3", 6);
|
||||
ValidateInterface("IComInterface4", 7);
|
||||
|
||||
void ValidateInterface(string name, int expectedVTableSize)
|
||||
{
|
||||
INamedTypeSymbol? userDefinedInterface = newComp.Assembly.GetTypeByMetadataName(name);
|
||||
Assert.NotNull(userDefinedInterface);
|
||||
ITypeSymbol vtableType = new ComInterfaceImplementationLocator().FindVTableStructType(newComp, userDefinedInterface);
|
||||
int actualVTableSize = vtableType.GetMembers().OfType<IFieldSymbol>().Count();
|
||||
|
||||
if (expectedVTableSize != actualVTableSize)
|
||||
{
|
||||
Assert.Fail($"VTable size mismatch for {name}. Expected: {expectedVTableSize}, Actual: {actualVTableSize}. VTable structure:\n{vtableType.DeclaringSyntaxReferences[0].GetSyntax().SyntaxTree.GetText()}");
|
||||
}
|
||||
|
||||
Assert.Equal(expectedVTableSize, actualVTableSize);
|
||||
}
|
||||
},
|
||||
VerifyCS.DiagnosticWithArguments(GeneratorDiagnostics.BaseInterfaceDefinedInOtherAssembly, "IComInterface3", "IComInterface2").WithLocation(1).WithSeverity(DiagnosticSeverity.Warning));
|
||||
}
|
||||
|
||||
private static async Task VerifyInvocationWithMultipleProjectsAsync(
|
||||
string thisSource,
|
||||
string baseSource,
|
||||
|
@ -581,6 +729,14 @@ namespace ComInterfaceGenerator.Unit.Tests
|
|||
|
||||
return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1];
|
||||
}
|
||||
|
||||
public ITypeSymbol FindVTableStructType(Compilation compilation, INamedTypeSymbol userDefinedInterface)
|
||||
{
|
||||
INamedTypeSymbol? implementationInterface = FindImplementationInterface(compilation, userDefinedInterface);
|
||||
var vtableField = implementationInterface.GetMembers("Vtable").OfType<IFieldSymbol>().SingleOrDefault();
|
||||
Assert.NotNull(vtableField);
|
||||
return vtableField.Type;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,15 @@ namespace SharedTypes.ComInterfaces
|
|||
|
||||
internal new const string IID = "7F0DB364-3C04-4487-9193-4BB05DC7B654";
|
||||
}
|
||||
[GeneratedComInterface]
|
||||
[Guid("D38D8B40-54A4-4685-B048-D04E215E6A93")]
|
||||
internal partial interface IDerivedBool : IBool
|
||||
{
|
||||
void SetName([MarshalUsing(typeof(Utf16StringMarshaller))] string name);
|
||||
|
||||
[return: MarshalUsing(typeof(Utf16StringMarshaller))]
|
||||
string GetName();
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
internal partial class Derived : GetAndSetInt, IDerived
|
|
@ -0,0 +1,28 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
namespace SharedTypes.ComInterfaces
|
||||
{
|
||||
[GeneratedComInterface]
|
||||
[Guid(IID)]
|
||||
internal partial interface IDerivedDerived : IDerived
|
||||
{
|
||||
void SetFloat(float name);
|
||||
|
||||
float GetFloat();
|
||||
|
||||
internal new const string IID = "7F0DB364-3C04-4487-9193-4BB05DC7B654";
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
internal partial class DerivedDerived : Derived, IDerivedDerived
|
||||
{
|
||||
float _data = 0;
|
||||
public float GetFloat() => _data;
|
||||
public void SetFloat(float name) => _data = name;
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@ namespace SharedTypes.ComInterfaces
|
|||
|
||||
public const string IID = "2c3f9903-b586-46b1-881b-adfce9af47b1";
|
||||
}
|
||||
|
||||
[GeneratedComClass]
|
||||
internal partial class GetAndSetInt : IGetAndSetInt
|
||||
{
|
|
@ -20,7 +20,7 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\..\TestAssets\SharedTypes\ComInterfaces\**\*.cs" Link="ComInterfaceGenerator\ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
|
||||
<Compile Include="..\..\Common\ComInterfaces\**\*.cs" Link="ComInterfaceGenerator\ComInterfaces\%(RecursiveDir)\%(FileName).cs" />
|
||||
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs" Link="Common\DisableRuntimeMarshalling.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
namespace SharedTypes.ComInterfaces
|
||||
{
|
||||
[GeneratedComInterface]
|
||||
[Guid(IID)]
|
||||
public partial interface IExternalBase
|
||||
{
|
||||
public int GetInt();
|
||||
public void SetInt(int x);
|
||||
public const string IID = "2c3f9903-b586-46b1-881b-adfce9af47b1";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
|
||||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.InteropServices.Marshalling;
|
||||
|
||||
namespace SharedTypes.ComInterfaces
|
||||
{
|
||||
[GeneratedComInterface]
|
||||
[Guid(IID)]
|
||||
public partial interface IExternalDerived : IExternalBase
|
||||
{
|
||||
[return: MarshalAs(UnmanagedType.Bool)]
|
||||
bool GetBool();
|
||||
public void SetBool([MarshalAs(UnmanagedType.Bool)] bool x);
|
||||
new public const string IID = "594DF2B9-66CE-490D-9D05-34646675B188";
|
||||
}
|
||||
}
|
|
@ -11,8 +11,6 @@
|
|||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\Ancillary.Interop\Ancillary.Interop.csproj" />
|
||||
<!-- This project just has shared types, but shouldn't generate code for the types -->
|
||||
<EnabledGenerators Remove="@(EnabledGenerators)" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
Loading…
Reference in New Issue