This commit is contained in:
Copilot 2025-07-30 07:16:39 -07:00 committed by GitHub
commit 82892e8373
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 188 additions and 4 deletions

View File

@ -444,8 +444,28 @@ namespace Microsoft.Interop
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(TypeInfo, context);
// Generate numElements expression with null check for pointer types
ExpressionSyntax numElementsExpression = ElementsMarshalling.GenerateNumElementsExpression(countInfo, castCountInfo, CodeContext, context);
// If the marshalling direction is unmanaged-to-managed and we have a native pointer type,
// we need to check if the native pointer is null before using the size parameter to avoid
// allocating arrays for null pointers.
if (CodeContext.Direction == MarshalDirection.UnmanagedToManaged &&
innerMarshaller.NativeType is PointerTypeInfo)
{
string nativeIdentifier = context.GetIdentifiers(TypeInfo).native;
// Generate: nativePointer == null ? 0 : sizeExpression
numElementsExpression = ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)),
numElementsExpression);
}
// <numElements> = <numElementsExpression>;
yield return AssignmentStatement(IdentifierName(numElementsIdentifier), ElementsMarshalling.GenerateNumElementsExpression(countInfo, castCountInfo, CodeContext, context));
yield return AssignmentStatement(IdentifierName(numElementsIdentifier), numElementsExpression);
yield return elementsMarshalling.GenerateUnmarshalStatement(context);

View File

@ -322,13 +322,31 @@ namespace Microsoft.Interop
if (MarshallerHelpers.GetMarshalDirection(TypeInfo, CodeContext) != MarshalDirection.ManagedToUnmanaged)
{
// If we are marshalling from unmanaged to managed, we need to get the number of elements again.
string nativeIdentifier = context.GetIdentifiers(TypeInfo).native;
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(TypeInfo, context);
// Generate numElements expression with null check for pointer types
ExpressionSyntax numElementsExpression = ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context);
// If we have a native pointer type, we need to check if the native pointer is null
// before using the size parameter to avoid allocating arrays for null pointers.
if (NativeType is PointerTypeInfo)
{
// Generate: nativePointer == null ? 0 : sizeExpression
numElementsExpression = ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)),
numElementsExpression);
}
// <numElements> = <numElementsExpression>;
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(numElementsIdentifier),
ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context)));
numElementsExpression));
}
}
@ -360,11 +378,28 @@ namespace Microsoft.Interop
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(TypeInfo);
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(TypeInfo, context);
// Generate numElements expression with null check for pointer types
ExpressionSyntax numElementsExpression = ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context);
// If we have a native pointer type, we need to check if the native pointer is null
// before using the size parameter to avoid allocating arrays for null pointers.
// This method only runs for unmarshal scenarios (not ManagedToUnmanaged).
if (NativeType is PointerTypeInfo)
{
// Generate: nativePointer == null ? 0 : sizeExpression
numElementsExpression = ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)),
numElementsExpression);
}
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(numElementsIdentifier),
ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context)));
numElementsExpression));
// <managedIdentifier> = <marshallerType>.AllocateContainerForManagedElementsFinally(<nativeIdentifier>, <numElements>);
yield return ExpressionStatement(
@ -448,11 +483,29 @@ namespace Microsoft.Interop
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(TypeInfo);
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(TypeInfo, context);
// Generate numElements expression with null check for pointer types
ExpressionSyntax numElementsExpression = ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context);
// If the marshalling direction is unmanaged-to-managed and we have a native pointer type,
// we need to check if the native pointer is null before using the size parameter to avoid
// allocating arrays for null pointers.
if (CodeContext.Direction == MarshalDirection.UnmanagedToManaged &&
NativeType is PointerTypeInfo)
{
// Generate: nativePointer == null ? 0 : sizeExpression
numElementsExpression = ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)),
numElementsExpression);
}
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(numElementsIdentifier),
ElementsMarshalling.GenerateNumElementsExpression(countInfo, countInfoRequiresCast, CodeContext, context)));
numElementsExpression));
// <managedIdentifier> = <marshallerType>.AllocateContainerForManagedElements(<nativeIdentifier>, <numElements>);
yield return ExpressionStatement(

View File

@ -0,0 +1,111 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using Xunit;
namespace ComInterfaceGenerator.Tests;
public partial class ArrayBufferMarshallingTests
{
[GeneratedComInterface]
[Guid("8A2AF35B-D028-4191-A01F-3422AB0CF724")]
public partial interface ITestInterface
{
void TestMethod(
int bufferSize,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 0), Out] int[]? buffer1,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 0), Out] int[]? buffer2);
void TestMethodWithRef(
int bufferSize,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 0)] ref int[]? buffer);
}
private class TestImplementation : ITestInterface
{
public void TestMethod(int bufferSize, int[]? buffer1, int[]? buffer2)
{
// Fill buffer1 if not null
if (buffer1 != null)
{
for (int i = 0; i < Math.Min(bufferSize, buffer1.Length); i++)
{
buffer1[i] = i;
}
}
// Fill buffer2 if not null
if (buffer2 != null)
{
for (int i = 0; i < Math.Min(bufferSize, buffer2.Length); i++)
{
buffer2[i] = i * 2;
}
}
}
public void TestMethodWithRef(int bufferSize, ref int[]? buffer)
{
// If buffer is null, allocate it with the specified size
if (buffer is null && bufferSize > 0)
{
buffer = new int[bufferSize];
}
// Fill buffer if not null
if (buffer != null)
{
for (int i = 0; i < Math.Min(bufferSize, buffer.Length); i++)
{
buffer[i] = i * 3;
}
}
}
}
[Fact]
public void TestGeneratedCodeCompilation()
{
// This test ensures the COM interface with array parameters is generated without compilation errors
// The specific issue is that the generated code should handle null array pointers correctly
// when calculating the number of elements in unmanaged-to-managed stubs.
var testImpl = new TestImplementation();
var cw = new StrategyBasedComWrappers();
nint ptr = cw.GetOrCreateComInterfaceForObject(testImpl, CreateComInterfaceFlags.None);
try
{
// The main test is that this interface can be created and the generated code compiles
Assert.NotEqual(0, (int)ptr);
}
finally
{
Marshal.Release(ptr);
}
}
[Fact]
public void TestGeneratedCodeCompilationWithRefArrays()
{
// This test ensures that COM interfaces with ref array parameters compile correctly
// and handle null pointer scenarios properly
var testImpl = new TestImplementation();
var cw = new StrategyBasedComWrappers();
nint ptr = cw.GetOrCreateComInterfaceForObject(testImpl, CreateComInterfaceFlags.None);
try
{
// The main test is that this interface with ref arrays can be created and the generated code compiles
Assert.NotEqual(0, (int)ptr);
}
finally
{
Marshal.Release(ptr);
}
}
}