Skip to content

Commit

Permalink
Generate input types as ReadOnlyMemory<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Sep 20, 2023
1 parent 86dc363 commit 2145a92
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
<Compile Include="..\ComputeSharp.D2D1\Interfaces\__Internals\ID2D1Shader.cs" Link="ComputeSharp.D2D1\Interfaces\__Internals\ID2D1Shader.cs" />
<Compile Include="..\ComputeSharp.D2D1\Intrinsics\D2D.cs" Link="ComputeSharp.D2D1\Intrinsics\D2D.cs" />
<Compile Include="..\ComputeSharp.D2D1\Shaders\Exceptions\FxcCompilationException.cs" Link="ComputeSharp.D2D1\Shaders\Exceptions\FxcCompilationException.cs" />
<Compile Include="..\ComputeSharp.D2D1\Shaders\Interop\D2D1PixelShaderInputType.cs" Link="ComputeSharp.D2D1\Shaders\Interop\D2D1PixelShaderInputType.cs" />
<Compile Include="..\ComputeSharp.D2D1\Shaders\Translation\D3DCompiler.cs" Link="ComputeSharp.D2D1\Shaders\Translation\D3DCompiler.cs" />
<Compile Include="..\ComputeSharp.D2D1\Shaders\Translation\D3DCompiler.ID3DInclude.cs" Link="ComputeSharp.D2D1\Shaders\Translation\D3DCompiler.ID3DInclude.cs" />
<Compile Include="..\ComputeSharp.D2D1\Shaders\Translation\Headers\d2d1effecthelpers.hlsli.cs" Link="ComputeSharp.D2D1\Shaders\Translation\Headers\d2d1effecthelpers.hlsli.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Collections.Immutable;
using ComputeSharp.D2D1.__Internals;
using ComputeSharp.SourceGeneration.Helpers;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
Expand All @@ -18,21 +20,218 @@ partial class GetInputType
/// Creates a <see cref="MethodDeclarationSyntax"/> instance for the <c>GetInputTypeMethod</c> method.
/// </summary>
/// <param name="inputTypes">The input types for the shader.</param>
/// <returns>The resulting <see cref="MethodDeclarationSyntax"/> instance for the <c>GetInputTypeMethod</c> method.</returns>
public static MethodDeclarationSyntax GetSyntax(ImmutableArray<uint> inputTypes)
/// <returns>The resulting <see cref="MethodDeclarationSyntax"/> instance for the <c
/// >GetInputTypeMethod</c> method.</returns>
public static (MemberDeclarationSyntax Member, MemberDeclarationSyntax Type) GetSyntax(ImmutableArray<uint> inputTypes)
{
// This code produces a method declaration as follows:
//
// readonly uint global::ComputeSharp.D2D1.__Internals.ID2D1Shader.GetInputType(uint index)
// {
// return <INPUT_TYPE>;
// }
return
MethodDeclaration(PredefinedType(Token(SyntaxKind.UIntKeyword)), Identifier(nameof(GetInputType)))
// readonly uint global::ComputeSharp.D2D1.__Internals.ID2D1Shader.InputTypes => InputTypesMemoryManager.Memory;
MemberDeclarationSyntax member =
PropertyDeclaration(
GenericName(Identifier("global::System.ReadOnlyMemory"))
.AddTypeArgumentListArguments(IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType")),
Identifier("InputTypes")) // TODO: use nameof()
.WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(IdentifierName($"global::ComputeSharp.D2D1.__Internals.{nameof(ID2D1Shader)}")))
.AddModifiers(Token(SyntaxKind.ReadOnlyKeyword))
.AddParameterListParameters(Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.UIntKeyword))))
.WithBody(Block(ReturnStatement(GetInputTypesSwitchExpression(inputTypes))));
.WithExpressionBody(ArrowExpressionClause(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("InputTypesMemoryManager"),
IdentifierName("Memory"))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken));

return (member, GetMemoryManagerDeclaration(inputTypes));
}

/// <summary>
/// Gets the memory manager declaration for the input type data.
/// </summary>
/// <param name="inputTypes">The input types for the shader.</param>
/// <returns>The memory manager declaration for the input type data.</returns>
private static TypeDeclarationSyntax GetMemoryManagerDeclaration(ImmutableArray<uint> inputTypes)
{
// Create the MemoryManager<T> declaration:
//
// /// <summary>
// /// <see cref="global::System.Buffers.MemoryManager{T}"/> implementation to get the input types.
// /// </summary>
// [global::System.CodeDom.Compiler.GeneratedCode("...", "...")]
// [global::System.Diagnostics.DebuggerNonUserCode]
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
// file sealed InputTypesMemoryManager : global::System.Buffers.MemoryManager<global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType>
// {
// }
TypeDeclarationSyntax typeDeclaration =
ClassDeclaration("InputTypesMemoryManager")
.AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.SealedKeyword))
.AddBaseListTypes(SimpleBaseType(
GenericName(Identifier("global::System.Buffers.MemoryManager"))
.AddTypeArgumentListArguments(IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType"))))
.AddAttributeLists(
AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")).AddArgumentListArguments(
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ID2D1ShaderGenerator).FullName))),
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ID2D1ShaderGenerator).Assembly.GetName().Version.ToString())))))),
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))),
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))))
.WithLeadingTrivia(
Comment("/// <summary>"),
Comment("""/// A <see cref="global::System.Buffers.MemoryManager{T}"/> implementation to get the input types."""),
Comment("/// </summary>"));

using ImmutableArrayBuilder<MemberDeclarationSyntax> memberDeclarations = ImmutableArrayBuilder<MemberDeclarationSyntax>.Rent();

// Declare the singleton property to get the memory instance:
//
// /// <summary>The singleton <see cref="global::System.ReadOnlyMemory{T}"/> instance for the memory manager.</summary>
// public static new readonly global::System.ReadOnlyMemory<global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType> Memory = new InputTypesMemoryManager().CreateMemory(<INPUT_COUNT>);
memberDeclarations.Add(
FieldDeclaration(
VariableDeclaration(
GenericName(Identifier("global::System.ReadOnlyMemory"))
.AddTypeArgumentListArguments(IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType")))
.AddVariables(
VariableDeclarator(Identifier("Memory"))
.WithInitializer(
EqualsValueClause(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ObjectCreationExpression(IdentifierName("InputTypesMemoryManager")).WithArgumentList(ArgumentList()),
IdentifierName("CreateMemory")))
.AddArgumentListArguments(Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(inputTypes.Length))))))))
.AddModifiers(
Token(SyntaxKind.PublicKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.NewKeyword),
Token(SyntaxKind.ReadOnlyKeyword))
.WithLeadingTrivia(Comment("""/// <summary>The singleton <see cref="global::System.ReadOnlyMemory{T}"/> instance for the memory manager.</summary>""")));

using (ImmutableArrayBuilder<ExpressionSyntax> inputTypeExpressions = ImmutableArrayBuilder<ExpressionSyntax>.Rent())
{
// Build the sequence of expressions for all input types
foreach (uint inputType in inputTypes)
{
inputTypeExpressions.Add(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType"),
IdentifierName(inputType == 0 ? "Simple" : "Complex")));
}

// Construct the RVA span property:
//
// /// <summary>The RVA data with the input type info.</summary>
// private static global::System.ReadOnlySpan<global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType> Data => new[] { <INPUT_TYPES> };
memberDeclarations.Add(
PropertyDeclaration(
GenericName(Identifier("global::System.ReadOnlySpan"))
.AddTypeArgumentListArguments(IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType")),
Identifier("Data"))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword))
.WithExpressionBody(
ArrowExpressionClause(
ImplicitArrayCreationExpression(
InitializerExpression(
SyntaxKind.ArrayInitializerExpression,
SeparatedList(inputTypeExpressions.ToArray())))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithLeadingTrivia(Comment("/// <summary>The RVA data with the input type info.</summary>")));
}

// Add the GetSpan() method:
//
// /// <inheritdoc/>
// public override unsafe global::System.Span<global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType> GetSpan
// {
// return new(global::System.Runtime.CompilerServices.Unsafe.AsPointer(ref global::System.Runtime.InteropServices.MemoryMarshal(Data)), <INPUT_COUNT>);
// }
memberDeclarations.Add(
MethodDeclaration(
GenericName(Identifier("global::System.Span"))
.AddTypeArgumentListArguments(IdentifierName("global::ComputeSharp.D2D1.Interop.D2D1PixelShaderInputType")),
Identifier("GetSpan"))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword), Token(SyntaxKind.UnsafeKeyword))
.AddBodyStatements(
ReturnStatement(
ImplicitObjectCreationExpression()
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.Runtime.CompilerServices.Unsafe"),
IdentifierName("AsPointer")))
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.Runtime.InteropServices.MemoryMarshal"),
IdentifierName("GetReference")))
.AddArgumentListArguments(Argument(IdentifierName("Data"))))
.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(inputTypes.Length))))))
.WithLeadingTrivia(Comment("/// <inheritdoc/>")));

// Add the Pin(int elementIndex) method:
//
// /// <inheritdoc/>
// public override unsafe global::System.Buffers.MemoryHandle Pin(int elementIndex)
// {
// return new(Unsafe.AsPointer(ref Unsafe.AsRef(in Data[elementIndex])), pinnable: this);
// }
memberDeclarations.Add(
MethodDeclaration(IdentifierName("global::System.Buffers.MemoryHandle"), Identifier("Pin"))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword), Token(SyntaxKind.UnsafeKeyword))
.AddParameterListParameters(Parameter(Identifier("elementIndex")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))))
.AddBodyStatements(
ReturnStatement(
ImplicitObjectCreationExpression()
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.Runtime.CompilerServices.Unsafe"),
IdentifierName("AsPointer")))
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.Runtime.CompilerServices.Unsafe"),
IdentifierName("AsRef")))
.AddArgumentListArguments(
Argument(
ElementAccessExpression(IdentifierName("Data"))
.AddArgumentListArguments(Argument(IdentifierName("elementIndex"))))
.WithRefOrOutKeyword(Token(SyntaxKind.InKeyword))))
.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)))),
Argument(ThisExpression()).WithNameColon(NameColon(IdentifierName("pinnable"))))))
.WithLeadingTrivia(Comment("/// <inheritdoc/>")));

// Add the empty Unpin() method
memberDeclarations.Add(
MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier("Unpin"))
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))
.WithBody(Block())
.WithLeadingTrivia(Comment("/// <inheritdoc/>")));

// Add the empty Dispose(bool disposing) method
memberDeclarations.Add(
MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier("Dispose"))
.AddModifiers(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword))
.AddParameterListParameters(
Parameter(Identifier("disposing")).WithType(PredefinedType(Token(SyntaxKind.BoolKeyword))))
.WithBody(Block())
.WithLeadingTrivia(Comment("/// <inheritdoc/>")));

return typeDeclaration.AddMembers(memberDeclarations.ToArray());
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ public static bool IsD2D1PixelShaderType(INamedTypeSymbol typeSymbol, Compilatio
/// <param name="hierarchyInfo">The <see cref="HierarchyInfo"/> instance for the current type.</param>
/// <param name="methodDeclaration">The <see cref="MethodDeclarationSyntax"/> item to insert.</param>
/// <param name="canUseSkipLocalsInit">Whether <c>[SkipLocalsInit]</c> can be used.</param>
/// <param name="additionalMemberDeclarations">Additional member declarations to also emit, if any.</param>
/// <returns>A <see cref="CompilationUnitSyntax"/> object wrapping <paramref name="methodDeclaration"/>.</returns>
private static CompilationUnitSyntax GetCompilationUnitFromMethod(
HierarchyInfo hierarchyInfo,
MethodDeclarationSyntax methodDeclaration,
bool canUseSkipLocalsInit)
MemberDeclarationSyntax methodDeclaration,
bool canUseSkipLocalsInit,
params MemberDeclarationSyntax[] additionalMemberDeclarations)
{
// Method attributes
List<AttributeListSyntax> attributes = new()
Expand All @@ -71,6 +73,6 @@ private static CompilationUnitSyntax GetCompilationUnitFromMethod(
attributes.Add(AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Runtime.CompilerServices.SkipLocalsInit")))));
}

return hierarchyInfo.GetSyntax(methodDeclaration.AddAttributeLists(attributes.ToArray()));
return hierarchyInfo.GetSyntax(new MemberDeclarationSyntax[] { methodDeclaration.AddAttributeLists(attributes.ToArray()) }, additionalMemberDeclarations);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
// Generate the GetInputType() methods
context.RegisterSourceOutput(inputTypesInfo, static (context, item) =>
{
MethodDeclarationSyntax getInputTypeMethod = GetInputType.GetSyntax(item.InputTypes.InputTypes);
CompilationUnitSyntax compilationUnit = GetCompilationUnitFromMethod(item.Hierarchy, getInputTypeMethod, canUseSkipLocalsInit: false);
(MemberDeclarationSyntax property, MemberDeclarationSyntax type) = GetInputType.GetSyntax(item.InputTypes.InputTypes);
CompilationUnitSyntax compilationUnit = GetCompilationUnitFromMethod(item.Hierarchy, property, canUseSkipLocalsInit: false, additionalMemberDeclarations: type);
context.AddSource($"{item.Hierarchy.FullyQualifiedMetadataName}.{nameof(GetInputType)}.g.cs", compilationUnit.GetText(Encoding.UTF8));
});
Expand Down
Loading

0 comments on commit 2145a92

Please sign in to comment.