diff --git a/lib/proc-macros/Cargo.toml b/lib/proc-macros/Cargo.toml index 3b526c3396..15c74e5047 100644 --- a/lib/proc-macros/Cargo.toml +++ b/lib/proc-macros/Cargo.toml @@ -13,3 +13,4 @@ proc-macro = true proc-macro2 = { workspace=true } quote = { workspace=true } syn = { workspace=true } + diff --git a/lib/proc-macros/src/content_hash.rs b/lib/proc-macros/src/content_hash.rs index 1f286f4348..4b0dd27a4f 100644 --- a/lib/proc-macros/src/content_hash.rs +++ b/lib/proc-macros/src/content_hash.rs @@ -1,14 +1,17 @@ -use proc_macro2::TokenStream; -use quote::{quote, quote_spanned}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; use syn::spanned::Spanned; -use syn::{parse_quote, Data, Fields, GenericParam, Generics, Index}; +use syn::{ + parse_quote, Data, Field, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Index, + Variant, +}; pub fn add_trait_bounds(mut generics: Generics) -> Generics { for param in &mut generics.params { if let GenericParam::Type(ref mut type_param) = *param { type_param .bounds - .push(parse_quote!(crate::content_hash::ContentHash)); + .push(parse_quote!(::jj_lib::content_hash::ContentHash)); } } generics @@ -43,6 +46,77 @@ pub fn generate_hash_impl(data: &Data) -> TokenStream { quote! {} } }, - _ => unimplemented!("ContentHash can only be derived for structs."), + // Generates a match statement with a match arm and hash implementation + // for each of the variants in the enum. + Data::Enum(ref data) => { + let match_hash_statements = + data.variants + .iter() + .enumerate() + .map(|(i, v)| match &v.fields { + Fields::Named(fields) => named_fields_match_arm(i, v, fields), + Fields::Unnamed(fields) => anonymous_fields_match_arm(i, v, fields), + Fields::Unit => { + let ix = index_to_ordinal(i); + quote_spanned! {v.span() => + Self::#v => {state.update(&#ix.to_le_bytes());} + } + } + }); + quote! { + match self { + #(#match_hash_statements)* + } + } + } + Data::Union(_) => unimplemented!("ContentHash cannot be derived for unions."), + } +} + +// The documentation for `ContentHash` specifies that the hash impl for each +// enum variant should hash the ordinal number of the enum variant as a little +// endian u32 before hashing the variant's fields, if any. +fn index_to_ordinal(ix: usize) -> u32 { + u32::try_from(ix).expect("The number of enum variants overflows a u32.") +} + +fn enum_bindings<'a>(fields: impl IntoIterator) -> Vec { + fields + .into_iter() + .enumerate() + .map(|(i, f)| { + // If the field is named, use the name, otherwise generate a placeholder name. + f.ident.clone().unwrap_or(format_ident!("field_{}", i)) + }) + .collect::>() +} + +// Variants with named fields use {} braces for the match arm whereas variants +// with unnamed fields use () parentheses. Thus, we have two functions. +fn named_fields_match_arm(index: usize, variant: &Variant, fields: &FieldsNamed) -> TokenStream { + let bindings = enum_bindings(fields.named.iter()); + let variant_id = &variant.ident; + let ix = index_to_ordinal(index); + quote_spanned! {variant.span() => + Self::#variant_id{ #(#bindings),* } => { + ::jj_lib::content_hash::ContentHash::hash(&#ix, state); + #( ::jj_lib::content_hash::ContentHash::hash(#bindings, state); )* + } + } +} + +fn anonymous_fields_match_arm( + index: usize, + variant: &Variant, + fields: &FieldsUnnamed, +) -> TokenStream { + let bindings = enum_bindings(fields.unnamed.iter()); + let variant_id = &variant.ident; + let ix = index_to_ordinal(index); + quote_spanned! {variant.span() => + Self::#variant_id( #(#bindings),* ) => { + ::jj_lib::content_hash::ContentHash::hash(&#ix, state); + #( ::jj_lib::content_hash::ContentHash::hash(#bindings, state); )* + } } } diff --git a/lib/src/content_hash.rs b/lib/src/content_hash.rs index df990eab09..4de41ee82f 100644 --- a/lib/src/content_hash.rs +++ b/lib/src/content_hash.rs @@ -264,6 +264,18 @@ mod tests { ); } + // Test that the derived version of `ContentHash` matches the that's + // manually implemented for `std::Option`. + #[test] + fn derive_for_enum() { + #[derive(ContentHash)] + enum MyOption { + None, + Some(T), + } + assert_eq!(hash(&Option::::None), hash(&MyOption::::None)); + assert_eq!(hash(&Some(1)), hash(&MyOption::Some(1))); + } // This will be removed once all uses of content_hash! are replaced by the // derive version. #[test]