From 99e0fdff08347e78df5f6399e083acaf152f1de5 Mon Sep 17 00:00:00 2001 From: Patrick Stevens <3138005+Smaug123@users.noreply.github.com> Date: Tue, 29 Jul 2025 22:04:45 +0100 Subject: [PATCH] Make ParallelQueue surface its errors, correctly flow ExecutionContext (#278) --- WoofWare.NUnitTestRunner.Lib/Context.fs | 33 +- WoofWare.NUnitTestRunner.Lib/Exception.fs | 10 + WoofWare.NUnitTestRunner.Lib/ParallelQueue.fs | 50 ++- .../WoofWare.NUnitTestRunner.Lib.fsproj | 1 + .../TestSynchronizationContext.fs | 409 ++++++++++++++++++ .../WoofWare.NUnitTestRunner.Test.fsproj | 1 + 6 files changed, 467 insertions(+), 37 deletions(-) create mode 100644 WoofWare.NUnitTestRunner.Lib/Exception.fs create mode 100644 WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/TestSynchronizationContext.fs diff --git a/WoofWare.NUnitTestRunner.Lib/Context.fs b/WoofWare.NUnitTestRunner.Lib/Context.fs index 40420c4..ffc0bcf 100644 --- a/WoofWare.NUnitTestRunner.Lib/Context.fs +++ b/WoofWare.NUnitTestRunner.Lib/Context.fs @@ -10,20 +10,14 @@ open System.Threading type internal OutputStreamId = | OutputStreamId of Guid -type private ThreadAwareWriter - ( - local : AsyncLocal, - underlying : Dictionary, - mem : Dictionary - ) +type private ThreadAwareWriter (local : AsyncLocal, underlying : Dictionary) = inherit TextWriter () override _.get_Encoding () = Encoding.Default override this.Write (v : char) : unit = - use prev = ExecutionContext.Capture () - - (fun _ -> + lock + underlying (fun () -> match underlying.TryGetValue local.Value with | true, output -> output.Write v @@ -31,16 +25,12 @@ type private ThreadAwareWriter let wanted = underlying |> Seq.map (fun (KeyValue (a, b)) -> $"%O{a}") |> String.concat "\n" - failwith $"no such context: %O{local.Value}\nwanted:\n" + failwith $"no such context: %O{local.Value}\nwanted:\n{wanted}" ) - |> lock underlying - ) - |> fun action -> ExecutionContext.Run (prev, action, ()) override this.WriteLine (v : string) : unit = - use prev = ExecutionContext.Capture () - - (fun _ -> + lock + underlying (fun () -> match underlying.TryGetValue local.Value with | true, output -> output.WriteLine v @@ -48,16 +38,13 @@ type private ThreadAwareWriter let wanted = underlying |> Seq.map (fun (KeyValue (a, b)) -> $"%O{a}") |> String.concat "\n" - failwith $"no such context: %O{local.Value}\nwanted:\n" + failwith $"no such context: %O{local.Value}\nwanted:\n{wanted}" ) - |> lock underlying - ) - |> fun action -> ExecutionContext.Run (prev, action, ()) /// Wraps up the necessary context to intercept global state. [] type TestContexts = - private + internal { /// Accesses to this must be locked on StdOutWriters. StdOuts : Dictionary @@ -77,8 +64,8 @@ type TestContexts = let stdoutWriters = Dictionary () let stderrWriters = Dictionary () let local = AsyncLocal () - let stdoutWriter = new ThreadAwareWriter (local, stdoutWriters, stdouts) - let stderrWriter = new ThreadAwareWriter (local, stderrWriters, stderrs) + let stdoutWriter = new ThreadAwareWriter (local, stdoutWriters) + let stderrWriter = new ThreadAwareWriter (local, stderrWriters) { StdOuts = stdouts diff --git a/WoofWare.NUnitTestRunner.Lib/Exception.fs b/WoofWare.NUnitTestRunner.Lib/Exception.fs new file mode 100644 index 0000000..94fbf5e --- /dev/null +++ b/WoofWare.NUnitTestRunner.Lib/Exception.fs @@ -0,0 +1,10 @@ +namespace WoofWare.NUnitTestRunner + +open System.Runtime.ExceptionServices + +[] +module internal Exception = + let reraiseWithOriginalStackTrace<'a> (e : exn) : 'a = + let edi = ExceptionDispatchInfo.Capture e + edi.Throw () + failwith "unreachable" diff --git a/WoofWare.NUnitTestRunner.Lib/ParallelQueue.fs b/WoofWare.NUnitTestRunner.Lib/ParallelQueue.fs index de7901d..cfefdf7 100644 --- a/WoofWare.NUnitTestRunner.Lib/ParallelQueue.fs +++ b/WoofWare.NUnitTestRunner.Lib/ParallelQueue.fs @@ -5,14 +5,14 @@ open System.Threading open System.Threading.Tasks type private ThunkEvaluator<'ret> = - abstract Eval<'a> : (unit -> 'a) -> AsyncReplyChannel<'a> -> 'ret + abstract Eval<'a> : (unit -> 'a) -> AsyncReplyChannel> -> 'ret type private ThunkCrate = abstract Apply<'ret> : ThunkEvaluator<'ret> -> 'ret [] module private ThunkCrate = - let make<'a> (t : unit -> 'a) (rc : AsyncReplyChannel<'a>) : ThunkCrate = + let make<'a> (t : unit -> 'a) (rc : AsyncReplyChannel>) : ThunkCrate = { new ThunkCrate with member _.Apply e = e.Eval t rc } @@ -41,7 +41,7 @@ type private MailboxMessage = | Quit of AsyncReplyChannel /// Check current state, see if we need to start more tests, etc. | Reconcile - | RunTest of within : TestFixture * Parallelizable option * test : ThunkCrate + | RunTest of within : TestFixture * Parallelizable option * test : ThunkCrate * context : ExecutionContext | BeginTestFixture of TestFixture * AsyncReplyChannel | EndTestFixture of TestFixtureTearDownToken * AsyncReplyChannel @@ -310,18 +310,23 @@ type ParallelQueue rc.Reply () m.Post MailboxMessage.Reconcile return! processTask (Running state) m - | MailboxMessage.RunTest (withinFixture, par, message) -> + | MailboxMessage.RunTest (withinFixture, par, message, capturedContext) -> let t () = { new ThunkEvaluator<_> with member _.Eval<'b> (t : unit -> 'b) rc = let tcs = TaskCompletionSource TaskCreationOptions.RunContinuationsAsynchronously - use ec = ExecutionContext.Capture () fun () -> ExecutionContext.Run ( - ec, + capturedContext, (fun _ -> - let result = t () + let result = + try + let r = t () + Ok r + with e -> + Error e + tcs.SetResult () m.Post MailboxMessage.Reconcile rc.Reply result @@ -356,9 +361,18 @@ type ParallelQueue (action : unit -> 'a) : 'a Task = - (fun rc -> MailboxMessage.RunTest (parent, scope, ThunkCrate.make action rc)) - |> mb.PostAndAsyncReply - |> Async.StartAsTask + let ec = ExecutionContext.Capture () + + task { + let! result = + (fun rc -> MailboxMessage.RunTest (parent, scope, ThunkCrate.make action rc, ec)) + |> mb.PostAndAsyncReply + |> Async.StartAsTask + + match result with + | Ok o -> return o + | Error e -> return Exception.reraiseWithOriginalStackTrace e + } /// Declare that we wish to start the given test fixture. The resulting Task will return /// when you are allowed to start running tests from that fixture. @@ -379,11 +393,15 @@ type ParallelQueue | Parallelizable.Yes _ -> Parallelizable.Yes () ) + let ec = ExecutionContext.Capture () + let! response = - (fun rc -> MailboxMessage.RunTest (parent, par, ThunkCrate.make action rc)) + (fun rc -> MailboxMessage.RunTest (parent, par, ThunkCrate.make action rc, ec)) |> mb.PostAndAsyncReply - return response, TestFixtureSetupToken parent + match response with + | Ok response -> return response, TestFixtureSetupToken parent + | Error e -> return Exception.reraiseWithOriginalStackTrace e } /// Run the given one-time tear-down for the test fixture. @@ -401,11 +419,15 @@ type ParallelQueue | Parallelizable.Yes _ -> Parallelizable.Yes () ) + let ec = ExecutionContext.Capture () + let! response = - (fun rc -> MailboxMessage.RunTest (parent, par, ThunkCrate.make action rc)) + (fun rc -> MailboxMessage.RunTest (parent, par, ThunkCrate.make action rc, ec)) |> mb.PostAndAsyncReply - return response, TestFixtureTearDownToken parent + match response with + | Ok response -> return response, TestFixtureTearDownToken parent + | Error e -> return Exception.reraiseWithOriginalStackTrace e } /// Declare that we have finished submitting requests to run in the given test fixture. diff --git a/WoofWare.NUnitTestRunner.Lib/WoofWare.NUnitTestRunner.Lib.fsproj b/WoofWare.NUnitTestRunner.Lib/WoofWare.NUnitTestRunner.Lib.fsproj index 38e536c..e1a1f72 100644 --- a/WoofWare.NUnitTestRunner.Lib/WoofWare.NUnitTestRunner.Lib.fsproj +++ b/WoofWare.NUnitTestRunner.Lib/WoofWare.NUnitTestRunner.Lib.fsproj @@ -31,6 +31,7 @@ + diff --git a/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/TestSynchronizationContext.fs b/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/TestSynchronizationContext.fs new file mode 100644 index 0000000..7179742 --- /dev/null +++ b/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/TestSynchronizationContext.fs @@ -0,0 +1,409 @@ +namespace WoofWare.NUnitTestRunner.Test + +open System +open System.Text +open System.Threading +open System.Threading.Tasks +open NUnit.Framework +open FsUnitTyped +open WoofWare.NUnitTestRunner + +[] +module TestSynchronizationContext = + + [] + let ``ExecutionContext flows correctly through synchronous operations`` () = + task { + let dummyFixture = + TestFixture.Empty typeof (Some (Parallelizable.Yes ClassParallelScope.All)) [] [] + + use contexts = TestContexts.Empty () + use queue = new ParallelQueue (Some 4, None) + + // Track which context values we see during execution + let contextValues = System.Collections.Concurrent.ConcurrentBag () + + // Start the fixture + let! running = queue.StartTestFixture dummyFixture + let! _, setup = queue.RunTestSetup running (fun () -> ()) + + // Create several synchronous operations with different context values + let tasks = + [ 1..10 ] + |> List.map (fun _ -> + task { + do! Task.Yield () + // Set a unique context value + let outputId = contexts.NewOutputs () + let (OutputStreamId expectedId) = outputId + contexts.AsyncLocal.Value <- outputId + + // Run a synchronous operation that checks the context + let! actualId = + queue.Run + setup + None + (fun () -> + // Check context immediately + let immediate = contexts.AsyncLocal.Value + let (OutputStreamId immediateGuid) = immediate + contextValues.Add (expectedId, immediateGuid) + + // Do some work that might cause context issues + Thread.Sleep 10 + + // Check context after work + let afterWork = contexts.AsyncLocal.Value + let (OutputStreamId afterWorkGuid) = afterWork + contextValues.Add (expectedId, afterWorkGuid) + + // Simulate calling into framework code that might use ExecutionContext + let mutable capturedValue = Guid.Empty + + ExecutionContext.Run ( + ExecutionContext.Capture (), + (fun _ -> + let current = contexts.AsyncLocal.Value + let (OutputStreamId currentGuid) = current + capturedValue <- currentGuid + ), + () + ) + + contextValues.Add (expectedId, capturedValue) + + afterWorkGuid + ) + + // Verify the returned value matches what we set + actualId |> shouldEqual expectedId + } + ) + + // Wait for all tasks + let! results = Task.WhenAll tasks + results |> Array.iter id + + // Verify all context values were correct + let allValues = contextValues |> Seq.toList + allValues |> shouldHaveLength 30 // 3 checks per operation * 10 operations + + // Every captured value should match its expected value + allValues + |> List.iter (fun (expected, actual) -> actual |> shouldEqual expected) + + // Clean up + let! _, teardown = queue.RunTestTearDown setup (fun () -> ()) + do! queue.EndTestFixture teardown + } + + [] + let ``ExecutionContext isolation between concurrent synchronous operations`` () = + task { + let dummyFixture = + TestFixture.Empty typeof (Some (Parallelizable.Yes ClassParallelScope.All)) [] [] + + use contexts = TestContexts.Empty () + use queue = new ParallelQueue (Some 4, None) + + let! running = queue.StartTestFixture dummyFixture + let! _, setup = queue.RunTestSetup running (fun () -> ()) + + // Use a barrier to ensure operations run concurrently + let barrier = new Barrier (3) + let seenValues = System.Collections.Concurrent.ConcurrentBag () + let outputIds = System.Collections.Concurrent.ConcurrentBag () + + // Create operations that will definitely run concurrently + let tasks = + [ 1..3 ] + |> List.map (fun i -> + task { + // Each task sets its own context value + let outputId = contexts.NewOutputs () + let (OutputStreamId myId) = outputId + contexts.AsyncLocal.Value <- outputId + outputIds.Add outputId + + let! result = + queue.Run + setup + (Some (Parallelizable.Yes ())) + (fun () -> + // Wait for all tasks to reach this point + barrier.SignalAndWait () + + // Now check what value we see + let currentValue = contexts.AsyncLocal.Value + + match currentValue with + | OutputStreamId guid -> seenValues.Add (i, guid) + + // Do some synchronous work + Thread.Sleep 5 + + // Check again after work + let afterWork = contexts.AsyncLocal.Value + + match afterWork with + | OutputStreamId guid -> + // Also verify we can write to the correct streams + contexts.Stdout.WriteLine $"Task %i{i} sees context %O{guid}" + guid + ) + + // Each task should see its own value + result |> shouldEqual myId + } + ) + + let! results = Task.WhenAll tasks + results |> Array.iter id + + // Verify we saw 3 different values (one per task) + let values = seenValues |> Seq.toList + values |> shouldHaveLength 3 + + // All seen values should be different (no context bleeding) + let uniqueValues = values |> List.map snd |> List.distinct + uniqueValues |> shouldHaveLength 3 + + let! _, teardown = queue.RunTestTearDown setup (fun () -> ()) + do! queue.EndTestFixture teardown + + // Verify stdout content for each task + let collectedOutputs = outputIds |> Seq.toList + collectedOutputs |> shouldHaveLength 3 + + for outputId in collectedOutputs do + let content = contexts.DumpStdout outputId + content |> shouldNotEqual "" + let (OutputStreamId guid) = outputId + content |> shouldContainText (guid.ToString ()) + } + + [] + let ``ExecutionContext flows correctly through nested synchronous operations`` () = + task { + let dummyFixture = + TestFixture.Empty typeof (Some (Parallelizable.Yes ClassParallelScope.All)) [] [] + + use contexts = TestContexts.Empty () + use queue = new ParallelQueue (Some 4, None) + + let! running = queue.StartTestFixture dummyFixture + let! _, setup = queue.RunTestSetup running (fun () -> ()) + + // Set an initial context + let outputId = contexts.NewOutputs () + let (OutputStreamId outerGuid) = outputId + contexts.AsyncLocal.Value <- outputId + + let! result = + queue.Run + setup + None + (fun () -> + // Check we have the outer context + let outer = contexts.AsyncLocal.Value + let (OutputStreamId outerSeen) = outer + outerSeen |> shouldEqual outerGuid + + // Now change the context for a nested operation + let innerOutputId = contexts.NewOutputs () + let (OutputStreamId innerGuid) = innerOutputId + contexts.AsyncLocal.Value <- innerOutputId + + // Use Task.Run to potentially hop threads + let innerResult = + Task + .Run(fun () -> + let inner = contexts.AsyncLocal.Value + let (OutputStreamId innerSeen) = inner + innerSeen |> shouldEqual innerGuid + innerSeen + ) + .Result + + // After the nested operation, we should still have our inner context + let afterNested = contexts.AsyncLocal.Value + let (OutputStreamId afterNestedGuid) = afterNested + afterNestedGuid |> shouldEqual innerGuid + + (outerSeen, innerResult, afterNestedGuid) + ) + + // Unpack results + let seenOuter, seenInner, seenAfter = result + seenOuter |> shouldEqual outerGuid + seenInner |> shouldNotEqual outerGuid + seenAfter |> shouldEqual seenInner + + let! _, teardown = queue.RunTestTearDown setup (fun () -> ()) + do! queue.EndTestFixture teardown + } + +(* + [] + let ``ExecutionContext flows correctly through async operations`` () = + task { + // Create a test fixture + let dummyFixture = + TestFixture.Empty + typeof + (Some (Parallelizable.Yes ClassParallelScope.All)) + [] + [] + + use contexts = TestContexts.Empty () + use queue = new ParallelQueue(Some 4, None) + + // Track which context values we see during execution + let contextValues = System.Collections.Concurrent.ConcurrentBag() + + // Start the fixture + let! running = queue.StartTestFixture dummyFixture + let! _, setup = queue.RunTestSetup running (fun () -> ()) + + // Create several async operations with different context values + let tasks = + [1..10] + |> List.map (fun i -> + task { + // Set a unique context value + let expectedId = Guid.NewGuid() + let outputId = OutputStreamId expectedId + contexts.AsyncLocal.Value <- outputId + + // Run an async operation that checks the context at multiple points + let! actualId = + queue.RunAsync setup None (fun () -> + async { + // Check context immediately + let immediate = contexts.AsyncLocal.Value + let (OutputStreamId immediateGuid) = immediate + contextValues.Add(expectedId, immediateGuid) + + // Yield to allow potential context loss + do! Async.Sleep 10 + + // Check context after yield + let afterYield = contexts.AsyncLocal.Value + let (OutputStreamId afterYieldGuid) = afterYield + contextValues.Add(expectedId, afterYieldGuid) + + // Do some actual async work + do! Task.Delay(10) |> Async.AwaitTask + + // Check context after task + let afterTask = contexts.AsyncLocal.Value + let (OutputStreamId afterTaskGuid) = afterTask + contextValues.Add(expectedId, afterTaskGuid) + + return afterTaskGuid + } + ) + + // Verify the returned value matches what we set + actualId |> shouldEqual expectedId + } + ) + + // Wait for all tasks + let! results = Task.WhenAll(tasks) + results |> Array.iter id + + // Verify all context values were correct + let allValues = contextValues |> Seq.toList + allValues |> shouldHaveLength 30 // 3 checks per operation * 10 operations + + // Every captured value should match its expected value + for expected, actual in allValues do + actual |> shouldEqual expected + + // Clean up + let! _, teardown = queue.RunTestTearDown setup (fun () -> ()) + do! queue.EndTestFixture teardown + } + + [] + let ``ExecutionContext isolation between concurrent operations`` () = + task { + let dummyFixture = + TestFixture.Empty + typeof + (Some (Parallelizable.Yes ClassParallelScope.All)) + [] + [] + + use contexts = TestContexts.Empty () + use queue = new ParallelQueue(Some 4, None) + + let! running = queue.StartTestFixture dummyFixture + let! _, setup = queue.RunTestSetup running (fun () -> ()) + + // Use a barrier to ensure operations run concurrently + let barrier = new Barrier(3) + let seenValues = System.Collections.Concurrent.ConcurrentBag() + + // Create operations that will definitely run concurrently + let tasks = + [1..3] + |> List.map (fun i -> + task { + // Each task sets its own context value + let myId = Guid.NewGuid() + contexts.AsyncLocal.Value <- OutputStreamId myId + + let! result = + queue.RunAsync setup (Some (Parallelizable.Yes ())) (fun () -> + async { + // Wait for all tasks to reach this point + barrier.SignalAndWait() |> ignore + + // Now check what value we see + let currentValue = contexts.AsyncLocal.Value + match currentValue with + | OutputStreamId guid -> seenValues.Add(i, Some guid) + | _ -> seenValues.Add(i, None) + + // Do some async work + do! Async.Sleep 5 + + // Check again after async work + let afterAsync = contexts.AsyncLocal.Value + match afterAsync with + | OutputStreamId guid -> + return guid + | _ -> + return failwith "Lost context after async" + } + ) + + // Each task should see its own value + result |> shouldEqual myId + } + ) + + let! results = Task.WhenAll(tasks) + results |> Array.iter id + + // Verify we saw 3 different values (one per task) + let values = seenValues |> Seq.toList + values |> shouldHaveLength 3 + + // Each task should have seen a value + for (taskId, value) in values do + value |> shouldNotEqual None + + // All seen values should be different (no context bleeding) + let uniqueValues = + values + |> List.choose snd + |> List.distinct + uniqueValues |> shouldHaveLength 3 + + let! _, teardown = queue.RunTestTearDown setup (fun () -> ()) + do! queue.EndTestFixture teardown + } +*) diff --git a/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/WoofWare.NUnitTestRunner.Test.fsproj b/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/WoofWare.NUnitTestRunner.Test.fsproj index a9f6dac..22a6b62 100644 --- a/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/WoofWare.NUnitTestRunner.Test.fsproj +++ b/WoofWare.NUnitTestRunner/WoofWare.NUnitTestRunner.Test/WoofWare.NUnitTestRunner.Test.fsproj @@ -11,6 +11,7 @@ +