diff --git a/src/Draco.Compiler/Internal/Solver/Constraint.cs b/src/Draco.Compiler/Internal/Solver/Constraint.cs index 82193de9d..75c9b8c19 100644 --- a/src/Draco.Compiler/Internal/Solver/Constraint.cs +++ b/src/Draco.Compiler/Internal/Solver/Constraint.cs @@ -1,6 +1,7 @@ using System; using Draco.Compiler.Api.Diagnostics; using Draco.Compiler.Internal.Binding.Tasks; +using Draco.Compiler.Internal.Solver.Tasks; namespace Draco.Compiler.Internal.Solver; @@ -10,7 +11,7 @@ namespace Draco.Compiler.Internal.Solver; /// The result type. internal abstract class Constraint : IConstraint { - public BindingTaskCompletionSource CompletionSource { get; } + public SolverTaskCompletionSource CompletionSource { get; } public ConstraintLocator Locator { get; } protected Constraint(ConstraintSolver solver, ConstraintLocator locator) diff --git a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs index 17f78baaf..f0ff06528 100644 --- a/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs +++ b/src/Draco.Compiler/Internal/Solver/ConstraintSolver_Constraints.cs @@ -5,6 +5,7 @@ using System.Linq; using Draco.Compiler.Api.Syntax; using Draco.Compiler.Internal.Binding.Tasks; +using Draco.Compiler.Internal.Solver.Tasks; using Draco.Compiler.Internal.Symbols; using Draco.Compiler.Internal.Utilities; @@ -72,7 +73,7 @@ public bool TryDequeue( /// The type that is constrained to be the same as . /// The syntax that the constraint originates from. /// The promise for the constraint added. - public BindingTask SameType(TypeSymbol first, TypeSymbol second, SyntaxNode syntax) + public SolverTask SameType(TypeSymbol first, TypeSymbol second, SyntaxNode syntax) { var constraint = new SameTypeConstraint(this, ImmutableArray.Create(first, second), ConstraintLocator.Syntax(syntax)); this.Add(constraint); @@ -86,7 +87,7 @@ public BindingTask SameType(TypeSymbol first, TypeSymbol second, SyntaxNod /// The type assigned. /// The syntax that the constraint originates from. /// The promise for the constraint added. - public BindingTask Assignable(TypeSymbol targetType, TypeSymbol assignedType, SyntaxNode syntax) => + public SolverTask Assignable(TypeSymbol targetType, TypeSymbol assignedType, SyntaxNode syntax) => this.Assignable(targetType, assignedType, ConstraintLocator.Syntax(syntax)); /// @@ -96,7 +97,7 @@ public BindingTask Assignable(TypeSymbol targetType, TypeSymbol assignedTy /// The type assigned. /// The locator for the constraint. /// The promise for the constraint added. - public BindingTask Assignable(TypeSymbol targetType, TypeSymbol assignedType, ConstraintLocator locator) + public SolverTask Assignable(TypeSymbol targetType, TypeSymbol assignedType, ConstraintLocator locator) { var constraint = new AssignableConstraint(this, targetType, assignedType, locator); this.Add(constraint); @@ -110,7 +111,7 @@ public BindingTask Assignable(TypeSymbol targetType, TypeSymbol assignedTy /// The alternative types to find the common type of. /// The syntax that the constraint originates from. /// The promise of the constraint added. - public BindingTask CommonType( + public SolverTask CommonType( TypeSymbol commonType, ImmutableArray alternativeTypes, SyntaxNode syntax) => this.CommonType(commonType, alternativeTypes, ConstraintLocator.Syntax(syntax)); @@ -122,7 +123,7 @@ public BindingTask CommonType( /// The alternative types to find the common type of. /// The locator for this constraint. /// The promise of the constraint added. - public BindingTask CommonType( + public SolverTask CommonType( TypeSymbol commonType, ImmutableArray alternativeTypes, ConstraintLocator locator) @@ -140,7 +141,7 @@ public BindingTask CommonType( /// The type of the member. /// The syntax that the constraint originates from. /// The promise of the accessed member symbol. - public BindingTask Member( + public SolverTask Member( TypeSymbol accessedType, string memberName, out TypeSymbol memberType, @@ -160,7 +161,7 @@ public BindingTask Member( /// The return type. /// The syntax that the constraint originates from. /// The promise of the constraint. - public BindingTask Call( + public SolverTask Call( TypeSymbol calledType, ImmutableArray args, out TypeSymbol returnType, @@ -181,7 +182,7 @@ public BindingTask Call( /// The return type of the call. /// The syntax that the constraint originates from. /// The promise for the resolved overload. - public BindingTask Overload( + public SolverTask Overload( string name, ImmutableArray functions, ImmutableArray args, diff --git a/src/Draco.Compiler/Internal/Solver/IConstraint.cs b/src/Draco.Compiler/Internal/Solver/IConstraint.cs index e1dc339d9..3ea037dd6 100644 --- a/src/Draco.Compiler/Internal/Solver/IConstraint.cs +++ b/src/Draco.Compiler/Internal/Solver/IConstraint.cs @@ -1,4 +1,5 @@ using Draco.Compiler.Internal.Binding.Tasks; +using Draco.Compiler.Internal.Solver.Tasks; namespace Draco.Compiler.Internal.Solver; @@ -22,5 +23,5 @@ internal interface IConstraint : IConstraint /// /// The completion source of this constraint. /// - public BindingTaskCompletionSource CompletionSource { get; } + public SolverTaskCompletionSource CompletionSource { get; } } diff --git a/src/Draco.Compiler/Internal/Solver/Tasks/SolverTask.cs b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTask.cs new file mode 100644 index 000000000..f25bcb642 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTask.cs @@ -0,0 +1,34 @@ +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Draco.Compiler.Internal.Solver; + +namespace Draco.Compiler.Internal.Solver.Tasks; + +internal static class SolverTask +{ + public static SolverTask FromResult(ConstraintSolver solver, T result) + { + var task = new SolverTask(); + task.Awaiter.Solver = solver; + task.Awaiter.SetResult(result, null); + return task; + } + + public static async SolverTask> WhenAll(IEnumerable> tasks) + { + var result = ImmutableArray.CreateBuilder(); + foreach (var task in tasks) result.Add(await task); + return result.ToImmutable(); + } +} + +[AsyncMethodBuilder(typeof(SolverTaskMethodBuilder<>))] +internal struct SolverTask +{ + internal SolverTaskAwaiter Awaiter; + internal readonly ConstraintSolver Solver => this.Awaiter.Solver; + public readonly bool IsCompleted => this.Awaiter.IsCompleted; + public readonly T Result => this.Awaiter.GetResult(); + public readonly SolverTaskAwaiter GetAwaiter() => this.Awaiter; +} diff --git a/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskAwaiter.cs b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskAwaiter.cs new file mode 100644 index 000000000..be823fcb8 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskAwaiter.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using Draco.Compiler.Internal.Binding.Tasks; +using Draco.Compiler.Internal.Solver; + +namespace Draco.Compiler.Internal.Solver.Tasks; + +internal struct SolverTaskAwaiter : INotifyCompletion, IBindingTaskAwaiter +{ + public bool IsCompleted { get; private set; } + public ConstraintSolver Solver { get; set; } + + private T? result; + private Exception? exception; + private List? completions; + + internal void SetResult(T? result, Exception? exception) + { + this.IsCompleted = true; + this.result = result; + this.exception = exception; + foreach (var completion in this.completions ?? Enumerable.Empty()) + { + completion(); + } + } + + public readonly T GetResult() + { + if (this.exception is not null) throw this.exception; + return this.result!; + } + + public void OnCompleted(Action completion) + { + if (this.IsCompleted) + { + completion(); + } + else + { + this.completions ??= new(); + this.completions.Add(completion); + } + } +} diff --git a/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskCompletionSource.cs b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskCompletionSource.cs new file mode 100644 index 000000000..1e3301a55 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskCompletionSource.cs @@ -0,0 +1,31 @@ +using System; +using Draco.Compiler.Internal.Solver; + +namespace Draco.Compiler.Internal.Solver.Tasks; + +internal sealed class SolverTaskCompletionSource +{ + public SolverTask Task + { + get + { + var task = new SolverTask(); + task.Awaiter = this.Awaiter; + return task; + } + } + public bool IsCompleted => this.Awaiter.IsCompleted; + public T Result => this.Awaiter.GetResult(); + internal ConstraintSolver Solver => this.Awaiter.Solver; + + internal SolverTaskAwaiter Awaiter; + + internal SolverTaskCompletionSource(ConstraintSolver solver) + { + this.Awaiter.Solver = solver; + } + + public SolverTaskAwaiter GetAwaiter() => this.Awaiter; + public void SetResult(T result) => this.Awaiter.SetResult(result, null); + public void SetException(Exception exception) => this.Awaiter.SetResult(default, exception); +} diff --git a/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskMethodBuilder.cs b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskMethodBuilder.cs new file mode 100644 index 000000000..6effe3652 --- /dev/null +++ b/src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskMethodBuilder.cs @@ -0,0 +1,48 @@ +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using Draco.Compiler.Internal.Binding.Tasks; + +namespace Draco.Compiler.Internal.Solver.Tasks; + +internal sealed class SolverTaskMethodBuilder +{ + public SolverTask Task => this.task; + private SolverTask task; + + public static SolverTaskMethodBuilder Create() => new(); + + public void Start(ref TStateMachine stateMachine) + where TStateMachine : IAsyncStateMachine => stateMachine.MoveNext(); + + public void SetStateMachine(IAsyncStateMachine _) => Debug.Fail("Unused"); + + public void SetException(Exception exception) => this.Task.Awaiter.SetResult(default, exception); + public void SetResult(T result) => this.Task.Awaiter.SetResult(result, null); + + public void AwaitOnCompleted( + ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : INotifyCompletion + where TStateMachine : IAsyncStateMachine + { + if (awaiter is not IBindingTaskAwaiter syncAwaiter) + { + throw new NotSupportedException("Only supporting BindingTask."); + } + this.task.Awaiter.Solver = syncAwaiter.Solver; + awaiter.OnCompleted(stateMachine.MoveNext); + } + + public void AwaitUnsafeOnCompleted( + ref TAwaiter awaiter, ref TStateMachine stateMachine) + where TAwaiter : ICriticalNotifyCompletion + where TStateMachine : IAsyncStateMachine + { + if (awaiter is not IBindingTaskAwaiter syncAwaiter) + { + throw new NotSupportedException("Only supporting BindingTask."); + } + this.task.Awaiter.Solver = syncAwaiter.Solver; + awaiter.OnCompleted(stateMachine.MoveNext); + } +}