diff --git a/src/CenterEdge.Async.UnitTests/AsyncHelperTests.cs b/src/CenterEdge.Async.UnitTests/AsyncHelperTests.cs index 7ea5b98..72f8294 100644 --- a/src/CenterEdge.Async.UnitTests/AsyncHelperTests.cs +++ b/src/CenterEdge.Async.UnitTests/AsyncHelperTests.cs @@ -1,6 +1,7 @@ using System; using System.Threading; using System.Threading.Tasks; +using System.Collections.Generic; using Moq; using Xunit; @@ -1276,6 +1277,156 @@ public void RunSyncWithState_ValueTaskT_DanglingContinuations_HandledOnParentSyn #endregion + #region SyncContext Restoration Tests + + [Fact] + public void RunSync_Task_SyncCompletionWithContinuations_RestoresSyncContextBeforeProcessingQueue() + { + // This test simulates the bug where a reentrant sync context (like WinForms) + // might process posted messages before RunSync returns, causing continuations + // to see the ExclusiveSynchronizationContext instead of the original context. + + // Arrange + SynchronizationContext? capturedContext = null; + var reentrantSync = new ReentrantSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(reentrantSync); + + try + { + // Act + AsyncHelper.RunSync(() => + { + // Start a continuation that will be queued +#pragma warning disable CS4014 + Task.Run(() => + { + // This continuation will be posted back to the sync context + capturedContext = SynchronizationContext.Current; + }); +#pragma warning restore CS4014 + + // Return a completed task so RunAlreadyComplete is called + return Task.CompletedTask; + }); + + // The reentrant context processes its queue before returning + reentrantSync.ProcessQueue(); + + // Assert + Assert.Equal(reentrantSync, capturedContext); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + } + + [Fact] + public void RunSync_ValueTask_SyncCompletionWithContinuations_RestoresSyncContextBeforeProcessingQueue() + { + // Arrange + SynchronizationContext? capturedContext = null; + var reentrantSync = new ReentrantSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(reentrantSync); + + try + { + // Act + AsyncHelper.RunSync(() => + { +#pragma warning disable CS4014 + Task.Run(() => + { + capturedContext = SynchronizationContext.Current; + }); +#pragma warning restore CS4014 + + return new ValueTask(); + }); + + reentrantSync.ProcessQueue(); + + // Assert + Assert.Equal(reentrantSync, capturedContext); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + } + + [Fact] + public void RunSync_TaskT_SyncCompletionWithContinuations_RestoresSyncContextBeforeProcessingQueue() + { + // Arrange + SynchronizationContext? capturedContext = null; + var reentrantSync = new ReentrantSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(reentrantSync); + + try + { + // Act + var result = AsyncHelper.RunSync(() => + { +#pragma warning disable CS4014 + Task.Run(() => + { + capturedContext = SynchronizationContext.Current; + }); +#pragma warning restore CS4014 + + return Task.FromResult(42); + }); + + reentrantSync.ProcessQueue(); + + // Assert + Assert.Equal(42, result); + Assert.Equal(reentrantSync, capturedContext); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + } + + [Fact] + public void RunSync_ValueTaskT_SyncCompletionWithContinuations_RestoresSyncContextBeforeProcessingQueue() + { + // Arrange + SynchronizationContext? capturedContext = null; + var reentrantSync = new ReentrantSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(reentrantSync); + + try + { + // Act + var result = AsyncHelper.RunSync(() => + { +#pragma warning disable CS4014 + Task.Run(() => + { + capturedContext = SynchronizationContext.Current; + }); +#pragma warning restore CS4014 + + return new ValueTask(42); + }); + + reentrantSync.ProcessQueue(); + + // Assert + Assert.Equal(42, result); + Assert.Equal(reentrantSync, capturedContext); + } + finally + { + SynchronizationContext.SetSynchronizationContext(null); + } + } + + #endregion + #region Helpers private static readonly AsyncLocal asyncLocalField = new(); @@ -1287,6 +1438,27 @@ private static async Task DelayedActionAsync(TimeSpan delay, Action action) action.Invoke(); } + // Simulates a reentrant synchronization context like WinForms + // that processes its message queue synchronously when asked + private class ReentrantSynchronizationContext : SynchronizationContext + { + private readonly Queue<(SendOrPostCallback, object?)> _queue = new(); + + public override void Post(SendOrPostCallback d, object? state) + { + _queue.Enqueue((d, state)); + } + + public void ProcessQueue() + { + while (_queue.Count > 0) + { + var (callback, state) = _queue.Dequeue(); + callback(state); + } + } + } + #endregion } } diff --git a/src/CenterEdge.Async/AsyncHelper.cs b/src/CenterEdge.Async/AsyncHelper.cs index 8b1e3dc..4bb50cf 100644 --- a/src/CenterEdge.Async/AsyncHelper.cs +++ b/src/CenterEdge.Async/AsyncHelper.cs @@ -88,6 +88,10 @@ public static void RunSync(Func task, TState state) } else { + // Restore the sync context before processing any queued continuations + // to avoid reentrancy issues where the parent context might process + // posted messages before we return + SynchronizationContext.SetSynchronizationContext(oldContext); synch.RunAlreadyComplete(); } @@ -168,6 +172,10 @@ public static void RunSync(Func task, TState state) } else { + // Restore the sync context before processing any queued continuations + // to avoid reentrancy issues where the parent context might process + // posted messages before we return + SynchronizationContext.SetSynchronizationContext(oldContext); synch.RunAlreadyComplete(); } @@ -245,6 +253,10 @@ public static T RunSync(Func> task, TState state) } else { + // Restore the sync context before processing any queued continuations + // to avoid reentrancy issues where the parent context might process + // posted messages before we return + SynchronizationContext.SetSynchronizationContext(oldContext); synch.RunAlreadyComplete(); } @@ -326,6 +338,10 @@ public static T RunSync(Func> task, TState state } else { + // Restore the sync context before processing any queued continuations + // to avoid reentrancy issues where the parent context might process + // posted messages before we return + SynchronizationContext.SetSynchronizationContext(oldContext); synch.RunAlreadyComplete(); }