diff --git a/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs b/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs
index 9eaf22f8f..2307f5f04 100644
--- a/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs
+++ b/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs
@@ -9,6 +9,25 @@ namespace ComputeSharp.SourceGeneration.Extensions;
///
internal static class ITypeSymbolExtensions
{
+ ///
+ /// Gets the method of this symbol that have a particular name.
+ ///
+ /// The input instance to check.
+ /// The name of the method to find.
+ /// The target method, if present.
+ public static IMethodSymbol? GetMethod(this ITypeSymbol symbol, string name)
+ {
+ foreach (IMethodSymbol methodSymbol in symbol.GetMembers(name))
+ {
+ if (methodSymbol.Name == name)
+ {
+ return methodSymbol;
+ }
+ }
+
+ return null;
+ }
+
///
/// Checks whether or not a given type symbol has a specified fully qualified metadata name.
///
@@ -28,7 +47,7 @@ public static bool HasFullyQualifiedMetadataName(this ITypeSymbol symbol, string
/// Checks whether or not a given implements an interface of a specified type.
///
/// The target instance to check.
- /// The instane to check for inheritance from.
+ /// The instance to check for inheritance from.
/// Whether or not has an interface of type .
public static bool HasInterfaceWithType(this ITypeSymbol typeSymbol, ITypeSymbol interfaceSymbol)
{
diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs
index 2eb8d6a56..a16a9bf8e 100644
--- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs
+++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs
@@ -1,3 +1,4 @@
+using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
namespace ComputeSharp.SourceGenerators;
@@ -10,9 +11,14 @@ partial class ComputeShaderDescriptorGenerator
///
/// The input instance to check.
/// The instance currently in use.
+ /// The (constructed) shader interface type implemented by the shader type.
/// Whether is a "pixel shader like" type.
/// Whether is a compute shader type at all.
- private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compilation compilation, out bool isPixelShaderLike)
+ private static bool TryGetIsPixelShaderLike(
+ INamedTypeSymbol typeSymbol,
+ Compilation compilation,
+ [NotNullWhen(true)] out INamedTypeSymbol? shaderInterfaceType,
+ out bool isPixelShaderLike)
{
INamedTypeSymbol computeShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader")!;
INamedTypeSymbol pixelShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader`1")!;
@@ -21,18 +27,21 @@ private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compila
{
if (SymbolEqualityComparer.Default.Equals(interfaceSymbol, computeShaderSymbol))
{
+ shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = false;
return true;
}
else if (SymbolEqualityComparer.Default.Equals(interfaceSymbol.ConstructedFrom, pixelShaderSymbol))
{
+ shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = true;
return true;
}
}
+ shaderInterfaceType = null;
isPixelShaderLike = false;
return false;
diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs
index a44e4464a..e276df9aa 100644
--- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs
+++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs
@@ -31,6 +31,8 @@ internal static partial class HlslSource
/// The collection of produced instances.
/// The input object currently in use.
/// The for the shader type.
+ /// The shader interface type implemented by the shader type.
+ /// Whether is a "pixel shader like" type.
/// The thread ids value for the X axis.
/// The thread ids value for the Y axis.
/// The thread ids value for the Z axis.
@@ -42,6 +44,8 @@ public static void GetInfo(
ImmutableArrayBuilder diagnostics,
Compilation compilation,
INamedTypeSymbol structDeclarationSymbol,
+ INamedTypeSymbol shaderInterfaceType,
+ bool isPixelShaderLike,
int threadsX,
int threadsY,
int threadsZ,
@@ -53,6 +57,8 @@ public static void GetInfo(
// Detect any invalid properties
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);
+ token.ThrowIfCancellationRequested();
+
// We need to sets to track all discovered custom types and static methods
HashSet discoveredTypes = new(SymbolEqualityComparer.Default);
Dictionary staticMethods = new(SymbolEqualityComparer.Default);
@@ -62,9 +68,8 @@ public static void GetInfo(
Dictionary staticFieldDefinitions = new(SymbolEqualityComparer.Default);
// Setup the semantic model and basic properties
- INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" });
- bool isComputeShader = pixelShaderSymbol is null;
- string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!);
+ bool isComputeShader = !isPixelShaderLike;
+ string? implicitTextureType = HlslKnownTypes.GetMappedNameForPixelShaderType(shaderInterfaceType);
token.ThrowIfCancellationRequested();
@@ -90,6 +95,7 @@ public static void GetInfo(
(string entryPoint, ImmutableArray processedMethods, isSamplerUsed) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
+ shaderInterfaceType,
semanticModelProvider,
discoveredTypes,
staticMethods,
@@ -360,6 +366,7 @@ private static ImmutableArray GetSharedBuffers(
///
/// The collection of produced instances.
/// The type symbol for the shader type.
+ /// The shader interface type implemented by the shader type.
/// The instance for the type to process.
/// The collection of currently discovered types.
/// The set of discovered and processed static methods.
@@ -373,6 +380,7 @@ private static ImmutableArray GetSharedBuffers(
private static (string EntryPoint, ImmutableArray Methods, bool IsSamplerUser) GetProcessedMethods(
ImmutableArrayBuilder diagnostics,
INamedTypeSymbol structDeclarationSymbol,
+ INamedTypeSymbol shaderInterfaceType,
SemanticModelProvider semanticModel,
ICollection discoveredTypes,
IDictionary staticMethods,
@@ -385,6 +393,7 @@ private static (string EntryPoint, ImmutableArray Methods, bool IsSa
{
using ImmutableArrayBuilder methods = new();
+ IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
string? entryPoint = null;
bool isSamplerUsed = false;
@@ -396,22 +405,15 @@ private static (string EntryPoint, ImmutableArray Methods, bool IsSa
continue;
}
+ // Ensure that we have accessible source information
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
{
continue;
}
- bool isShaderEntryPoint =
- (isComputeShader &&
- methodSymbol.Name == "Execute" &&
- methodSymbol.ReturnsVoid &&
- methodSymbol.TypeParameters.Length == 0 &&
- methodSymbol.Parameters.Length == 0) ||
- (!isComputeShader &&
- methodSymbol.Name == "Execute" &&
- methodSymbol.ReturnType is not null && // TODO: match for pixel type
- methodSymbol.TypeParameters.Length == 0 &&
- methodSymbol.Parameters.Length == 0);
+ // Check whether the current method is the entry point (ie. it's implementing 'Execute'). We use
+ // 'FindImplementationForInterfaceMember' to handle explicit interface implementations as well.
+ bool isShaderEntryPoint = structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod) is not null;
// Except for the entry point, ignore explicit interface implementations
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs
index 9fbceeb7d..1ed999402 100644
--- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs
+++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs
@@ -54,7 +54,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
}
// Check whether type is a compute shader, and if so, if it's pixel shader like
- if (!TryGetIsPixelShaderLike(typeSymbol, context.SemanticModel.Compilation, out bool isPixelShaderLike))
+ if (!TryGetIsPixelShaderLike(
+ typeSymbol,
+ context.SemanticModel.Compilation,
+ out INamedTypeSymbol? shaderInterfaceType,
+ out bool isPixelShaderLike))
{
return default;
}
@@ -91,6 +95,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
diagnostics,
context.SemanticModel.Compilation,
typeSymbol,
+ shaderInterfaceType,
+ isPixelShaderLike,
threadsX,
threadsY,
threadsZ,
diff --git a/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs b/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs
index 0eb7cb216..a1b013c18 100644
--- a/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs
+++ b/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs
@@ -208,11 +208,18 @@ public static partial string GetMappedName(INamedTypeSymbol typeSymbol)
///
/// Gets the mapped HLSL-compatible type name for the output texture of a pixel shader.
///
- /// The pixel shader type to map.
+ /// The shader type to map.
/// The HLSL-compatible type name that can be used in an HLSL shader.
- public static string GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol)
+ public static string? GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol)
{
- string genericArgumentName = ((INamedTypeSymbol)typeSymbol.TypeArguments.First()).GetFullyQualifiedMetadataName();
+ // If the shader type is not a pixel shader type (ie. it has a type argument), stop here.
+ // At this point the input is guaranteed to either be 'IComputeShader' or 'IComputeShader'.
+ if (typeSymbol.TypeArguments is not [INamedTypeSymbol pixelShaderType])
+ {
+ return null;
+ }
+
+ string genericArgumentName = pixelShaderType.GetFullyQualifiedMetadataName();
// If the current type is a custom type, format it as needed
if (!KnownHlslTypeMetadataNames.TryGetValue(genericArgumentName, out string? mappedElementType))