diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs
index 831069bc3..f1e0d2e8a 100644
--- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs
+++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateEffectIdProperty.cs
@@ -28,22 +28,72 @@ private static partial class EffectId
/// Extracts the effect id info for the current shader.
///
/// The collection of produced instances.
+ /// The input object currently in use.
/// The for the shader type in use.
/// The resulting effect id.
- public static ImmutableArray GetInfo(ImmutableArrayBuilder diagnostics, INamedTypeSymbol structDeclarationSymbol)
+ public static ImmutableArray GetInfo(
+ ImmutableArrayBuilder diagnostics,
+ Compilation compilation,
+ INamedTypeSymbol structDeclarationSymbol)
+ {
+ if (TryGetDefinedEffectId(compilation, structDeclarationSymbol, out ImmutableArray effectId))
+ {
+ return effectId;
+ }
+
+ return CreateDefaultEffectId(structDeclarationSymbol);
+ }
+
+ ///
+ /// Tries to get the defined effect id for a given shader type.
+ ///
+ /// The input object currently in use.
+ /// The input instance.
+ /// The resulting defined effect id, if found.
+ /// Whether or not a defined effect id could be found.
+ private static bool TryGetDefinedEffectId(Compilation compilation, INamedTypeSymbol typeSymbol, out ImmutableArray effectId)
+ {
+ INamedTypeSymbol effectIdAttributeSymbol = compilation.GetTypeByMetadataName("ComputeSharp.D2D1.D2DEffectIdAttribute")!;
+
+ foreach (AttributeData attributeData in typeSymbol.GetAttributes())
+ {
+ // Check that the attribute is [D2DEffectId] and with a valid parameter
+ if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, effectIdAttributeSymbol) &&
+ attributeData.ConstructorArguments is [{ Value: string value }] &&
+ Guid.TryParse(value, out Guid guid))
+ {
+ byte[] bytes = guid.ToByteArray();
+
+ effectId = Unsafe.As>(ref bytes);
+
+ return true;
+ }
+ }
+
+ effectId = default;
+
+ return false;
+ }
+
+ ///
+ /// Creates the default effect id for a given type symbol.
+ ///
+ /// The input instance.
+ /// The resulting effect id.
+ private static ImmutableArray CreateDefaultEffectId(INamedTypeSymbol typeSymbol)
{
// Initialize an instance using the MD5 algorithm. We use this for several reasons:
// - We don't really need security, this is just to uniquely identify types
// - The hash size is 128 bits, which is exactly the size of a GUID.
IncrementalHash incrementalHash = EffectId.incrementalHash ??= IncrementalHash.CreateHash(HashAlgorithmName.MD5);
- string assemblyName = structDeclarationSymbol.ContainingAssembly?.Name ?? string.Empty;
+ string assemblyName = typeSymbol.ContainingAssembly?.Name ?? string.Empty;
using ImmutableArrayBuilder byteBuffer = ImmutableArrayBuilder.Rent();
using ImmutableArrayBuilder charBuffer = ImmutableArrayBuilder.Rent();
// Format the fully qualified name into a pooled builder to avoid the string allocation
- structDeclarationSymbol.AppendFullyQualifiedMetadataName(in charBuffer);
+ typeSymbol.AppendFullyQualifiedMetadataName(in charBuffer);
int maxTypeNameCharsLength = Encoding.UTF8.GetMaxByteCount(charBuffer.Count);
int maxAssemblyNameCharsLength = Encoding.UTF8.GetMaxByteCount(assemblyName.Length);
diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
index 61e2fead8..c7811456d 100644
--- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
+++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
@@ -53,7 +53,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
using ImmutableArrayBuilder diagnostics = ImmutableArrayBuilder.Rent();
// EffectId info
- ImmutableArray effectId = EffectId.GetInfo(diagnostics, typeSymbol);
+ ImmutableArray effectId = EffectId.GetInfo(diagnostics, context.SemanticModel.Compilation, typeSymbol);
token.ThrowIfCancellationRequested();