diff --git a/lib/proc-macros/src/content_hash.rs b/lib/proc-macros/src/content_hash.rs index 65573af086..f03a8b3e74 100644 --- a/lib/proc-macros/src/content_hash.rs +++ b/lib/proc-macros/src/content_hash.rs @@ -1,7 +1,18 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; -use syn::{Data, Fields, Index}; +use syn::{parse_quote, Data, Fields, GenericParam, Generics, Index}; + +pub fn add_trait_bounds(mut generics: Generics) -> Generics { + for param in &mut generics.params { + if let GenericParam::Type(ref mut type_param) = *param { + type_param + .bounds + .push(parse_quote!(::jj_lib::content_hash::ContentHash)); + } + } + generics +} pub fn generate_hash_impl(data: &Data) -> TokenStream { match *data { diff --git a/lib/proc-macros/src/lib.rs b/lib/proc-macros/src/lib.rs index 8e8c128413..2505e83b97 100644 --- a/lib/proc-macros/src/lib.rs +++ b/lib/proc-macros/src/lib.rs @@ -18,10 +18,15 @@ pub fn derive_content_hash(input: proc_macro::TokenStream) -> proc_macro::TokenS // Generate an expression to hash each of the fields in the struct. let hash_impl = content_hash::generate_hash_impl(&input.data); + // Handle structs and enums with generics. + let generics = content_hash::add_trait_bounds(input.generics); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let expanded = quote! { #[automatically_derived] - impl ::jj_lib::content_hash::ContentHash for #name { - fn hash(&self, state: &mut impl ::jj_lib::content_hash::DigestUpdate) { + impl #impl_generics ::jj_lib::content_hash::ContentHash for #name #ty_generics + #where_clause { + fn hash(&self, state: &mut impl digest::Update) { #hash_impl } } diff --git a/lib/src/content_hash.rs b/lib/src/content_hash.rs index 881e907300..df990eab09 100644 --- a/lib/src/content_hash.rs +++ b/lib/src/content_hash.rs @@ -240,12 +240,27 @@ mod tests { content_hash! { struct Foo { x: Vec>, y: i64 } } + let foo_hash = hex::encode(hash(&Foo { + x: vec![None, Some(42)], + y: 17, + })); insta::assert_snapshot!( - hex::encode(hash(&Foo { + foo_hash, + @"e33c423b4b774b1353c414e0f9ef108822fde2fd5113fcd53bf7bd9e74e3206690b96af96373f268ed95dd020c7cbe171c7b7a6947fcaf5703ff6c8e208cefd4" + ); + + // Try again with an equivalent generic struct deriving ContentHash. + #[derive(ContentHash)] + struct GenericFoo { + x: X, + y: Y, + } + assert_eq!( + hex::encode(hash(&GenericFoo { x: vec![None, Some(42)], - y: 17 + y: 17i64 })), - @"e33c423b4b774b1353c414e0f9ef108822fde2fd5113fcd53bf7bd9e74e3206690b96af96373f268ed95dd020c7cbe171c7b7a6947fcaf5703ff6c8e208cefd4" + foo_hash ); }