Skip to content

Commit

Permalink
Generate constant buffer native layout mapping type
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Sep 23, 2023
1 parent 0437590 commit ee2e7fd
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using System;
using System.Collections.Immutable;
using ComputeSharp.D2D1.__Internals;
using ComputeSharp.D2D1.SourceGenerators.Models;
using ComputeSharp.SourceGeneration.Helpers;
using ComputeSharp.SourceGeneration.Mappings;
using ComputeSharp.SourceGeneration.Models;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
Expand All @@ -19,10 +23,22 @@ partial class LoadDispatchData
/// <summary>
/// Creates a <see cref="MethodDeclarationSyntax"/> instance for the <c>LoadDispatchDataMethod</c> method.
/// </summary>
/// <param name="hierarchyInfo">The hiararchy info of the shader type.</param>
/// <param name="dispatchInfo">The dispatch info gathered for the current shader.</param>
/// <param name="additionalTypes">Any additional <see cref="TypeDeclarationSyntax"/> instances needed by the generated code, if needed.</param>
/// <returns>The resulting <see cref="MethodDeclarationSyntax"/> instance for the <c>LoadDispatchDataMethod</c> method.</returns>
public static MethodDeclarationSyntax GetSyntax(DispatchDataInfo dispatchInfo)
public static MethodDeclarationSyntax GetSyntax(HierarchyInfo hierarchyInfo, DispatchDataInfo dispatchInfo, out TypeDeclarationSyntax[] additionalTypes)
{
// Declare the mapping constant buffer type, if needed (ie. if the shader has at least one field)
if (dispatchInfo.FieldInfos.Length == 0)
{
additionalTypes = Array.Empty<TypeDeclarationSyntax>();
}
else
{
additionalTypes = new[] { GetConstantBufferDeclaration(hierarchyInfo, dispatchInfo.FieldInfos, dispatchInfo.ConstantBufferSizeInBytes) };
}

// This code produces a method declaration as follows:
//
// readonly void global::ComputeSharp.D2D1.__Internals.ID2D1Shader.LoadDispatchData<TLoader>(ref TLoader loader)
Expand All @@ -38,6 +54,122 @@ public static MethodDeclarationSyntax GetSyntax(DispatchDataInfo dispatchInfo)
.WithBody(Block(GetDispatchDataLoadingStatements(dispatchInfo.FieldInfos, dispatchInfo.ConstantBufferSizeInBytes)));
}

/// <summary>
/// Gets a type definition to map the constant buffer of a given shader type.
/// </summary>
/// <param name="hierarchyInfo">The hiararchy info of the shader type.</param>
/// <param name="fieldInfos">The array of <see cref="FieldInfo"/> values for all captured fields.</param>
/// <param name="constantBufferSizeInBytes">The size of the shader constant buffer.</param>
/// <returns>The <see cref="TypeDeclarationSyntax"/> object for the mapped constant buffer for the current shader type.</returns>
private static TypeDeclarationSyntax GetConstantBufferDeclaration(HierarchyInfo hierarchyInfo, ImmutableArray<FieldInfo> fieldInfos, int constantBufferSizeInBytes)
{
string fullyQualifiedTypeName = hierarchyInfo.GetFullyQualifiedTypeName();

using ImmutableArrayBuilder<FieldDeclarationSyntax> fieldDeclarations = ImmutableArrayBuilder<FieldDeclarationSyntax>.Rent();

// Appends a new field declaration for a constant buffer field:
//
// <COMMENT>
// [global::System.Runtime.InteropServices.FieldOffset(<FIELD_OFFSET>)]
// public <FIELD_TYPE> <FIELD_NAME>;
void AppendFieldDeclaration(
string comment,
TypeSyntax typeIdentifier,
string identifierName,
int fieldOffset)
{
fieldDeclarations.Add(
FieldDeclaration(
VariableDeclaration(typeIdentifier)
.AddVariables(VariableDeclarator(Identifier(identifierName))))
.AddAttributeLists(AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("global::System.Runtime.InteropServices.FieldOffset"))
.AddArgumentListArguments(AttributeArgument(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(fieldOffset)))))))
.AddModifiers(Token(SyntaxKind.PublicKeyword))
.WithLeadingTrivia(Comment(comment)));
}

// Declare fields for every mapped item from the shader layout
foreach (FieldInfo fieldInfo in fieldInfos)
{
switch (fieldInfo)
{
case FieldInfo.Primitive { TypeName: "System.Boolean" } primitive:

// Append a field as a global::ComputeSharp.Bool value (will use the implicit conversion from bool values)
AppendFieldDeclaration(
comment: $"""/// <inheritdoc cref="{fullyQualifiedTypeName}.{string.Join(".", primitive.FieldPath)}"/>""",
typeIdentifier: IdentifierName("global::ComputeSharp.Bool"),
identifierName: string.Join("_", primitive.FieldPath),
fieldOffset: primitive.Offset);
break;
case FieldInfo.Primitive primitive:

// Append primitive fields of other types with their mapped names
AppendFieldDeclaration(
comment: $"""/// <inheritdoc cref="{fullyQualifiedTypeName}.{string.Join(".", primitive.FieldPath)}"/>""",
typeIdentifier: IdentifierName(HlslKnownTypes.GetMappedName(primitive.TypeName)),
identifierName: string.Join("_", primitive.FieldPath),
fieldOffset: primitive.Offset);
break;

case FieldInfo.NonLinearMatrix matrix:
string rowTypeName = HlslKnownTypes.GetMappedName($"ComputeSharp.{matrix.ElementName}{matrix.Columns}");
string fieldNamePrefix = string.Join("_", matrix.FieldPath);

// Declare a field for every row of the matrix type
for (int j = 0; j < matrix.Rows; j++)
{
AppendFieldDeclaration(
comment: $"""/// <summary>Row {j} of <see cref="{fullyQualifiedTypeName}.{string.Join(".", matrix.FieldPath)}"/>.</summary>""",
typeIdentifier: IdentifierName(rowTypeName),
identifierName: $"{fieldNamePrefix}_{j}",
fieldOffset: matrix.Offsets[j]);
}

break;
}
}

// Create the constant buffer type:
//
// /// <summary>
// /// A type representing the constant buffer native layout for <see cref="<SHADER_TYPE"/>.
// /// </summary>
// [global::System.CodeDom.Compiler.GeneratedCode("...", "...")]
// [global::System.Diagnostics.DebuggerNonUserCode]
// [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
// [global::System.Runtime.InteropServices.StructLayout(global::System.Runtime.InteropServices.LayoutKind.Explicit, Size = <CONSTANT_BUFFER_SIZE>)]
// file struct ConstantBuffer
// {
// <FIELD_DECLARATIONS>
// }
return
StructDeclaration("ConstantBuffer")
.AddModifiers(Token(SyntaxKind.FileKeyword))
.AddAttributeLists(
AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")).AddArgumentListArguments(
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ID2D1ShaderGenerator).FullName))),
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ID2D1ShaderGenerator).Assembly.GetName().Version.ToString())))))),
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))),
AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))),
AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("global::System.Runtime.InteropServices.StructLayout")).AddArgumentListArguments(
AttributeArgument(MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("global::System.Runtime.InteropServices.LayoutKind"),
IdentifierName("Explicit"))),
AttributeArgument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(constantBufferSizeInBytes)))
.WithNameEquals(NameEquals(IdentifierName("Size")))))))
.AddMembers(fieldDeclarations.ToArray())
.WithLeadingTrivia(
Comment("/// <summary>"),
Comment($"""/// A type representing the constant buffer native layout for <see cref="{fullyQualifiedTypeName}"/>."""),
Comment("/// </summary>"));
}

/// <summary>
/// Gets a sequence of statements to load the dispatch data for a given shader.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
// Generate the LoadDispatchData() methods
context.RegisterSourceOutput(dispatchDataInfo, static (context, item) =>
{
MethodDeclarationSyntax loadDispatchDataMethod = LoadDispatchData.GetSyntax(item.Info.Dispatch);
CompilationUnitSyntax compilationUnit = GetCompilationUnitFromMethod(item.Info.Hierarchy, loadDispatchDataMethod, item.CanUseSkipLocalsInit);
MethodDeclarationSyntax loadDispatchDataMethod = LoadDispatchData.GetSyntax(item.Info.Hierarchy, item.Info.Dispatch, out TypeDeclarationSyntax[] additionalTypes);
CompilationUnitSyntax compilationUnit = GetCompilationUnitFromMethod(item.Info.Hierarchy, loadDispatchDataMethod, item.CanUseSkipLocalsInit, additionalMemberDeclarations: additionalTypes);
context.AddSource($"{item.Info.Hierarchy.FullyQualifiedMetadataName}.{nameof(LoadDispatchData)}.g.cs", compilationUnit.GetText(Encoding.UTF8));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,18 @@ public static string GetMappedElementName(IArrayTypeSymbol typeSymbol)
/// Gets the mapped HLSL-compatible type name for the input type name.
/// </summary>
/// <param name="originalName">The input type name to map.</param>
/// <param name="mappedName">The resulting mapped type name, if found.</param>
/// <returns>The HLSL-compatible type name that can be used in an HLSL shader.</returns>
public static string GetMappedName(string originalName)
{
return KnownHlslTypes[originalName];
}

/// <summary>
/// Tries to get the mapped HLSL-compatible type name for the input type name.
/// </summary>
/// <param name="originalName">The input type name to map.</param>
/// <param name="mappedName">The resulting mapped type name, if found.</param>
/// <returns>Whether a mapped name was available.</returns>
public static bool TryGetMappedName(string originalName, out string? mappedName)
{
return KnownHlslTypes.TryGetValue(originalName, out mappedName);
Expand Down

0 comments on commit ee2e7fd

Please sign in to comment.