diff --git a/zvariant/src/lib.rs b/zvariant/src/lib.rs index 161975eca..33445e84d 100644 --- a/zvariant/src/lib.rs +++ b/zvariant/src/lib.rs @@ -1611,6 +1611,26 @@ mod tests { Enum::try_from(Value::Str("Variant1".into())), Err(Error::IncorrectType) ); + + #[repr(u8)] + #[derive(Debug, Clone, Copy, PartialEq, Type, Value, OwnedValue)] + #[zvariant(signature = "s", rename_all = "snake_case")] + enum Enum4 { + FooBar, + Baz, + } + + assert_eq!(Enum4::signature(), "s"); + assert_eq!( + Enum4::try_from(Value::Str("foo_bar".into())), + Ok(Enum4::FooBar) + ); + assert_eq!(Enum4::try_from(Value::Str("baz".into())), Ok(Enum4::Baz)); + assert_eq!( + Enum4::try_from(Value::Str("foo_baz".into())), + Err(Error::IncorrectType) + ); + assert_eq!(Enum4::try_from(Value::U32(0)), Err(Error::IncorrectType)); } #[test] diff --git a/zvariant_derive/src/utils.rs b/zvariant_derive/src/utils.rs index 81b6f4e06..4e6ad56da 100644 --- a/zvariant_derive/src/utils.rs +++ b/zvariant_derive/src/utils.rs @@ -22,4 +22,8 @@ def_attrs! { pub StructAttributes("struct") { signature str, rename_all str, deny_unknown_fields none }; /// Attributes defined on fields. pub FieldAttributes("field") { rename str }; + /// Attributes defined on enumerations. + pub EnumAttributes("enum") { signature str, rename_all str }; + /// Attributes defined on variants. + pub VariantAttributes("variant") { rename str }; } diff --git a/zvariant_derive/src/value.rs b/zvariant_derive/src/value.rs index 204b2abf1..b9ce315dd 100644 --- a/zvariant_derive/src/value.rs +++ b/zvariant_derive/src/value.rs @@ -2,8 +2,9 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Expr, Fields, Generics, Ident, - Lifetime, LifetimeDef, + Lifetime, LifetimeDef, Variant, }; +use zvariant_utils::case; use crate::utils::*; @@ -236,18 +237,45 @@ fn impl_enum( Some(repr_attr) => repr_attr.parse_args()?, None => quote! { u32 }, }; + let enum_attrs = EnumAttributes::parse(&attrs)?; + let str_enum = enum_attrs + .signature + .map(|sig| sig == "s") + .unwrap_or_default(); let mut variant_names = vec![]; + let mut str_values = vec![]; for variant in &data.variants { + let variant_attrs = VariantAttributes::parse(&variant.attrs)?; // Ensure all variants of the enum are unit type match variant.fields { Fields::Unit => { variant_names.push(&variant.ident); + if str_enum { + let str_value = enum_name_for_variant( + &variant, + variant_attrs.rename, + enum_attrs.rename_all.as_ref().map(AsRef::as_ref), + )?; + str_values.push(str_value); + } } _ => return Err(Error::new(variant.span(), "must be a unit variant")), } } + let into_val = if str_enum { + quote! { + match e { + #( + #name::#variant_names => #str_values, + )* + } + } + } else { + quote! { e as #repr } + }; + let (value_type, into_value) = match value_type { ValueType::Value => ( quote! { #zv::Value<'_> }, @@ -255,7 +283,7 @@ fn impl_enum( impl ::std::convert::From<#name> for #zv::Value<'_> { #[inline] fn from(e: #name) -> Self { - <#zv::Value as ::std::convert::From<_>>::from(e as #repr).into() + <#zv::Value as ::std::convert::From<_>>::from(#into_val) } } }, @@ -269,7 +297,7 @@ fn impl_enum( #[inline] fn try_from(e: #name) -> #zv::Result { <#zv::OwnedValue as ::std::convert::TryFrom<_>>::try_from( - <#zv::Value as ::std::convert::From<_>>::from(e as #repr) + <#zv::Value as ::std::convert::From<_>>::from(#into_val) ) } } @@ -277,26 +305,68 @@ fn impl_enum( ), }; + let from_val = if str_enum { + quote! { + let v: #zv::Str = ::std::convert::TryInto::try_into(value)?; + + ::std::result::Result::Ok(match v.as_str() { + #( + #str_values => #name::#variant_names, + )* + _ => return ::std::result::Result::Err(#zv::Error::IncorrectType), + }) + } + } else { + quote! { + let v: #repr = ::std::convert::TryInto::try_into(value)?; + + ::std::result::Result::Ok( + #( + if v == #name::#variant_names as #repr { + #name::#variant_names + } else + )* { + return ::std::result::Result::Err(#zv::Error::IncorrectType); + } + ) + } + }; + Ok(quote! { impl ::std::convert::TryFrom<#value_type> for #name { type Error = #zv::Error; #[inline] fn try_from(value: #value_type) -> #zv::Result { - let v: #repr = ::std::convert::TryInto::try_into(value)?; - - ::std::result::Result::Ok( - #( - if v == #name::#variant_names as #repr { - #name::#variant_names - } else - )* { - return ::std::result::Result::Err(#zv::Error::IncorrectType); - } - ) + #from_val } } #into_value }) } + +fn enum_name_for_variant( + v: &Variant, + rename_attr: Option, + rename_all_attr: Option<&str>, +) -> Result { + if let Some(name) = rename_attr { + Ok(name) + } else { + let ident = v.ident.to_string(); + + match rename_all_attr { + Some("lowercase") => Ok(ident.to_ascii_lowercase()), + Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()), + Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, true)), + Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, false)), + Some("snake_case") => Ok(case::snake_case(&ident)), + None => Ok(ident), + Some(other) => Err(Error::new( + v.span(), + format!("invalid `rename_all` attribute value {other}"), + )), + } + } +}