diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs index 831069bc3..f1e0d2e8a 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs @@ -28,22 +28,72 @@ private static partial class EffectId /// Extracts the effect id info for the current shader. /// /// The collection of produced instances. + /// The input object currently in use. /// The for the shader type in use. /// The resulting effect id. - public static ImmutableArray GetInfo(ImmutableArrayBuilder diagnostics, INamedTypeSymbol structDeclarationSymbol) + public static ImmutableArray GetInfo( + ImmutableArrayBuilder diagnostics, + Compilation compilation, + INamedTypeSymbol structDeclarationSymbol) + { + if (TryGetDefinedEffectId(compilation, structDeclarationSymbol, out ImmutableArray effectId)) + { + return effectId; + } + + return CreateDefaultEffectId(structDeclarationSymbol); + } + + /// + /// Tries to get the defined effect id for a given shader type. + /// + /// The input object currently in use. + /// The input instance. + /// The resulting defined effect id, if found. + /// Whether or not a defined effect id could be found. + private static bool TryGetDefinedEffectId(Compilation compilation, INamedTypeSymbol typeSymbol, out ImmutableArray effectId) + { + INamedTypeSymbol effectIdAttributeSymbol = compilation.GetTypeByMetadataName("ComputeSharp.D2D1.D2DEffectIdAttribute")!; + + foreach (AttributeData attributeData in typeSymbol.GetAttributes()) + { + // Check that the attribute is [D2DEffectId] and with a valid parameter + if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, effectIdAttributeSymbol) && + attributeData.ConstructorArguments is [{ Value: string value }] && + Guid.TryParse(value, out Guid guid)) + { + byte[] bytes = guid.ToByteArray(); + + effectId = Unsafe.As>(ref bytes); + + return true; + } + } + + effectId = default; + + return false; + } + + /// + /// Creates the default effect id for a given type symbol. + /// + /// The input instance. + /// The resulting effect id. + private static ImmutableArray CreateDefaultEffectId(INamedTypeSymbol typeSymbol) { // Initialize an instance using the MD5 algorithm. We use this for several reasons: // - We don't really need security, this is just to uniquely identify types // - The hash size is 128 bits, which is exactly the size of a GUID. IncrementalHash incrementalHash = EffectId.incrementalHash ??= IncrementalHash.CreateHash(HashAlgorithmName.MD5); - string assemblyName = structDeclarationSymbol.ContainingAssembly?.Name ?? string.Empty; + string assemblyName = typeSymbol.ContainingAssembly?.Name ?? string.Empty; using ImmutableArrayBuilder byteBuffer = ImmutableArrayBuilder.Rent(); using ImmutableArrayBuilder charBuffer = ImmutableArrayBuilder.Rent(); // Format the fully qualified name into a pooled builder to avoid the string allocation - structDeclarationSymbol.AppendFullyQualifiedMetadataName(in charBuffer); + typeSymbol.AppendFullyQualifiedMetadataName(in charBuffer); int maxTypeNameCharsLength = Encoding.UTF8.GetMaxByteCount(charBuffer.Count); int maxAssemblyNameCharsLength = Encoding.UTF8.GetMaxByteCount(assemblyName.Length); diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs index 61e2fead8..c7811456d 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs @@ -53,7 +53,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) using ImmutableArrayBuilder diagnostics = ImmutableArrayBuilder.Rent(); // EffectId info - ImmutableArray effectId = EffectId.GetInfo(diagnostics, typeSymbol); + ImmutableArray effectId = EffectId.GetInfo(diagnostics, context.SemanticModel.Compilation, typeSymbol); token.ThrowIfCancellationRequested();