Skip to content

Commit

Permalink
Added solver tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
LPeter1997 committed Nov 4, 2023
1 parent bb80f03 commit d4cab88
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/Draco.Compiler/Internal/Solver/Constraint.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using Draco.Compiler.Api.Diagnostics;
using Draco.Compiler.Internal.Binding.Tasks;
using Draco.Compiler.Internal.Solver.Tasks;

namespace Draco.Compiler.Internal.Solver;

Expand All @@ -10,7 +11,7 @@ namespace Draco.Compiler.Internal.Solver;
/// <typeparam name="TResult">The result type.</typeparam>
internal abstract class Constraint<TResult> : IConstraint<TResult>
{
public BindingTaskCompletionSource<TResult> CompletionSource { get; }
public SolverTaskCompletionSource<TResult> CompletionSource { get; }
public ConstraintLocator Locator { get; }

protected Constraint(ConstraintSolver solver, ConstraintLocator locator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using Draco.Compiler.Api.Syntax;
using Draco.Compiler.Internal.Binding.Tasks;
using Draco.Compiler.Internal.Solver.Tasks;
using Draco.Compiler.Internal.Symbols;
using Draco.Compiler.Internal.Utilities;

Expand Down Expand Up @@ -72,7 +73,7 @@ public bool TryDequeue<TConstraint>(
/// <param name="second">The type that is constrained to be the same as <paramref name="first"/>.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise for the constraint added.</returns>
public BindingTask<Unit> SameType(TypeSymbol first, TypeSymbol second, SyntaxNode syntax)
public SolverTask<Unit> SameType(TypeSymbol first, TypeSymbol second, SyntaxNode syntax)
{
var constraint = new SameTypeConstraint(this, ImmutableArray.Create(first, second), ConstraintLocator.Syntax(syntax));
this.Add(constraint);
Expand All @@ -86,7 +87,7 @@ public BindingTask<Unit> SameType(TypeSymbol first, TypeSymbol second, SyntaxNod
/// <param name="assignedType">The type assigned.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise for the constraint added.</returns>
public BindingTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedType, SyntaxNode syntax) =>
public SolverTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedType, SyntaxNode syntax) =>
this.Assignable(targetType, assignedType, ConstraintLocator.Syntax(syntax));

/// <summary>
Expand All @@ -96,7 +97,7 @@ public BindingTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedTy
/// <param name="assignedType">The type assigned.</param>
/// <param name="locator">The locator for the constraint.</param>
/// <returns>The promise for the constraint added.</returns>
public BindingTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedType, ConstraintLocator locator)
public SolverTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedType, ConstraintLocator locator)
{
var constraint = new AssignableConstraint(this, targetType, assignedType, locator);
this.Add(constraint);
Expand All @@ -110,7 +111,7 @@ public BindingTask<Unit> Assignable(TypeSymbol targetType, TypeSymbol assignedTy
/// <param name="alternativeTypes">The alternative types to find the common type of.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise of the constraint added.</returns>
public BindingTask<Unit> CommonType(
public SolverTask<Unit> CommonType(
TypeSymbol commonType,
ImmutableArray<TypeSymbol> alternativeTypes,
SyntaxNode syntax) => this.CommonType(commonType, alternativeTypes, ConstraintLocator.Syntax(syntax));
Expand All @@ -122,7 +123,7 @@ public BindingTask<Unit> CommonType(
/// <param name="alternativeTypes">The alternative types to find the common type of.</param>
/// <param name="locator">The locator for this constraint.</param>
/// <returns>The promise of the constraint added.</returns>
public BindingTask<Unit> CommonType(
public SolverTask<Unit> CommonType(
TypeSymbol commonType,
ImmutableArray<TypeSymbol> alternativeTypes,
ConstraintLocator locator)
Expand All @@ -140,7 +141,7 @@ public BindingTask<Unit> CommonType(
/// <param name="memberType">The type of the member.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise of the accessed member symbol.</returns>
public BindingTask<Symbol> Member(
public SolverTask<Symbol> Member(
TypeSymbol accessedType,
string memberName,
out TypeSymbol memberType,
Expand All @@ -160,7 +161,7 @@ public BindingTask<Symbol> Member(
/// <param name="returnType">The return type.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise of the constraint.</returns>
public BindingTask<Unit> Call(
public SolverTask<Unit> Call(
TypeSymbol calledType,
ImmutableArray<object> args,
out TypeSymbol returnType,
Expand All @@ -181,7 +182,7 @@ public BindingTask<Unit> Call(
/// <param name="returnType">The return type of the call.</param>
/// <param name="syntax">The syntax that the constraint originates from.</param>
/// <returns>The promise for the resolved overload.</returns>
public BindingTask<FunctionSymbol> Overload(
public SolverTask<FunctionSymbol> Overload(
string name,
ImmutableArray<FunctionSymbol> functions,
ImmutableArray<object> args,
Expand Down
3 changes: 2 additions & 1 deletion src/Draco.Compiler/Internal/Solver/IConstraint.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Draco.Compiler.Internal.Binding.Tasks;
using Draco.Compiler.Internal.Solver.Tasks;

namespace Draco.Compiler.Internal.Solver;

Expand All @@ -22,5 +23,5 @@ internal interface IConstraint<TResult> : IConstraint
/// <summary>
/// The completion source of this constraint.
/// </summary>
public BindingTaskCompletionSource<TResult> CompletionSource { get; }
public SolverTaskCompletionSource<TResult> CompletionSource { get; }
}
34 changes: 34 additions & 0 deletions src/Draco.Compiler/Internal/Solver/Tasks/SolverTask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using Draco.Compiler.Internal.Solver;

namespace Draco.Compiler.Internal.Solver.Tasks;

internal static class SolverTask
{
public static SolverTask<T> FromResult<T>(ConstraintSolver solver, T result)
{
var task = new SolverTask<T>();
task.Awaiter.Solver = solver;
task.Awaiter.SetResult(result, null);
return task;
}

public static async SolverTask<ImmutableArray<T>> WhenAll<T>(IEnumerable<SolverTask<T>> tasks)
{
var result = ImmutableArray.CreateBuilder<T>();
foreach (var task in tasks) result.Add(await task);
return result.ToImmutable();
}
}

[AsyncMethodBuilder(typeof(SolverTaskMethodBuilder<>))]
internal struct SolverTask<T>
{
internal SolverTaskAwaiter<T> Awaiter;
internal readonly ConstraintSolver Solver => this.Awaiter.Solver;
public readonly bool IsCompleted => this.Awaiter.IsCompleted;
public readonly T Result => this.Awaiter.GetResult();
public readonly SolverTaskAwaiter<T> GetAwaiter() => this.Awaiter;
}
48 changes: 48 additions & 0 deletions src/Draco.Compiler/Internal/Solver/Tasks/SolverTaskAwaiter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using Draco.Compiler.Internal.Binding.Tasks;
using Draco.Compiler.Internal.Solver;

namespace Draco.Compiler.Internal.Solver.Tasks;

internal struct SolverTaskAwaiter<T> : INotifyCompletion, IBindingTaskAwaiter
{
public bool IsCompleted { get; private set; }
public ConstraintSolver Solver { get; set; }

private T? result;
private Exception? exception;
private List<Action>? completions;

internal void SetResult(T? result, Exception? exception)
{
this.IsCompleted = true;
this.result = result;
this.exception = exception;
foreach (var completion in this.completions ?? Enumerable.Empty<Action>())
{
completion();
}
}

public readonly T GetResult()
{
if (this.exception is not null) throw this.exception;
return this.result!;
}

public void OnCompleted(Action completion)
{
if (this.IsCompleted)
{
completion();
}
else
{
this.completions ??= new();
this.completions.Add(completion);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System;
using Draco.Compiler.Internal.Solver;

namespace Draco.Compiler.Internal.Solver.Tasks;

internal sealed class SolverTaskCompletionSource<T>
{
public SolverTask<T> Task
{
get
{
var task = new SolverTask<T>();
task.Awaiter = this.Awaiter;
return task;
}
}
public bool IsCompleted => this.Awaiter.IsCompleted;
public T Result => this.Awaiter.GetResult();
internal ConstraintSolver Solver => this.Awaiter.Solver;

internal SolverTaskAwaiter<T> Awaiter;

internal SolverTaskCompletionSource(ConstraintSolver solver)
{
this.Awaiter.Solver = solver;
}

public SolverTaskAwaiter<T> GetAwaiter() => this.Awaiter;
public void SetResult(T result) => this.Awaiter.SetResult(result, null);
public void SetException(Exception exception) => this.Awaiter.SetResult(default, exception);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Draco.Compiler.Internal.Binding.Tasks;

namespace Draco.Compiler.Internal.Solver.Tasks;

internal sealed class SolverTaskMethodBuilder<T>
{
public SolverTask<T> Task => this.task;
private SolverTask<T> task;

public static SolverTaskMethodBuilder<T> Create() => new();

public void Start<TStateMachine>(ref TStateMachine stateMachine)
where TStateMachine : IAsyncStateMachine => stateMachine.MoveNext();

public void SetStateMachine(IAsyncStateMachine _) => Debug.Fail("Unused");

public void SetException(Exception exception) => this.Task.Awaiter.SetResult(default, exception);
public void SetResult(T result) => this.Task.Awaiter.SetResult(result, null);

public void AwaitOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
{
if (awaiter is not IBindingTaskAwaiter syncAwaiter)
{
throw new NotSupportedException("Only supporting BindingTask.");
}
this.task.Awaiter.Solver = syncAwaiter.Solver;
awaiter.OnCompleted(stateMachine.MoveNext);
}

public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
{
if (awaiter is not IBindingTaskAwaiter syncAwaiter)
{
throw new NotSupportedException("Only supporting BindingTask.");
}
this.task.Awaiter.Solver = syncAwaiter.Solver;
awaiter.OnCompleted(stateMachine.MoveNext);
}
}

0 comments on commit d4cab88

Please sign in to comment.