Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements OnMainThreadBlocked. #1280

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/Microsoft.VisualStudio.Threading/JoinableTaskContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ public partial class JoinableTaskContext : IDisposable
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskFactory? nonJoinableFactory;

/// <summary>
/// A special JoinableTaskFactory to detect main thread blocking tasks.
/// </summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private NonPostJoinableTaskFactory? nonPostJoinableTaskFactory;

/// <summary>
/// Initializes a new instance of the <see cref="JoinableTaskContext"/> class
/// assuming the current thread is the main thread and
Expand Down Expand Up @@ -385,6 +391,61 @@ public bool IsMainThreadMaybeBlocked()
return false;
}

/// <summary>
/// Registers a callback when the current JoinableTask is blocking the UI thread.
/// </summary>
/// <typeparam name="TState">The type of state used by the callback.</typeparam>
/// <param name="action">A callback method.</param>
/// <param name="state">A state passing to the callback method.</param>
/// <returns>A disposable which can be used to unregister the callback.</returns>
public IDisposable OnMainThreadBlocked<TState>(Action<TState> action, TState state)
{
Requires.NotNull(action, nameof(action));

JoinableTask? ambientTask = this.AmbientTask;
if (ambientTask is null)
{
// when it is called outside of a JoinableTask, it would never block main thread in the future, so we would not want to waste time further.
return EmptyDisposable.Instance;
}

var cancellation = new DisposeToCancel();
CancellationToken cancellationToken = cancellation.CancellationToken;

// chain a task to ensure we would clean up all things when the ambient task is completed.
// A completed task would not block the main thread further.
_ = ambientTask.Task.ContinueWith(
static (_, s) => ((DisposeToCancel)s!).Dispose(),
cancellation,
cancellationToken,
TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);

if (this.nonPostJoinableTaskFactory is null)
{
Interlocked.CompareExchange(ref this.nonPostJoinableTaskFactory, new NonPostJoinableTaskFactory(this), null);
}

_ = this.nonPostJoinableTaskFactory.WhenBlockingMainThreadAsync(cancellationToken)
.ContinueWith(
static (_, s) =>
{
(JoinableTaskContext me, DisposeToCancel cancellationSource, Action<TState> callback, TState callState) = ((JoinableTaskContext, DisposeToCancel, Action<TState>, TState))s!;
JoinableTask? ambientTask = me.AmbientTask;
if (ambientTask?.IsCompleted == false)
{
cancellationSource.Dispose();
callback(callState);
}
},
(this, cancellation, action, state),
cancellationToken,
TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.LazyCancellation,
TaskScheduler.Default);

return cancellation;
}

/// <summary>
/// Creates a joinable task factory that automatically adds all created tasks
/// to a collection that can be jointly joined.
Expand Down Expand Up @@ -908,4 +969,97 @@ public void Dispose()
}
}
}

/// <summary>
/// Represents a disposable which does nothing.
/// </summary>
private class EmptyDisposable : IDisposable, IDisposableObservable
{
public bool IsDisposed => true;

internal static IDisposable Instance { get; } = new EmptyDisposable();

public void Dispose()
{
}
}

/// <summary>
/// Implements a disposable which triggers a cancellation token when it is disposed.
/// </summary>
private class DisposeToCancel : IDisposable
{
private CancellationTokenSource? cancellationTokenSource = new();

internal CancellationToken CancellationToken => this.cancellationTokenSource?.Token ?? throw new ObjectDisposedException(nameof(DisposeToCancel));

public void Dispose()
{
if (Interlocked.Exchange(ref this.cancellationTokenSource, null) is CancellationTokenSource cancellationTokenSource)
{
cancellationTokenSource.Cancel();
cancellationTokenSource.Dispose();
}
}
}

/// <summary>
/// A special JoinableTaskFactory, which does not allow tasks to be scheduled when the UI thread is pumping messages.
/// This allows us to detect whether a task is blocking UI thread, because it would only be able to run under this condition.
/// </summary>
private class NonPostJoinableTaskFactory : JoinableTaskFactory
{
internal NonPostJoinableTaskFactory(JoinableTaskContext owner)
: base(owner)
{
}

internal Task WhenBlockingMainThreadAsync(CancellationToken cancellationToken)
{
TaskCompletionSource<bool> taskCompletion = new();
JoinableTask detectionTask;

// By default SwitchToMainThreadAsync would not use the exact JoinableTaskFactory we want to use here, but would use the one from ambient task, so we have to do extra mile to create the child task.
// However, we must suppress relevance, otherwise, the SwitchToMainThreadAsync would post to the UI thread queue twice through both factories. This will defeat all the reason we create the special factory.
using (this.Context.SuppressRelevance())
{
detectionTask = this.RunAsync(() =>
{
// a simpler code inside is just to await this.SwitchToMainThreadAsync(alwaysYield: true, cancellationToken)
// alwaysYield is necessary to ensure we don't execute immediately when it is called on the main thread.
// however, this also leads an extra yield to handle cancellation token. Also, this would throw an extra cancellation exception.
// While we can suppress the exception with NoThrowAwaitable(), it would lead us to queue the continuation one more time in the case the cancellation is triggered, which is a waste.
// This leads the extra code done below.
MainThreadAwaiter awaiter = this.SwitchToMainThreadAsync(cancellationToken).NoThrowAwaitable().GetAwaiter();

awaiter.UnsafeOnCompleted(() =>
{
if (cancellationToken.IsCancellationRequested)
{
taskCompletion.SetCanceled();
}
else
{
taskCompletion.SetResult(true);
}

// ensure the cancellation registration is disposed.
awaiter.GetResult();
});

return taskCompletion.Task;
});
}

// the detection task must be joined to the current task to ensure JTF dependencies to work (after we suppress relevance earlier).
_ = detectionTask.JoinAsync();

return taskCompletion.Task;
}

protected internal override void PostToUnderlyingSynchronizationContext(SendOrPostCallback callback, object state)
{
return;
}
}
}
61 changes: 61 additions & 0 deletions src/Microsoft.VisualStudio.Threading/ThreadingTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,67 @@ public static Task WithCancellation(this Task task, CancellationToken cancellati
return WithCancellationSlow(task, continueOnCapturedContext: false, cancellationToken: cancellationToken);
}

/// <summary>
/// Wait a long running or later finishing task, but abort if this work is blocking the main thread.
/// </summary>
/// <param name="context">The JoinableTaskContext.</param>
/// <param name="slowTask">A slow task to wait.</param>
/// <param name="cancellationToken">An optional cancellation token.</param>
/// <returns>A task is completed either the slow task is completed, or the input cancellation token is triggered, or the context task blocks the main thread (inside JTF.Run).</returns>
/// <exception cref="OperationCanceledException">Throw when the cancellation token is triggered or the current task blocks the main thread.</exception>
public static Task WaitUnlessBlockingMainThreadAsync(this JoinableTaskContext context, Task slowTask, CancellationToken cancellationToken = default)
Copy link
Member

@richardstanton richardstanton Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a slightly different signature to be similar to WithCancellation():
CancelIfMainThreadBlockedAsync(this Task slowTask, JoinableTaskContext context, CT)
That would associate it with the Task in Intellisense and imply that it throws an OperationCanceledEx like it currently does.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the order of parameters suggested by Richard, and updated the PR. I would like to hear more feedback on the name of the method itself, so I haven't updated it yet.

{
Requires.NotNull(context, nameof(context));
Requires.NotNull(slowTask, nameof(slowTask));

if (slowTask.IsCompleted)
{
return slowTask;
}

if (context.AmbientTask is null)
{
return slowTask.WithCancellation(cancellationToken);
}

if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}

return WaitUnlessBlockingMainThreadSlowAsync(context, slowTask, cancellationToken);

async Task WaitUnlessBlockingMainThreadSlowAsync(JoinableTaskContext context, Task slowTask, CancellationToken cancellationToken)
{
var taskCompletionSource = new TaskCompletionSource<bool>(slowTask);
using (context.OnMainThreadBlocked(
static s =>
{
// prefer to complete normally if the inner task is completed.
if (!((Task)s.Task.AsyncState!).IsCompleted)
{
s.TrySetResult(false);
}
},
taskCompletionSource))
{
using (cancellationToken.Register(static s => ((TaskCompletionSource<bool>)s!).TrySetResult(true), taskCompletionSource))
{
if (slowTask != await Task.WhenAny(slowTask, taskCompletionSource.Task).ConfigureAwait(false))
{
cancellationToken.ThrowIfCancellationRequested();

// throw when the main thread is blocked
throw new OperationCanceledException();
}
}

// Rethrow any fault/cancellation exception, the task should be completed.
await slowTask;
}
}
}

/// <summary>
/// Applies the specified <see cref="SynchronizationContext"/> to the caller's context.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked<TState>(System.Action<TState>! action, TState state) -> System.IDisposable!
static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked<TState>(System.Action<TState>! action, TState state) -> System.IDisposable!
static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked<TState>(System.Action<TState>! action, TState state) -> System.IDisposable!
static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance.Dispose() -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.get -> bool
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRecursiveFactoryDetection.init -> void
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.AsyncLazy<T>.SuppressRelevance() -> Microsoft.VisualStudio.Threading.AsyncLazy<T>.RevertRelevance
Microsoft.VisualStudio.Threading.JoinableTaskContext.OnMainThreadBlocked<TState>(System.Action<TState>! action, TState state) -> System.IDisposable!
static Microsoft.VisualStudio.Threading.ThreadingTools.WaitUnlessBlockingMainThreadAsync(this Microsoft.VisualStudio.Threading.JoinableTaskContext! context, System.Threading.Tasks.Task! slowTask, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
Loading