Skip to content

Commit

Permalink
Use generated constant buffer type in LoadDispatchData
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Sep 23, 2023
1 parent 5eac1c2 commit 53becce
Showing 1 changed file with 57 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ public static MethodDeclarationSyntax GetSyntax(HierarchyInfo hierarchyInfo, Dis

// This code produces a method declaration as follows:
//
// readonly void global::ComputeSharp.D2D1.__Internals.ID2D1Shader.LoadDispatchData<TLoader>(ref TLoader loader)
// readonly unsafe void global::ComputeSharp.D2D1.__Internals.ID2D1Shader.LoadDispatchData<TLoader>(ref TLoader loader)
// {
// <BODY>
// }
return
MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(nameof(LoadDispatchData)))
.WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier(IdentifierName($"global::ComputeSharp.D2D1.__Internals.{nameof(ID2D1Shader)}")))
.AddModifiers(Token(SyntaxKind.ReadOnlyKeyword))
.AddModifiers(Token(SyntaxKind.ReadOnlyKeyword), Token(SyntaxKind.UnsafeKeyword))
.AddTypeParameterListParameters(TypeParameter(Identifier("TLoader")))
.AddParameterListParameters(Parameter(Identifier("loader")).AddModifiers(Token(SyntaxKind.RefKeyword)).WithType(IdentifierName("TLoader")))
.WithBody(Block(GetDispatchDataLoadingStatements(dispatchInfo.FieldInfos, dispatchInfo.ConstantBufferSizeInBytes)));
Expand Down Expand Up @@ -178,10 +178,11 @@ void AppendFieldDeclaration(
/// <returns>The sequence of <see cref="StatementSyntax"/> instances to load the shader dispatch data.</returns>
private static ImmutableArray<StatementSyntax> GetDispatchDataLoadingStatements(ImmutableArray<FieldInfo> fieldInfos, int constantBufferSizeInBytes)
{
// If there are no fields, just load an empty buffer
// If there are no fields, just load an empty buffer:
//
// loader.LoadConstantBuffer(default);
if (fieldInfos.IsEmpty)
{
// loader.LoadConstantBuffer(default);
return
ImmutableArray.Create<StatementSyntax>(
ExpressionStatement(
Expand All @@ -198,103 +199,89 @@ private static ImmutableArray<StatementSyntax> GetDispatchDataLoadingStatements(

using ImmutableArrayBuilder<StatementSyntax> statements = ImmutableArrayBuilder<StatementSyntax>.Rent();

// global::System.Span<byte> data = stackalloc byte[<CONSTANT_BUFFER_SIZE>];
statements.Add(
LocalDeclarationStatement(
VariableDeclaration(
GenericName(Identifier("global::System.Span"))
.AddTypeArgumentListArguments(PredefinedType(Token(SyntaxKind.ByteKeyword))))
.AddVariables(
VariableDeclarator(Identifier("data"))
.WithInitializer(EqualsValueClause(
StackAllocArrayCreationExpression(
ArrayType(PredefinedType(Token(SyntaxKind.ByteKeyword)))
.AddRankSpecifiers(
ArrayRankSpecifier(SingletonSeparatedList<ExpressionSyntax>(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(constantBufferSizeInBytes)))))))))));

// ref byte r0 = ref data[0];
// ConstantBuffer data;
statements.Add(
LocalDeclarationStatement(
VariableDeclaration(RefType(PredefinedType(Token(SyntaxKind.ByteKeyword))))
.AddVariables(
VariableDeclarator(Identifier("r0"))
.WithInitializer(EqualsValueClause(
RefExpression(
ElementAccessExpression(IdentifierName("data"))
.AddArgumentListArguments(Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(0))))))))));
VariableDeclaration(IdentifierName("ConstantBuffer"))
.AddVariables(VariableDeclarator(Identifier("data")))));

// Generate loading statements for each captured field
foreach (FieldInfo fieldInfo in fieldInfos)
{
switch (fieldInfo)
{
case FieldInfo.Primitive { TypeName: "System.Boolean" } primitive:

// Read a boolean value and cast it to Bool first, which will apply the correct size expansion. This will generate the following:
//
// global::System.Runtime.CompilerServices.Unsafe.As<byte, global::ComputeSharp.Bool>(
// ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint)<OFFSET>)) = (global::ComputeSharp.Bool)<FIELD_PATH>
statements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ParseExpression($"global::System.Runtime.CompilerServices.Unsafe.As<byte, global::ComputeSharp.Bool>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){primitive.Offset}))"),
ParseExpression($"(global::ComputeSharp.Bool){string.Join(".", primitive.FieldPath)}"))));
break;
case FieldInfo.Primitive primitive:

// Read a primitive value and serialize it into the target buffer. This will generate:
// Assign a primitive value:
//
// global::System.Runtime.CompilerServices.Unsafe.As<byte, global::<TYPE_NAME>>(
// ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint)<OFFSET>)) = <FIELD_PATH>
statements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ParseExpression($"global::System.Runtime.CompilerServices.Unsafe.As<byte, global::{primitive.TypeName}>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){primitive.Offset}))"),
ParseExpression($"{string.Join(".", primitive.FieldPath)}"))));
// data.<CONSTANT_BUFFER_PATH> = this.<FIELD_PATH>;
statements.Add(
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("data"),
IdentifierName(string.Join("_", primitive.FieldPath))),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ThisExpression(),
IdentifierName(string.Join(".", primitive.FieldPath))))));
break;

case FieldInfo.NonLinearMatrix matrix:
string rowTypeName = $"global::ComputeSharp.{matrix.ElementName}{matrix.Columns}";
string rowLocalName = $"__{string.Join("_", matrix.FieldPath)}__row0";

// Declare a local to index into individual rows. This will generate:
//
// ref <ROW_TYPE> <ROW_NAME> = ref global::System.Runtime.CompilerServices.Unsafe.As<global::<TYPE_NAME>, <ROW_TYPE_NAME>>(
// ref global::System.Runtime.CompilerServices.Unsafe.AsRef(in <FIELD_PATH>));
statements.Add(ParseStatement($"ref {rowTypeName} {rowLocalName} = ref global::System.Runtime.CompilerServices.Unsafe.As<global::{matrix.TypeName}, {rowTypeName}>(ref global::System.Runtime.CompilerServices.Unsafe.AsRef(in {string.Join(".", matrix.FieldPath)}));"));
string fieldPath = string.Join(".", matrix.FieldPath);
string fieldNamePrefix = string.Join("_", matrix.FieldPath);

// Generate the loading code for each individual row, with proper alignment.
// This will result in the following (assuming Float2x3 m):
// Assign all rows of a given matrix type:
//
// ref global::ComputeSharp.Float3 __m__row0 = ref global::System.Runtime.CompilerServices.Unsafe.As<global::ComputeSharp.Float2x3, global::ComputeSharp.Float3>(ref global::System.Runtime.CompilerServices.Unsafe.AsRef(in m));
// global::System.Runtime.CompilerServices.Unsafe.As<byte, global::ComputeSharp.Float3>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint)rawDataOffset)) = global::System.Runtime.CompilerServices.Unsafe.Add(ref __m__row0, 0);
// global::System.Runtime.CompilerServices.Unsafe.As<byte, global::ComputeSharp.Float3>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint)(rawDataOffset + 16))) = global::System.Runtime.CompilerServices.Unsafe.Add(ref __m__row0, 1);
// data.<CONSTANT_BUFFER_ROW_0_PATH> = this.<FIELD_PATH>[0];
// data.<CONSTANT_BUFFER_ROW_1_PATH> = this.<FIELD_PATH>[1];
// ...
// data.<CONSTANT_BUFFER_ROW_N_PATH> = this.<FIELD_PATH>[N];
for (int j = 0; j < matrix.Rows; j++)
{
statements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ParseExpression($"global::System.Runtime.CompilerServices.Unsafe.As<byte, {rowTypeName}>(ref global::System.Runtime.CompilerServices.Unsafe.AddByteOffset(ref r0, (nint){matrix.Offsets[j]}))"),
ParseExpression($"global::System.Runtime.CompilerServices.Unsafe.Add(ref {rowLocalName}, {j})"))));
statements.Add(
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("data"),
IdentifierName($"{fieldNamePrefix}_{j}")),
ElementAccessExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ThisExpression(),
IdentifierName(fieldPath)))
.AddArgumentListArguments(
Argument(LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(j)))))));
}

break;
}
}

// loader.LoadConstantBuffer(data);
// loader.LoadConstantBuffer(new global::System.ReadOnlySpan<byte>(&data, sizeof(ConstantBuffer)));
statements.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("loader"),
IdentifierName("LoadConstantBuffer")))
.AddArgumentListArguments(Argument(IdentifierName("data")))));
.AddArgumentListArguments(Argument(
ObjectCreationExpression(
GenericName(Identifier("global::System.ReadOnlySpan"))
.AddTypeArgumentListArguments(PredefinedType(Token(SyntaxKind.ByteKeyword))))
.AddArgumentListArguments(
Argument(
PrefixUnaryExpression(
SyntaxKind.AddressOfExpression,
IdentifierName("data"))),
Argument(SizeOfExpression(IdentifierName("ConstantBuffer"))))))));

return statements.ToImmutable();
}
Expand Down

0 comments on commit 53becce

Please sign in to comment.