Skip to content

Commit

Permalink
Update 'D2D.GetInput' to support all input types
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Nov 24, 2024
1 parent 984f834 commit 6386c4c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public override void Initialize(AnalysisContext context)
context.RegisterCompilationStartAction(static context =>
{
// If we can't get the D2D methods map, we have to stop right away
if (!TryBuildMethodSymbolMap(context.Compilation, out ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType>? methodSymbols))
if (!TryBuildMethodSymbolMap(context.Compilation, out ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType?>? methodSymbols))
{
return;
}
Expand Down Expand Up @@ -65,7 +65,7 @@ public override void Initialize(AnalysisContext context)
}
// Validate that the target method is one of the ones we care about, and get the target input type
if (!methodSymbols.TryGetValue(targetMethodSymbol, out D2D1PixelShaderInputType targetInputType))
if (!methodSymbols.TryGetValue(targetMethodSymbol, out D2D1PixelShaderInputType? targetInputType))
{
return;
}
Expand Down Expand Up @@ -101,7 +101,7 @@ public override void Initialize(AnalysisContext context)
typeSymbol,
inputCount));
}
else if ((D2D1PixelShaderInputType)inputTypes[index] != targetInputType)
else if (targetInputType is not null && (D2D1PixelShaderInputType)inputTypes[index] != targetInputType)
{
// Second validation: the input type must match
context.ReportDiagnostic(Diagnostic.Create(
Expand All @@ -120,7 +120,7 @@ public override void Initialize(AnalysisContext context)
/// <param name="compilation">The <see cref="Compilation"/> to consider for analysis.</param>
/// <param name="methodSymbols">The resulting mapping of resolved <see cref="IMethodSymbol"/> instances.</param>
/// <returns>Whether all requested <see cref="IMethodSymbol"/> instances could be resolved.</returns>
private static bool TryBuildMethodSymbolMap(Compilation compilation, [NotNullWhen(true)] out ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType>? methodSymbols)
private static bool TryBuildMethodSymbolMap(Compilation compilation, [NotNullWhen(true)] out ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType?>? methodSymbols)
{
// Get the 'D2D' symbol, to get methods from it
if (compilation.GetTypeByMetadataName("ComputeSharp.D2D1.D2D") is not { } d2DSymbol)
Expand All @@ -140,7 +140,7 @@ private static bool TryBuildMethodSymbolMap(Compilation compilation, [NotNullWhe
d2DSymbol.GetMethod(nameof(D2D.SampleInputAtPosition))
];

ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType>.Builder inputTypeMethodMap = ImmutableDictionary.CreateBuilder<IMethodSymbol, D2D1PixelShaderInputType>(SymbolEqualityComparer.Default);
ImmutableDictionary<IMethodSymbol, D2D1PixelShaderInputType?>.Builder inputTypeMethodMap = ImmutableDictionary.CreateBuilder<IMethodSymbol, D2D1PixelShaderInputType?>(SymbolEqualityComparer.Default);

// Validate all methods and build the map
foreach (IMethodSymbol? d2DMethodSymbol in d2DMethodSymbols)
Expand All @@ -154,9 +154,16 @@ private static bool TryBuildMethodSymbolMap(Compilation compilation, [NotNullWhe
}

// Lookup the attribute to get the D2D input type (the attribute only exists on the 'D2D' type loaded in the analyzer)
D2D1PixelShaderInputType inputType = typeof(D2D).GetMethod(d2DMethodSymbol.Name).GetCustomAttribute<HlslD2DIntrinsicInputTypeAttribute>()!.InputType;

inputTypeMethodMap.Add(d2DMethodSymbol, (D2D1PixelShaderInputType)inputType);
if (typeof(D2D).GetMethod(d2DMethodSymbol.Name).GetCustomAttribute<HlslD2DIntrinsicInputTypeAttribute>() is { } hlslD2DIntrinsicInputTypeAttribute)
{
inputTypeMethodMap.Add(d2DMethodSymbol, hlslD2DIntrinsicInputTypeAttribute.InputType);
}
else
{
// If the method is not annotated, we stil track it, but we will not indicate any exclusive input type.
// This means that the input index validation logic will still work, but we'll skip the input type checks.
inputTypeMethodMap.Add(d2DMethodSymbol, null);
}
}

methodSymbols = inputTypeMethodMap.ToImmutable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace ComputeSharp.D2D1.Intrinsics;
/// An attribute indicating the input type associated with a given D2D HLSL intrinsic.
/// </summary>
/// <param name="inputType">The input type for the current instance.</param>
/// <remarks>
/// <para>If this attribute is not present, methods will be considered as supporting all input types.</para>
/// <para>This matches the behavior for <see cref="Core.Intrinsics.HlslIntrinsicNameAttribute"/>, when not present.</para>
/// </remarks>
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
[Conditional("SOURCE_GENERATOR")]
internal sealed class HlslD2DIntrinsicInputTypeAttribute(D2D1PixelShaderInputType inputType) : Attribute
Expand Down
2 changes: 0 additions & 2 deletions src/ComputeSharp.D2D1/Intrinsics/D2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ public static class D2D
/// </summary>
/// <param name="index">The index of the input texture to get the input from.</param>
/// <returns>The color from the target input at the current coordinate, in <c>INPUTN</c> format.</returns>
/// <remarks>This method is only available for simple inputs.</remarks>
[HlslIntrinsicName("D2DGetInput")]
[HlslD2DIntrinsicInputType(D2D1PixelShaderInputType.Simple)]
public static Float4 GetInput(int index) => throw new InvalidExecutionContextException($"{typeof(D2D)}.{nameof(GetInput)}({typeof(int)})");

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,7 @@ public async Task InvalidInputTypeForD2DIntrinsic_ValidArguments_DoesNotWarn()
public float4 Execute()
{
D2D.GetInput(0);
D2D.GetInput(1);
D2D.GetInputCoordinate(1);
D2D.SampleInput(1, 0);
D2D.SampleInputAtOffset(1, 0);
Expand Down Expand Up @@ -1381,7 +1382,6 @@ public async Task InvalidInputTypeForD2DIntrinsic_InvalidArguments_Warns()
{
public float4 Execute()
{
{|CMPSD2D0084:D2D.GetInput(1)|};
{|CMPSD2D0084:D2D.GetInputCoordinate(0)|};
{|CMPSD2D0084:D2D.SampleInput(0, 0)|};
{|CMPSD2D0084:D2D.SampleInputAtOffset(0, 0)|};
Expand Down

0 comments on commit 6386c4c

Please sign in to comment.