diff --git a/lib/proc-macros/src/content_hash.rs b/lib/proc-macros/src/content_hash.rs index f03a8b3e74..faa4ea04ec 100644 --- a/lib/proc-macros/src/content_hash.rs +++ b/lib/proc-macros/src/content_hash.rs @@ -1,7 +1,7 @@ -use proc_macro2::TokenStream; -use quote::{quote, quote_spanned}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; use syn::spanned::Spanned; -use syn::{parse_quote, Data, Fields, GenericParam, Generics, Index}; +use syn::{parse_quote, Data, Field, Fields, GenericParam, Generics, Index}; pub fn add_trait_bounds(mut generics: Generics) -> Generics { for param in &mut generics.params { @@ -46,6 +46,66 @@ pub fn generate_hash_impl(data: &Data) -> TokenStream { quote! {} } }, - _ => unimplemented!("ContentHash can only be derived for structs."), + // Generates a match statement with a match arm and hash implementation + // for each of the variants in the enum. + Data::Enum(ref data) => { + let match_hash_statements = data.variants.iter().enumerate().map(|(i, v)| { + let variant_id = &v.ident; + match &v.fields { + Fields::Named(fields) => { + let bindings = enum_bindings(fields.named.iter()); + let ix = index_to_ordinal(i); + quote_spanned! {v.span() => + Self::#variant_id{ #(#bindings),* } => { + ::jj_lib::content_hash::ContentHash::hash(&#ix, state); + #( ::jj_lib::content_hash::ContentHash::hash(#bindings, state); )* + } + } + } + Fields::Unnamed(fields) => { + let bindings = enum_bindings(fields.unnamed.iter()); + let ix = index_to_ordinal(i); + quote_spanned! {v.span() => + Self::#variant_id( #(#bindings),* ) => { + ::jj_lib::content_hash::ContentHash::hash(&#ix, state); + #( ::jj_lib::content_hash::ContentHash::hash(#bindings, state); )* + } + } + } + Fields::Unit => { + let ix = index_to_ordinal(i); + quote_spanned! {v.span() => + Self::#variant_id => { + ::jj_lib::content_hash::ContentHash::hash(&#ix, state); + } + } + } + } + }); + quote! { + match self { + #(#match_hash_statements)* + } + } + } + Data::Union(_) => unimplemented!("ContentHash cannot be derived for unions."), } } + +// The documentation for `ContentHash` specifies that the hash impl for each +// enum variant should hash the ordinal number of the enum variant as a little +// endian u32 before hashing the variant's fields, if any. +fn index_to_ordinal(ix: usize) -> u32 { + u32::try_from(ix).expect("The number of enum variants overflows a u32.") +} + +fn enum_bindings<'a>(fields: impl IntoIterator) -> Vec { + fields + .into_iter() + .enumerate() + .map(|(i, f)| { + // If the field is named, use the name, otherwise generate a placeholder name. + f.ident.clone().unwrap_or(format_ident!("field_{}", i)) + }) + .collect::>() +} diff --git a/lib/src/content_hash.rs b/lib/src/content_hash.rs index df990eab09..4de41ee82f 100644 --- a/lib/src/content_hash.rs +++ b/lib/src/content_hash.rs @@ -264,6 +264,18 @@ mod tests { ); } + // Test that the derived version of `ContentHash` matches the that's + // manually implemented for `std::Option`. + #[test] + fn derive_for_enum() { + #[derive(ContentHash)] + enum MyOption { + None, + Some(T), + } + assert_eq!(hash(&Option::::None), hash(&MyOption::::None)); + assert_eq!(hash(&Some(1)), hash(&MyOption::Some(1))); + } // This will be removed once all uses of content_hash! are replaced by the // derive version. #[test]