diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp index 1ca1460454c19d..59e05e31c54d4f 100644 --- a/src/coreclr/vm/comcallablewrapper.cpp +++ b/src/coreclr/vm/comcallablewrapper.cpp @@ -3587,6 +3587,20 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT) if (pClassMD != NULL) { pNewMD->InitMethod(pClassMD, pIntfMD); + + // Restore the vtable slot on parent MethodTables that may share this ComMethodTable. + // This ensures that we do not end up with a NULL slot during COM dispatch due to lazy + // entry point allocation. + if (pClassMD->IsVirtual()) + { + DWORD slot = pClassMD->GetSlot(); + MethodTable *pParentWalk = pClsMT->GetParentMethodTable(); + while (pParentWalk != NULL && slot < pParentWalk->GetNumVirtuals()) + { + pParentWalk->GetRestoredSlot(slot); + pParentWalk = pParentWalk->GetParentMethodTable(); + } + } } else { diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs new file mode 100644 index 00000000000000..fb15cf17ac7463 --- /dev/null +++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs @@ -0,0 +1,120 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using Xunit; + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000001")] +public interface IFoo +{ + void DoWork(); +} + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000002")] +[ComDefaultInterface(typeof(IFoo))] +public class Foo : IFoo +{ + public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo); +} + +[ComVisible(true)] +[Guid("A1111111-0000-0000-0000-000000000003")] +[ComDefaultInterface(typeof(IFoo))] +public class FooDerived : Foo +{ + public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000001")] +public interface IBar +{ + void DoWork(); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000002")] +[ComDefaultInterface(typeof(IBar))] +public class Bar : IBar +{ + public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar); +} + +[ComVisible(true)] +[Guid("B2222222-0000-0000-0000-000000000003")] +[ComDefaultInterface(typeof(IBar))] +public class BarDerived : Bar +{ + public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived); +} + +/// +/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides +/// regardless of whether the base or derived class is accessed via COM first. +/// +public class VirtualMethodOverrideTest +{ + internal static string? LastCalledType; + + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + delegate int DoWorkDelegate(IntPtr pThis); + + private static int CallDoWork(IntPtr pInterface, int slot) + { + IntPtr vtbl = Marshal.ReadIntPtr(pInterface); + IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size); + Assert.NotEqual(IntPtr.Zero, fnPtr); + + var fn = Marshal.GetDelegateForFunctionPointer(fnPtr); + return fn(pInterface); + } + + [Fact] + public static void DerivedFirst() + { + int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo)); + IntPtr pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo)); + IntPtr pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo)); + try + { + LastCalledType = null; + Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0); + Assert.Equal(nameof(FooDerived), LastCalledType); + + LastCalledType = null; + Assert.True(CallDoWork(pBase, doWorkSlot) >= 0); + Assert.Equal(nameof(Foo), LastCalledType); + } + finally + { + Marshal.Release(pDerived); + Marshal.Release(pBase); + } + } + + [Fact] + public static void BaseFirst() + { + int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar)); + IntPtr pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar)); + IntPtr pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar)); + try + { + LastCalledType = null; + Assert.True(CallDoWork(pBase, doWorkSlot) >= 0); + Assert.Equal(nameof(Bar), LastCalledType); + + LastCalledType = null; + Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0); + Assert.Equal(nameof(BarDerived), LastCalledType); + } + finally + { + Marshal.Release(pBase); + Marshal.Release(pDerived); + } + } +} diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj new file mode 100644 index 00000000000000..412688de2f640c --- /dev/null +++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj @@ -0,0 +1,10 @@ + + + true + true + true + + + + +