diff --git a/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs b/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs index 9eaf22f8f..2307f5f04 100644 --- a/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs +++ b/src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs @@ -9,6 +9,25 @@ namespace ComputeSharp.SourceGeneration.Extensions; /// internal static class ITypeSymbolExtensions { + /// + /// Gets the method of this symbol that have a particular name. + /// + /// The input instance to check. + /// The name of the method to find. + /// The target method, if present. + public static IMethodSymbol? GetMethod(this ITypeSymbol symbol, string name) + { + foreach (IMethodSymbol methodSymbol in symbol.GetMembers(name)) + { + if (methodSymbol.Name == name) + { + return methodSymbol; + } + } + + return null; + } + /// /// Checks whether or not a given type symbol has a specified fully qualified metadata name. /// @@ -28,7 +47,7 @@ public static bool HasFullyQualifiedMetadataName(this ITypeSymbol symbol, string /// Checks whether or not a given implements an interface of a specified type. /// /// The target instance to check. - /// The instane to check for inheritance from. + /// The instance to check for inheritance from. /// Whether or not has an interface of type . public static bool HasInterfaceWithType(this ITypeSymbol typeSymbol, ITypeSymbol interfaceSymbol) { diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs index 2eb8d6a56..a16a9bf8e 100644 --- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs +++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs @@ -1,3 +1,4 @@ +using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; namespace ComputeSharp.SourceGenerators; @@ -10,9 +11,14 @@ partial class ComputeShaderDescriptorGenerator /// /// The input instance to check. /// The instance currently in use. + /// The (constructed) shader interface type implemented by the shader type. /// Whether is a "pixel shader like" type. /// Whether is a compute shader type at all. - private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compilation compilation, out bool isPixelShaderLike) + private static bool TryGetIsPixelShaderLike( + INamedTypeSymbol typeSymbol, + Compilation compilation, + [NotNullWhen(true)] out INamedTypeSymbol? shaderInterfaceType, + out bool isPixelShaderLike) { INamedTypeSymbol computeShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader")!; INamedTypeSymbol pixelShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader`1")!; @@ -21,18 +27,21 @@ private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compila { if (SymbolEqualityComparer.Default.Equals(interfaceSymbol, computeShaderSymbol)) { + shaderInterfaceType = interfaceSymbol; isPixelShaderLike = false; return true; } else if (SymbolEqualityComparer.Default.Equals(interfaceSymbol.ConstructedFrom, pixelShaderSymbol)) { + shaderInterfaceType = interfaceSymbol; isPixelShaderLike = true; return true; } } + shaderInterfaceType = null; isPixelShaderLike = false; return false; diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs index a44e4464a..e276df9aa 100644 --- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs +++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs @@ -31,6 +31,8 @@ internal static partial class HlslSource /// The collection of produced instances. /// The input object currently in use. /// The for the shader type. + /// The shader interface type implemented by the shader type. + /// Whether is a "pixel shader like" type. /// The thread ids value for the X axis. /// The thread ids value for the Y axis. /// The thread ids value for the Z axis. @@ -42,6 +44,8 @@ public static void GetInfo( ImmutableArrayBuilder diagnostics, Compilation compilation, INamedTypeSymbol structDeclarationSymbol, + INamedTypeSymbol shaderInterfaceType, + bool isPixelShaderLike, int threadsX, int threadsY, int threadsZ, @@ -53,6 +57,8 @@ public static void GetInfo( // Detect any invalid properties HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol); + token.ThrowIfCancellationRequested(); + // We need to sets to track all discovered custom types and static methods HashSet discoveredTypes = new(SymbolEqualityComparer.Default); Dictionary staticMethods = new(SymbolEqualityComparer.Default); @@ -62,9 +68,8 @@ public static void GetInfo( Dictionary staticFieldDefinitions = new(SymbolEqualityComparer.Default); // Setup the semantic model and basic properties - INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" }); - bool isComputeShader = pixelShaderSymbol is null; - string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!); + bool isComputeShader = !isPixelShaderLike; + string? implicitTextureType = HlslKnownTypes.GetMappedNameForPixelShaderType(shaderInterfaceType); token.ThrowIfCancellationRequested(); @@ -90,6 +95,7 @@ public static void GetInfo( (string entryPoint, ImmutableArray processedMethods, isSamplerUsed) = GetProcessedMethods( diagnostics, structDeclarationSymbol, + shaderInterfaceType, semanticModelProvider, discoveredTypes, staticMethods, @@ -360,6 +366,7 @@ private static ImmutableArray GetSharedBuffers( /// /// The collection of produced instances. /// The type symbol for the shader type. + /// The shader interface type implemented by the shader type. /// The instance for the type to process. /// The collection of currently discovered types. /// The set of discovered and processed static methods. @@ -373,6 +380,7 @@ private static ImmutableArray GetSharedBuffers( private static (string EntryPoint, ImmutableArray Methods, bool IsSamplerUser) GetProcessedMethods( ImmutableArrayBuilder diagnostics, INamedTypeSymbol structDeclarationSymbol, + INamedTypeSymbol shaderInterfaceType, SemanticModelProvider semanticModel, ICollection discoveredTypes, IDictionary staticMethods, @@ -385,6 +393,7 @@ private static (string EntryPoint, ImmutableArray Methods, bool IsSa { using ImmutableArrayBuilder methods = new(); + IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!; string? entryPoint = null; bool isSamplerUsed = false; @@ -396,22 +405,15 @@ private static (string EntryPoint, ImmutableArray Methods, bool IsSa continue; } + // Ensure that we have accessible source information if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration)) { continue; } - bool isShaderEntryPoint = - (isComputeShader && - methodSymbol.Name == "Execute" && - methodSymbol.ReturnsVoid && - methodSymbol.TypeParameters.Length == 0 && - methodSymbol.Parameters.Length == 0) || - (!isComputeShader && - methodSymbol.Name == "Execute" && - methodSymbol.ReturnType is not null && // TODO: match for pixel type - methodSymbol.TypeParameters.Length == 0 && - methodSymbol.Parameters.Length == 0); + // Check whether the current method is the entry point (ie. it's implementing 'Execute'). We use + // 'FindImplementationForInterfaceMember' to handle explicit interface implementations as well. + bool isShaderEntryPoint = structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod) is not null; // Except for the entry point, ignore explicit interface implementations if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty) diff --git a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs index 9fbceeb7d..1ed999402 100644 --- a/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs +++ b/src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.cs @@ -54,7 +54,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } // Check whether type is a compute shader, and if so, if it's pixel shader like - if (!TryGetIsPixelShaderLike(typeSymbol, context.SemanticModel.Compilation, out bool isPixelShaderLike)) + if (!TryGetIsPixelShaderLike( + typeSymbol, + context.SemanticModel.Compilation, + out INamedTypeSymbol? shaderInterfaceType, + out bool isPixelShaderLike)) { return default; } @@ -91,6 +95,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) diagnostics, context.SemanticModel.Compilation, typeSymbol, + shaderInterfaceType, + isPixelShaderLike, threadsX, threadsY, threadsZ, diff --git a/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs b/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs index 0eb7cb216..a1b013c18 100644 --- a/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs +++ b/src/ComputeSharp.SourceGenerators/Mappings/HlslKnownTypes.cs @@ -208,11 +208,18 @@ public static partial string GetMappedName(INamedTypeSymbol typeSymbol) /// /// Gets the mapped HLSL-compatible type name for the output texture of a pixel shader. /// - /// The pixel shader type to map. + /// The shader type to map. /// The HLSL-compatible type name that can be used in an HLSL shader. - public static string GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol) + public static string? GetMappedNameForPixelShaderType(INamedTypeSymbol typeSymbol) { - string genericArgumentName = ((INamedTypeSymbol)typeSymbol.TypeArguments.First()).GetFullyQualifiedMetadataName(); + // If the shader type is not a pixel shader type (ie. it has a type argument), stop here. + // At this point the input is guaranteed to either be 'IComputeShader' or 'IComputeShader'. + if (typeSymbol.TypeArguments is not [INamedTypeSymbol pixelShaderType]) + { + return null; + } + + string genericArgumentName = pixelShaderType.GetFullyQualifiedMetadataName(); // If the current type is a custom type, format it as needed if (!KnownHlslTypeMetadataNames.TryGetValue(genericArgumentName, out string? mappedElementType))