diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp
index 1ca1460454c19d..59e05e31c54d4f 100644
--- a/src/coreclr/vm/comcallablewrapper.cpp
+++ b/src/coreclr/vm/comcallablewrapper.cpp
@@ -3587,6 +3587,20 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT)
if (pClassMD != NULL)
{
pNewMD->InitMethod(pClassMD, pIntfMD);
+
+ // Restore the vtable slot on parent MethodTables that may share this ComMethodTable.
+ // This ensures that we do not end up with a NULL slot during COM dispatch due to lazy
+ // entry point allocation.
+ if (pClassMD->IsVirtual())
+ {
+ DWORD slot = pClassMD->GetSlot();
+ MethodTable *pParentWalk = pClsMT->GetParentMethodTable();
+ while (pParentWalk != NULL && slot < pParentWalk->GetNumVirtuals())
+ {
+ pParentWalk->GetRestoredSlot(slot);
+ pParentWalk = pParentWalk->GetParentMethodTable();
+ }
+ }
}
else
{
diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs
new file mode 100644
index 00000000000000..fb15cf17ac7463
--- /dev/null
+++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.cs
@@ -0,0 +1,120 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using Xunit;
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000001")]
+public interface IFoo
+{
+ void DoWork();
+}
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000002")]
+[ComDefaultInterface(typeof(IFoo))]
+public class Foo : IFoo
+{
+ public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Foo);
+}
+
+[ComVisible(true)]
+[Guid("A1111111-0000-0000-0000-000000000003")]
+[ComDefaultInterface(typeof(IFoo))]
+public class FooDerived : Foo
+{
+ public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(FooDerived);
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000001")]
+public interface IBar
+{
+ void DoWork();
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000002")]
+[ComDefaultInterface(typeof(IBar))]
+public class Bar : IBar
+{
+ public virtual void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(Bar);
+}
+
+[ComVisible(true)]
+[Guid("B2222222-0000-0000-0000-000000000003")]
+[ComDefaultInterface(typeof(IBar))]
+public class BarDerived : Bar
+{
+ public override void DoWork() => VirtualMethodOverrideTest.LastCalledType = nameof(BarDerived);
+}
+
+///
+/// Tests that COM-to-CLR dispatch correctly resolves virtual method overrides
+/// regardless of whether the base or derived class is accessed via COM first.
+///
+public class VirtualMethodOverrideTest
+{
+ internal static string? LastCalledType;
+
+ [UnmanagedFunctionPointer(CallingConvention.StdCall)]
+ delegate int DoWorkDelegate(IntPtr pThis);
+
+ private static int CallDoWork(IntPtr pInterface, int slot)
+ {
+ IntPtr vtbl = Marshal.ReadIntPtr(pInterface);
+ IntPtr fnPtr = Marshal.ReadIntPtr(vtbl, slot * IntPtr.Size);
+ Assert.NotEqual(IntPtr.Zero, fnPtr);
+
+ var fn = Marshal.GetDelegateForFunctionPointer(fnPtr);
+ return fn(pInterface);
+ }
+
+ [Fact]
+ public static void DerivedFirst()
+ {
+ int doWorkSlot = Marshal.GetStartComSlot(typeof(IFoo));
+ IntPtr pDerived = Marshal.GetComInterfaceForObject(new FooDerived(), typeof(IFoo));
+ IntPtr pBase = Marshal.GetComInterfaceForObject(new Foo(), typeof(IFoo));
+ try
+ {
+ LastCalledType = null;
+ Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
+ Assert.Equal(nameof(FooDerived), LastCalledType);
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
+ Assert.Equal(nameof(Foo), LastCalledType);
+ }
+ finally
+ {
+ Marshal.Release(pDerived);
+ Marshal.Release(pBase);
+ }
+ }
+
+ [Fact]
+ public static void BaseFirst()
+ {
+ int doWorkSlot = Marshal.GetStartComSlot(typeof(IBar));
+ IntPtr pBase = Marshal.GetComInterfaceForObject(new Bar(), typeof(IBar));
+ IntPtr pDerived = Marshal.GetComInterfaceForObject(new BarDerived(), typeof(IBar));
+ try
+ {
+ LastCalledType = null;
+ Assert.True(CallDoWork(pBase, doWorkSlot) >= 0);
+ Assert.Equal(nameof(Bar), LastCalledType);
+
+ LastCalledType = null;
+ Assert.True(CallDoWork(pDerived, doWorkSlot) >= 0);
+ Assert.Equal(nameof(BarDerived), LastCalledType);
+ }
+ finally
+ {
+ Marshal.Release(pBase);
+ Marshal.Release(pDerived);
+ }
+ }
+}
diff --git a/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj
new file mode 100644
index 00000000000000..412688de2f640c
--- /dev/null
+++ b/src/tests/Interop/COM/VirtualMethodOverride/VirtualMethodOverrideTest.csproj
@@ -0,0 +1,10 @@
+
+
+ true
+ true
+ true
+
+
+
+
+