Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/coreclr/vm/comcallablewrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3587,6 +3587,20 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT)
if (pClassMD != NULL)
{
pNewMD->InitMethod(pClassMD, pIntfMD);

// Restore the vtable slot on parent MethodTables that may share this ComMethodTable.
// This ensures that we do not end up with a NULL slot during COM dispatch due to lazy
// entry point allocation.
if (pClassMD->IsVirtual())
{
DWORD slot = pClassMD->GetSlot();
MethodTable *pParentWalk = pClsMT->GetParentMethodTable();
while (pParentWalk != NULL && slot < pParentWalk->GetNumVirtuals())
{
pParentWalk->GetRestoredSlot(slot);
pParentWalk = pParentWalk->GetParentMethodTable();
}
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using Xunit;

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000001")]
public interface IFoo
{
void DoWork();
}

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000002")]
[ComDefaultInterface(typeof(IFoo))]
public class Foo : IFoo
{
public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo);
}

[ComVisible(true)]
[Guid("A1111111-0000-0000-0000-000000000003")]
[ComDefaultInterface(typeof(IFoo))]
public class FooDerived : Foo
{
public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived);
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000001")]
public interface IBar
{
void DoWork();
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000002")]
[ComDefaultInterface(typeof(IBar))]
public class Bar : IBar
{
public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar);
}

[ComVisible(true)]
[Guid("B2222222-0000-0000-0000-000000000003")]
[ComDefaultInterface(typeof(IBar))]
public class BarDerived : Bar
{
public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived);
}

/// <summary>
/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides
/// regardless of whether the base or derived class is accessed via COM first.
/// </summary>
public class VirtualMethodOverrideTest
{
internal static string? LastCalledType;

[UnmanagedFunctionPointer(CallingConvention.StdCall)]
delegate int DoWorkDelegate(IntPtr pThis);

private static int CallDoWork(IntPtr pInterface, int slot)
{
IntPtr vtbl = Marshal.ReadIntPtr(pInterface);
IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size);
Assert.NotEqual(IntPtr.Zero, fnPtr);

var fn = Marshal.GetDelegateForFunctionPointer<DoWorkDelegate>(fnPtr);
return fn(pInterface);
}

[Fact]
public static void DerivedFirst()
{
int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo));
IntPtr pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo));
IntPtr pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo));
try
{
LastCalledType = null;
Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
Assert.Equal(nameof(FooDerived), LastCalledType);

LastCalledType = null;
Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
Assert.Equal(nameof(Foo), LastCalledType);
}
finally
{
Marshal.Release(pDerived);
Marshal.Release(pBase);
}
}

[Fact]
public static void BaseFirst()
{
int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar));
IntPtr pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar));
IntPtr pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar));
try
{
LastCalledType = null;
Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
Assert.Equal(nameof(Bar), LastCalledType);

LastCalledType = null;
Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
Assert.Equal(nameof(BarDerived), LastCalledType);
}
finally
{
Marshal.Release(pBase);
Marshal.Release(pDerived);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<NativeAotIncompatible>true</NativeAotIncompatible>
</PropertyGroup>
<ItemGroup>
<Compile Include="VirtualMethodOverrideTest.cs" />
</ItemGroup>
</Project>
Loading