From 0d30aadc9da3fdf22a9bde8564592bbc56b96e36 Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Sun, 4 Feb 2024 17:39:40 -0900 Subject: [PATCH 1/7] Implements OnMainThreadBlocked. --- .../JoinableTaskContext.cs | 140 ++++++++ .../net472/PublicAPI.Unshipped.txt | 3 +- .../net6.0-windows/PublicAPI.Unshipped.txt | 3 +- .../net6.0/PublicAPI.Unshipped.txt | 3 +- .../netstandard2.0/PublicAPI.Unshipped.txt | 3 +- .../JoinableTaskContextTests.cs | 328 ++++++++++++++++++ 6 files changed, 476 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs index 3eb58fbcb..b0746c2c3 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs @@ -131,6 +131,8 @@ public partial class JoinableTaskContext : IDisposable /// private readonly string contextId = Guid.NewGuid().ToString("n"); + private readonly NonPostJoinableTaskFactory nonPostJoinableTaskFactory; + /// /// The next unique ID to assign to a for which a token is required. /// @@ -156,6 +158,7 @@ public partial class JoinableTaskContext : IDisposable public JoinableTaskContext() : this(Thread.CurrentThread, SynchronizationContext.Current) { + this.nonPostJoinableTaskFactory = new NonPostJoinableTaskFactory(this); } /// @@ -385,6 +388,55 @@ public bool IsMainThreadMaybeBlocked() return false; } + /// + /// Registers a callback when the current JoinableTask is blocking the UI thread. + /// + /// The type of state used by the callback. + /// A callback method. + /// A state passing to the callback method. + /// A disposable which can be used to unregister the callback. + public IDisposable OnMainThreadBlocked(Action action, TState state) + { + Requires.NotNull(action, nameof(action)); + + JoinableTask? ambientTask = this.AmbientTask; + if (ambientTask is null) + { + // when it is called outside of a JoinableTask, it would never block main thread in the future, so we would not want to waste time further. + return EmptyDisposable.Instance; + } + + var cancellation = new DisposeToCancel(); + CancellationToken cancellationToken = cancellation.CancellationToken; + + // chain a task to ensure we would clean up all things when the ambient task is completed. + // A completed task would not block the main thread further. + _ = ambientTask.Task.ContinueWith( + static (_, s) => ((DisposeToCancel)s!).Dispose(), + cancellation, + cancellationToken, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + + _ = this.nonPostJoinableTaskFactory.WhenBlockingMainThreadAsync(cancellationToken) + .ContinueWith( + static (_, s) => + { + (JoinableTaskContext me, Action callback, TState callState) = ((JoinableTaskContext, Action, TState))s!; + JoinableTask? ambientTask = me.AmbientTask; + if (ambientTask?.IsCompleted == false) + { + callback(callState); + } + }, + (this, action, state), + cancellationToken, + TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.LazyCancellation, + TaskScheduler.Default); + + return cancellation; + } + /// /// Creates a joinable task factory that automatically adds all created tasks /// to a collection that can be jointly joined. @@ -908,4 +960,92 @@ public void Dispose() } } } + + /// + /// Represents a disposable which does nothing. + /// + private class EmptyDisposable : IDisposable, IDisposableObservable + { + public bool IsDisposed => true; + + internal static IDisposable Instance { get; } = new EmptyDisposable(); + + public void Dispose() + { + } + } + + /// + /// Implements a disposable which triggers a cancellation token when it is disposed. + /// + private class DisposeToCancel : IDisposable + { + private readonly CancellationTokenSource cancellationTokenSource = new(); + + internal CancellationToken CancellationToken => this.cancellationTokenSource.Token; + + public void Dispose() + { + this.cancellationTokenSource.Cancel(); + this.cancellationTokenSource.Dispose(); + } + } + + /// + /// A special JoinableTaskFactory, which does not allow tasks to be scheduled when the UI thread is pumping messages. + /// This allows us to detect whether a task is blocking UI thread, because it would only be able to run under this condition. + /// + private class NonPostJoinableTaskFactory : JoinableTaskFactory + { + internal NonPostJoinableTaskFactory(JoinableTaskContext owner) + : base(owner) + { + } + + internal Task WhenBlockingMainThreadAsync(CancellationToken cancellationToken) + { + TaskCompletionSource taskCompletion = new(); + JoinableTask detectionTask; + + // By default SwitchToMainThreadAsync would not use the exact JoinableTaskFactory we want to use here, but would use the one from ambient task, so we have to do extra mile to create the child task. + // However, we must suppress relevance, otherwise, the SwitchToMainThreadAsync would post to the UI thread queue twice through both factories. This will defeat all the reason we create the special factory. + using (this.Context.SuppressRelevance()) + { + detectionTask = this.RunAsync(() => + { + // a simpler code inside is just to await this.SwitchToMainThreadAsync(alwaysYield: true, cancellationToken) + // alwaysYield is necessary to ensure we don't execute immediately when it is called on the main thread. + // however, this also leads an extra yield to handle cancellation token. Also, this would throw an extra cancellation exception. + // While we can suppress the exception with NoThrowAwaitable(), it would lead us to queue the continuation one more time in the case the cancellation is triggered, which is a waste. + // This leads the extra code done below. + MainThreadAwaiter awaiter = this.SwitchToMainThreadAsync(cancellationToken).NoThrowAwaitable().GetAwaiter(); + + awaiter.UnsafeOnCompleted(() => + { + if (cancellationToken.IsCancellationRequested) + { + taskCompletion.SetCanceled(); + } + + taskCompletion.SetResult(true); + + // ensure the cancellation registration is disposed. + awaiter.GetResult(); + }); + + return taskCompletion.Task; + }); + } + + // the detection task must be joined to the current task to ensure JTF dependencies to work (after we suppress relevance earlier). + _ = detectionTask.JoinAsync(); + + return taskCompletion.Task; + } + + protected internal override void PostToUnderlyingSynchronizationContext(SendOrPostCallback callback, object state) + { + return; + } + } } diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index e099d0429..ced34f519 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -15,4 +15,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void -Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt index e099d0429..ced34f519 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt @@ -15,4 +15,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void -Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt index e099d0429..ced34f519 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt @@ -15,4 +15,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void -Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index e099d0429..ced34f519 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -15,4 +15,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void -Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance \ No newline at end of file +Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs index e50d64078..dee034957 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs @@ -4,9 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using System.Xml.Linq; +using Microsoft; using Microsoft.VisualStudio.Threading; using Xunit; using Xunit.Abstractions; @@ -704,6 +706,332 @@ public void IsMainThreadBlockedFalseWhenTaskIsCompleted() checkTask!.Wait(); } + [Fact] + public void OnMainThreadBlockedReturnsDisposedObjectWithNoTask() + { + IDisposable disposable = this.Context.OnMainThreadBlocked(_ => { }, null); + Assert.NotNull(disposable); + Assert.True(disposable is IDisposableObservable { IsDisposed: true }); + } + + [Fact] + public void OnMainThreadBlockedNotCalledWhenMainThreadNotBlocked() + { + bool isCalled = false; + Task? spinOff = null; + + this.SimulateUIThread(() => + { + spinOff = Task.Run(async () => + { + await this.Context.Factory.RunAsync(async () => + { + IDisposable disposable = this.Context.OnMainThreadBlocked(_ => { isCalled = true; }, null); + await Task.Delay(10); + }); + }); + + return Task.CompletedTask; + }); + + spinOff!.Wait(); + Assert.False(isCalled); + } + + [Fact] + public void OnMainThreadBlockedInJoinableTaskRun() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + IDisposable disposable = this.Context.OnMainThreadBlocked( + s => + { + isCalled = true; + s.SetResult(true); + }, + taskCompletionSource); + await taskCompletionSource.Task; + }); + + return Task.CompletedTask; + }); + + Assert.True(isCalled); + } + + [Fact] + public void OnMainThreadBlockedInJoinableTaskRunChildTask() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + await this.Context.Factory.RunAsync(() => + { + _ = this.Context.OnMainThreadBlocked( + s => + { + isCalled = true; + s.SetResult(true); + }, + taskCompletionSource); + return taskCompletionSource.Task; + }); + }); + + return Task.CompletedTask; + }); + + Assert.True(isCalled); + } + + [Fact] + public void OnMainThreadBlockedInJoinableTaskRunChildTaskOnBackgroundThread() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + await TaskScheduler.Default; + var taskCompletionSource = new TaskCompletionSource(); + await this.Context.Factory.RunAsync(() => + { + _ = this.Context.OnMainThreadBlocked( + s => + { + isCalled = true; + s.SetResult(true); + }, + taskCompletionSource); + return taskCompletionSource.Task; + }); + }); + + return Task.CompletedTask; + }); + + Assert.True(isCalled); + } + + [Fact] + public void OnMainThreadBlockedInJoinableTaskRunChildTaskWhenWaited() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + JoinableTask? childTask = null; + var taskCompletionSource = new TaskCompletionSource(); + + using (this.Context.SuppressRelevance()) + { + childTask = this.Context.Factory.RunAsync(async () => + { + await TaskScheduler.Default; + _ = this.Context.OnMainThreadBlocked( + s => + { + isCalled = true; + s.SetResult(true); + }, + taskCompletionSource); + await taskCompletionSource.Task; + }); + } + + await Task.Delay(20); + Assert.False(isCalled); + + await childTask; + }); + + return Task.CompletedTask; + }); + + Assert.True(isCalled); + } + + [Fact] + public void OnMainThreadBlockedInJoinableTaskRunChildTaskWhenWaited2() + { + bool isCalled = false; + + var taskCompletionSource = new TaskCompletionSource(); + + JoinableTask spinOffTask; + spinOffTask = this.Context.Factory.RunAsync(async () => + { + await TaskScheduler.Default; + _ = this.Context.OnMainThreadBlocked( + s => + { + isCalled = true; + s.SetResult(true); + }, + taskCompletionSource); + await taskCompletionSource.Task; + }); + + this.SimulateUIThread(async () => + { + await Task.Delay(20); + Assert.False(isCalled); + + this.Context.Factory.Run(async () => + { + Assert.False(isCalled); + + await spinOffTask; + }); + }); + + Assert.True(isCalled); + } + + [Fact] + public void OnMainThreadBlockedNotCalledAfterDisposing() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + JoinableTask? childTask = null; + var taskCompletionSource = new TaskCompletionSource(); + IDisposable? registration = null; + + using (this.Context.SuppressRelevance()) + { + childTask = this.Context.Factory.RunAsync(async () => + { + registration = this.Context.OnMainThreadBlocked( + _ => + { + isCalled = true; + }, + null); + await taskCompletionSource.Task; + }); + } + + await Task.Delay(20); + Assert.False(isCalled); + + registration?.Dispose(); + await Task.Delay(20); + + taskCompletionSource.TrySetResult(true); + await childTask; + }); + + return Task.CompletedTask; + }); + + Assert.False(isCalled); + } + + [Fact, Trait("GC", "true")] + public void OnMainThreadBlockedNotLeakingObject() + { + TaskCompletionSource stopTask = new TaskCompletionSource(); + + WeakReference state = SpinOffTask(stopTask.Task, out JoinableTask newTask); + + // MainThreadAwaiter always queues the cancelation callback through ThreadPool.QueueUserWorkItem to make it difficult to control when the cancellation is handled. + Thread.Sleep(20); + + GC.Collect(); + Assert.False(state.IsAlive); + + stopTask.SetResult(true); + newTask.Join(); + + [MethodImpl(MethodImplOptions.NoInlining)] // mem leak detection requires literally popping locals with strong refs off the stack + WeakReference SpinOffTask(Task waitingTask, out JoinableTask newTask) + { + object? state = new(); + WeakReference weakState = new WeakReference(state); + IDisposable? registration = null; + var nowBlocking = new AsyncManualResetEvent(); + + newTask = this.Context.Factory.RunAsync(async () => + { + await TaskScheduler.Default; + registration = this.Context.OnMainThreadBlocked(s => { }, state); + state = null; + + Thread.MemoryBarrier(); + nowBlocking.Set(); + + await waitingTask; + }); + + nowBlocking.WaitAsync().Wait(); + + state = null; + registration?.Dispose(); + + Thread.MemoryBarrier(); + return weakState; + } + } + + [Fact, Trait("GC", "true")] + public void OnMainThreadBlockedNotLeakingObjectWhenTaskIsCompleted() + { + WeakReference state = SpinOffTask(out JoinableTask newTask); + + // MainThreadAwaiter always queues the cancelation callback through ThreadPool.QueueUserWorkItem to make it difficult to control when the cancellation is handled. + Thread.Sleep(20); + + GC.Collect(); + Assert.False(state.IsAlive); + + newTask.Join(); + + [MethodImpl(MethodImplOptions.NoInlining)] // mem leak detection requires literally popping locals with strong refs off the stack + WeakReference SpinOffTask(out JoinableTask newTask) + { + object? state = new(); + WeakReference weakState = new WeakReference(state); + var callbackIsSet = new AsyncManualResetEvent(); + + newTask = this.Context.Factory.RunAsync(async () => + { + await TaskScheduler.Default; + + _ = this.Context.OnMainThreadBlocked(s => { }, state); + state = null; + Thread.MemoryBarrier(); + + callbackIsSet.Set(); + + await Task.Yield(); + }); + + callbackIsSet.WaitAsync().Wait(); + + state = null; + Thread.MemoryBarrier(); + + newTask.Join(); + return weakState; + } + } + [Fact] public void RevertRelevanceDefaultValue() { From 1c2ab6d1d3fd392055496229c8a69f0fc5a2d2c4 Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Sun, 4 Feb 2024 17:50:22 -0900 Subject: [PATCH 2/7] missed to update one constuctor. --- src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs index b0746c2c3..185328961 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs @@ -158,7 +158,6 @@ public partial class JoinableTaskContext : IDisposable public JoinableTaskContext() : this(Thread.CurrentThread, SynchronizationContext.Current) { - this.nonPostJoinableTaskFactory = new NonPostJoinableTaskFactory(this); } /// @@ -176,6 +175,7 @@ public JoinableTaskContext(Thread? mainThread = null, SynchronizationContext? sy this.MainThread = mainThread ?? Thread.CurrentThread; this.mainThreadManagedThreadId = this.MainThread.ManagedThreadId; this.UnderlyingSynchronizationContext = synchronizationContext ?? SynchronizationContext.Current; // may still be null after this. + this.nonPostJoinableTaskFactory = new NonPostJoinableTaskFactory(this); } /// From c03bdece011d818f6fb09a31af64ac308ac3e080 Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Mon, 5 Feb 2024 09:26:11 -0900 Subject: [PATCH 3/7] Additional change to prevent leading tasks. --- .../JoinableTaskContext.cs | 30 +++++-- .../JoinableTaskContextTests.cs | 82 +++++++++++++++++++ 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs index 185328961..392a38b46 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs @@ -131,8 +131,6 @@ public partial class JoinableTaskContext : IDisposable /// private readonly string contextId = Guid.NewGuid().ToString("n"); - private readonly NonPostJoinableTaskFactory nonPostJoinableTaskFactory; - /// /// The next unique ID to assign to a for which a token is required. /// @@ -149,6 +147,12 @@ public partial class JoinableTaskContext : IDisposable [DebuggerBrowsable(DebuggerBrowsableState.Never)] private JoinableTaskFactory? nonJoinableFactory; + /// + /// A special JoinableTaskFactory to detect main thread blocking tasks. + /// + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private NonPostJoinableTaskFactory? nonPostJoinableTaskFactory; + /// /// Initializes a new instance of the class /// assuming the current thread is the main thread and @@ -175,7 +179,6 @@ public JoinableTaskContext(Thread? mainThread = null, SynchronizationContext? sy this.MainThread = mainThread ?? Thread.CurrentThread; this.mainThreadManagedThreadId = this.MainThread.ManagedThreadId; this.UnderlyingSynchronizationContext = synchronizationContext ?? SynchronizationContext.Current; // may still be null after this. - this.nonPostJoinableTaskFactory = new NonPostJoinableTaskFactory(this); } /// @@ -418,18 +421,24 @@ public IDisposable OnMainThreadBlocked(Action action, TState sta TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + if (this.nonPostJoinableTaskFactory is null) + { + Interlocked.CompareExchange(ref this.nonPostJoinableTaskFactory, new NonPostJoinableTaskFactory(this), null); + } + _ = this.nonPostJoinableTaskFactory.WhenBlockingMainThreadAsync(cancellationToken) .ContinueWith( static (_, s) => { - (JoinableTaskContext me, Action callback, TState callState) = ((JoinableTaskContext, Action, TState))s!; + (JoinableTaskContext me, DisposeToCancel cancellationSource, Action callback, TState callState) = ((JoinableTaskContext, DisposeToCancel, Action, TState))s!; JoinableTask? ambientTask = me.AmbientTask; if (ambientTask?.IsCompleted == false) { + cancellationSource.Dispose(); callback(callState); } }, - (this, action, state), + (this, cancellation, action, state), cancellationToken, TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.LazyCancellation, TaskScheduler.Default); @@ -980,14 +989,17 @@ public void Dispose() /// private class DisposeToCancel : IDisposable { - private readonly CancellationTokenSource cancellationTokenSource = new(); + private CancellationTokenSource? cancellationTokenSource = new(); - internal CancellationToken CancellationToken => this.cancellationTokenSource.Token; + internal CancellationToken CancellationToken => this.cancellationTokenSource?.Token ?? throw new ObjectDisposedException(nameof(DisposeToCancel)); public void Dispose() { - this.cancellationTokenSource.Cancel(); - this.cancellationTokenSource.Dispose(); + if (Interlocked.Exchange(ref this.cancellationTokenSource, null) is CancellationTokenSource cancellationTokenSource) + { + cancellationTokenSource.Cancel(); + cancellationTokenSource.Dispose(); + } } } diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs index dee034957..a3f6eca38 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs @@ -943,6 +943,55 @@ public void OnMainThreadBlockedNotCalledAfterDisposing() Assert.False(isCalled); } + [Fact] + public void OnMainThreadBlockedFineToDisposeMultipleTimes() + { + bool isCalled = false; + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + JoinableTask? childTask = null; + var taskCompletionSource = new TaskCompletionSource(); + IDisposable? registration = null; + + using (this.Context.SuppressRelevance()) + { + childTask = this.Context.Factory.RunAsync(async () => + { + registration = this.Context.OnMainThreadBlocked( + _ => + { + isCalled = true; + }, + null); + await taskCompletionSource.Task; + }); + } + + await Task.Delay(20); + Assert.False(isCalled); + + Assert.NotNull(registration); + + registration.Dispose(); + registration.Dispose(); + + await Task.Delay(20); + + registration.Dispose(); + + taskCompletionSource.TrySetResult(true); + await childTask; + }); + + return Task.CompletedTask; + }); + + Assert.False(isCalled); + } + [Fact, Trait("GC", "true")] public void OnMainThreadBlockedNotLeakingObject() { @@ -1032,6 +1081,39 @@ WeakReference SpinOffTask(out JoinableTask newTask) } } + [Fact, Trait("GC", "true")] + public void OnMainThreadBlockedNotLeakingDisposeCancellation() + { + WeakReference state = SpinOffTask(out JoinableTask newTask); + + // MainThreadAwaiter always queues the cancelation callback through ThreadPool.QueueUserWorkItem to make it difficult to control when the cancellation is handled. + Thread.Sleep(20); + + GC.Collect(); + Assert.False(state.IsAlive); + + newTask.Join(); + + [MethodImpl(MethodImplOptions.NoInlining)] // mem leak detection requires literally popping locals with strong refs off the stack + WeakReference SpinOffTask(out JoinableTask newTask) + { + WeakReference? weakState = null; + newTask = this.Context.Factory.RunAsync(async () => + { + await TaskScheduler.Default; + + weakState = new WeakReference(this.Context.OnMainThreadBlocked(s => { }, null)); + + await Task.Yield(); + }); + + newTask.Join(); + + Assert.NotNull(weakState); + return weakState; + } + } + [Fact] public void RevertRelevanceDefaultValue() { From 4bcb5f6700830200a84114a3cd2219869ee8ffab Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Mon, 5 Feb 2024 11:29:48 -0900 Subject: [PATCH 4/7] Fix a mistake when handling cancellationToken. --- src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs index 392a38b46..f2ce2a92b 100644 --- a/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs +++ b/src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs @@ -1038,8 +1038,10 @@ internal Task WhenBlockingMainThreadAsync(CancellationToken cancellationToken) { taskCompletion.SetCanceled(); } - - taskCompletion.SetResult(true); + else + { + taskCompletion.SetResult(true); + } // ensure the cancellation registration is disposed. awaiter.GetResult(); From e705cb4a8b93068faa09c8018cd59000536ff453 Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Mon, 5 Feb 2024 12:42:49 -0900 Subject: [PATCH 5/7] add WaitUnlessBlockingMainThreadAsync extension method this is to make the common usage easier. --- .../ThreadingTools.cs | 61 +++++ .../net472/PublicAPI.Unshipped.txt | 3 +- .../net6.0-windows/PublicAPI.Unshipped.txt | 3 +- .../net6.0/PublicAPI.Unshipped.txt | 3 +- .../netstandard2.0/PublicAPI.Unshipped.txt | 3 +- .../JoinableTaskContextTests.cs | 213 ++++++++++++++++++ 6 files changed, 282 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs index d6d87f01c..6cc5616e7 100644 --- a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs +++ b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs @@ -143,6 +143,67 @@ public static Task WithCancellation(this Task task, CancellationToken cancellati return WithCancellationSlow(task, continueOnCapturedContext: false, cancellationToken: cancellationToken); } + /// + /// Wait a long running or later finishing task, but abort if this work is blocking the main thread. + /// + /// The JoinableTaskContext. + /// A slow task to wait. + /// An optional cancellation token. + /// A task is completed either the slow task is completed, or the input cancellation token is triggered, or the context task blocks the main thread (inside JTF.Run). + /// Throw when the cancellation token is triggered or the current task blocks the main thread. + public static Task WaitUnlessBlockingMainThreadAsync(this JoinableTaskContext context, Task slowTask, CancellationToken cancellationToken = default) + { + Requires.NotNull(context, nameof(context)); + Requires.NotNull(slowTask, nameof(slowTask)); + + if (slowTask.IsCompleted) + { + return slowTask; + } + + if (context.AmbientTask is null) + { + return slowTask.WithCancellation(cancellationToken); + } + + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + return WaitUnlessBlockingMainThreadSlowAsync(context, slowTask, cancellationToken); + + async Task WaitUnlessBlockingMainThreadSlowAsync(JoinableTaskContext context, Task slowTask, CancellationToken cancellationToken) + { + var taskCompletionSource = new TaskCompletionSource(slowTask); + using (context.OnMainThreadBlocked( + static s => + { + // prefer to complete normally if the inner task is completed. + if (!((Task)s.Task.AsyncState!).IsCompleted) + { + s.TrySetResult(false); + } + }, + taskCompletionSource)) + { + using (cancellationToken.Register(s => ((TaskCompletionSource)s!).TrySetResult(true), taskCompletionSource)) + { + if (slowTask != await Task.WhenAny(slowTask, taskCompletionSource.Task).ConfigureAwait(false)) + { + cancellationToken.ThrowIfCancellationRequested(); + + // throw when the main thread is blocked + throw new OperationCanceledException(); + } + } + + // Rethrow any fault/cancellation exception, the task should be completed. + await slowTask; + } + } + } + /// /// Applies the specified to the caller's context. /// diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index ced34f519..746b93445 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -16,4 +16,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance -Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt index ced34f519..746b93445 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt @@ -16,4 +16,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance -Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt index ced34f519..746b93445 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt @@ -16,4 +16,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance -Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index ced34f519..746b93445 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -16,4 +16,5 @@ Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance.Dispose() -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.get -> bool Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance -Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! \ No newline at end of file +Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs index a3f6eca38..8d22718ae 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs @@ -1114,6 +1114,219 @@ WeakReference SpinOffTask(out JoinableTask newTask) } } + [Fact] + public void WaitUnlessBlockingMainThreadAsyncReturnsCompletedTask() + { + var taskCompletionSource = new TaskCompletionSource(); + taskCompletionSource.SetResult(true); + + var cancellationSource = new CancellationTokenSource(); + Assert.Equal(this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token), taskCompletionSource.Task); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncReturnsTaskUnderSimpleCondition() + { + var taskCompletionSource = new TaskCompletionSource(); + Assert.Equal(this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, CancellationToken.None), taskCompletionSource.Task); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncCancellable() + { + var taskCompletionSource = new TaskCompletionSource(); + + var cancellationSource = new CancellationTokenSource(); + Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + + Assert.False(waitingTask.IsCompleted); + cancellationSource.Cancel(); + + try + { + waitingTask.GetAwaiter().GetResult(); + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncCancellableInsideJTF() + { + JoinableTask task = this.Context.Factory.RunAsync(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + + var cancellationSource = new CancellationTokenSource(); + Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + + Assert.False(waitingTask.IsCompleted); + cancellationSource.Cancel(); + + try + { + await waitingTask; + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + }); + + task.Join(); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncCanComplete() + { + var taskCompletionSource = new TaskCompletionSource(); + + var cancellationSource = new CancellationTokenSource(); + Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + + Assert.False(waitingTask.IsCompleted); + taskCompletionSource.SetResult(true); + + waitingTask.Wait(TestTimeout); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncCanCompleteInsideJTF() + { + this.Context.Factory.Run(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + + var cancellationSource = new CancellationTokenSource(); + Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + + Assert.False(waitingTask.IsCompleted); + taskCompletionSource.SetResult(true); + + await waitingTask; + }); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncAbortWithinMainThreadBlockingStack() + { + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + try + { + await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, CancellationToken.None); + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + }); + + return Task.CompletedTask; + }); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncAbortWithinMainThreadBlockingStack2() + { + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + var cancellationSource = new CancellationTokenSource(); + try + { + await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + }); + + return Task.CompletedTask; + }); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncAbortWhenMainThreadingBlockedLater() + { + JoinableTask firstTask = this.Context.Factory.RunAsync(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + var cancellationSource = new CancellationTokenSource(); + try + { + await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + }); + + Assert.False(firstTask.IsCompleted); + + this.SimulateUIThread(() => + { + firstTask.Join(); + return Task.CompletedTask; + }); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncAbortWhenMainThreadingBlockedLaterAwaited() + { + JoinableTask firstTask = this.Context.Factory.RunAsync(async () => + { + var taskCompletionSource = new TaskCompletionSource(); + var cancellationSource = new CancellationTokenSource(); + try + { + await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Assert.Fail("Expect to throw."); + } + catch (OperationCanceledException) + { + } + }); + + Assert.False(firstTask.IsCompleted); + + this.SimulateUIThread(() => + { + this.Context.Factory.Run(async () => await firstTask); + return Task.CompletedTask; + }); + } + + [Fact] + public void WaitUnlessBlockingMainThreadAsyncTaskCompletedFirst() + { + var taskCompletionSource = new TaskCompletionSource(); + var cancellationSource = new CancellationTokenSource(); + + JoinableTask firstTask = this.Context.Factory.RunAsync(async () => + { + await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + }); + + Assert.False(firstTask.IsCompleted); + + this.SimulateUIThread(() => + { + taskCompletionSource.SetResult(true); + this.Context.Factory.Run(async () => await firstTask); + return Task.CompletedTask; + }); + } + [Fact] public void RevertRelevanceDefaultValue() { From 5c7830e6658d1f3d2dc6ead0d023addec0b3165f Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Mon, 5 Feb 2024 12:46:30 -0900 Subject: [PATCH 6/7] mark callback method static --- src/Microsoft.VisualStudio.Threading/ThreadingTools.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs index 6cc5616e7..398fd320f 100644 --- a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs +++ b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs @@ -187,7 +187,7 @@ async Task WaitUnlessBlockingMainThreadSlowAsync(JoinableTaskContext context, Ta }, taskCompletionSource)) { - using (cancellationToken.Register(s => ((TaskCompletionSource)s!).TrySetResult(true), taskCompletionSource)) + using (cancellationToken.Register(static s => ((TaskCompletionSource)s!).TrySetResult(true), taskCompletionSource)) { if (slowTask != await Task.WhenAny(slowTask, taskCompletionSource.Task).ConfigureAwait(false)) { From 5b637e30c4aed02e13a5a26f30b6b7727bbe22ac Mon Sep 17 00:00:00 2001 From: "Lifeng Lu (from Dev Box)" Date: Mon, 5 Feb 2024 16:16:01 -0800 Subject: [PATCH 7/7] Update the order of two parameters. --- .../ThreadingTools.cs | 4 ++-- .../net472/PublicAPI.Unshipped.txt | 2 +- .../net6.0-windows/PublicAPI.Unshipped.txt | 2 +- .../net6.0/PublicAPI.Unshipped.txt | 2 +- .../netstandard2.0/PublicAPI.Unshipped.txt | 2 +- .../JoinableTaskContextTests.cs | 22 +++++++++---------- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs index 398fd320f..6e84f628e 100644 --- a/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs +++ b/src/Microsoft.VisualStudio.Threading/ThreadingTools.cs @@ -146,12 +146,12 @@ public static Task WithCancellation(this Task task, CancellationToken cancellati /// /// Wait a long running or later finishing task, but abort if this work is blocking the main thread. /// - /// The JoinableTaskContext. /// A slow task to wait. + /// The JoinableTaskContext. /// An optional cancellation token. /// A task is completed either the slow task is completed, or the input cancellation token is triggered, or the context task blocks the main thread (inside JTF.Run). /// Throw when the cancellation token is triggered or the current task blocks the main thread. - public static Task WaitUnlessBlockingMainThreadAsync(this JoinableTaskContext context, Task slowTask, CancellationToken cancellationToken = default) + public static Task WaitUnlessBlockingMainThreadAsync(this Task slowTask, JoinableTaskContext context, CancellationToken cancellationToken = default) { Requires.NotNull(context, nameof(context)); Requires.NotNull(slowTask, nameof(slowTask)); diff --git a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt index 746b93445..e2cd04d78 100644 --- a/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net472/PublicAPI.Unshipped.txt @@ -17,4 +17,4 @@ Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection. Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! -static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this System.Threading.Tasks.Task! slowTask, Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt index 746b93445..e2cd04d78 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0-windows/PublicAPI.Unshipped.txt @@ -17,4 +17,4 @@ Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection. Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! -static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this System.Threading.Tasks.Task! slowTask, Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt index 746b93445..e2cd04d78 100644 --- a/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/net6.0/PublicAPI.Unshipped.txt @@ -17,4 +17,4 @@ Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection. Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! -static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this System.Threading.Tasks.Task! slowTask, Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt index 746b93445..e2cd04d78 100644 --- a/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.VisualStudio.Threading/netstandard2.0/PublicAPI.Unshipped.txt @@ -17,4 +17,4 @@ Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection. Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRecursiveFactoryDetection.init -> void Microsoft.VisualStudio.Threading.AsyncLazy.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy.RevertRelevance Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked(System.Action! action, TState state) -> System.IDisposable! -static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file +static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this System.Threading.Tasks.Task! slowTask, Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! \ No newline at end of file diff --git a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs index 8d22718ae..ce7ba2241 100644 --- a/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs +++ b/test/Microsoft.VisualStudio.Threading.Tests/JoinableTaskContextTests.cs @@ -1121,14 +1121,14 @@ public void WaitUnlessBlockingMainThreadAsyncReturnsCompletedTask() taskCompletionSource.SetResult(true); var cancellationSource = new CancellationTokenSource(); - Assert.Equal(this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token), taskCompletionSource.Task); + Assert.Equal(taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token), taskCompletionSource.Task); } [Fact] public void WaitUnlessBlockingMainThreadAsyncReturnsTaskUnderSimpleCondition() { var taskCompletionSource = new TaskCompletionSource(); - Assert.Equal(this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, CancellationToken.None), taskCompletionSource.Task); + Assert.Equal(taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, CancellationToken.None), taskCompletionSource.Task); } [Fact] @@ -1137,7 +1137,7 @@ public void WaitUnlessBlockingMainThreadAsyncCancellable() var taskCompletionSource = new TaskCompletionSource(); var cancellationSource = new CancellationTokenSource(); - Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Task waitingTask = taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.False(waitingTask.IsCompleted); cancellationSource.Cancel(); @@ -1160,7 +1160,7 @@ public void WaitUnlessBlockingMainThreadAsyncCancellableInsideJTF() var taskCompletionSource = new TaskCompletionSource(); var cancellationSource = new CancellationTokenSource(); - Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Task waitingTask = taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.False(waitingTask.IsCompleted); cancellationSource.Cancel(); @@ -1184,7 +1184,7 @@ public void WaitUnlessBlockingMainThreadAsyncCanComplete() var taskCompletionSource = new TaskCompletionSource(); var cancellationSource = new CancellationTokenSource(); - Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Task waitingTask = taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.False(waitingTask.IsCompleted); taskCompletionSource.SetResult(true); @@ -1200,7 +1200,7 @@ public void WaitUnlessBlockingMainThreadAsyncCanCompleteInsideJTF() var taskCompletionSource = new TaskCompletionSource(); var cancellationSource = new CancellationTokenSource(); - Task waitingTask = this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + Task waitingTask = taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.False(waitingTask.IsCompleted); taskCompletionSource.SetResult(true); @@ -1219,7 +1219,7 @@ public void WaitUnlessBlockingMainThreadAsyncAbortWithinMainThreadBlockingStack( var taskCompletionSource = new TaskCompletionSource(); try { - await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, CancellationToken.None); + await taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, CancellationToken.None); Assert.Fail("Expect to throw."); } catch (OperationCanceledException) @@ -1242,7 +1242,7 @@ public void WaitUnlessBlockingMainThreadAsyncAbortWithinMainThreadBlockingStack2 var cancellationSource = new CancellationTokenSource(); try { - await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + await taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.Fail("Expect to throw."); } catch (OperationCanceledException) @@ -1263,7 +1263,7 @@ public void WaitUnlessBlockingMainThreadAsyncAbortWhenMainThreadingBlockedLater( var cancellationSource = new CancellationTokenSource(); try { - await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + await taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.Fail("Expect to throw."); } catch (OperationCanceledException) @@ -1289,7 +1289,7 @@ public void WaitUnlessBlockingMainThreadAsyncAbortWhenMainThreadingBlockedLaterA var cancellationSource = new CancellationTokenSource(); try { - await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + await taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); Assert.Fail("Expect to throw."); } catch (OperationCanceledException) @@ -1314,7 +1314,7 @@ public void WaitUnlessBlockingMainThreadAsyncTaskCompletedFirst() JoinableTask firstTask = this.Context.Factory.RunAsync(async () => { - await this.Context.WaitUnlessBlockingMainThreadAsync(taskCompletionSource.Task, cancellationSource.Token); + await taskCompletionSource.Task.WaitUnlessBlockingMainThreadAsync(this.Context, cancellationSource.Token); }); Assert.False(firstTask.IsCompleted);