diff --git a/static-xml-derive/src/common.rs b/static-xml-derive/src/common.rs index 3b07ad6..2c74af8 100644 --- a/static-xml-derive/src/common.rs +++ b/static-xml-derive/src/common.rs @@ -6,7 +6,7 @@ use std::{cell::RefCell, collections::BTreeMap}; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use static_xml::{de::WhiteSpace, ExpandedNameRef}; -use syn::{spanned::Spanned, DeriveInput, Fields, Lit, LitStr, Meta, MetaNameValue, NestedMeta}; +use syn::{DeriveInput, Fields, Lit, LitStr, Meta, MetaNameValue, NestedMeta}; // See serde/serde_derive/src/internals/attr.rs and yaserde_derive/src/common/field.rs @@ -249,7 +249,6 @@ pub(crate) enum TextVariantMode { pub(crate) struct ElementAttr<'a> { pub(crate) name: Name<'a>, pub(crate) namespaces: Namespaces, - pub(crate) direct: bool, // TODO: rename_all. } @@ -382,7 +381,6 @@ impl<'a> ElementAttr<'a> { let mut name = Name::from_ident(&input.ident); let mut namespaces = Namespaces::default(); let mut prefix_nv = None; - let mut direct = false; for item in get_meta_items(errors, &input.attrs) { match item { NestedMeta::Meta(Meta::NameValue(nv)) => { @@ -403,28 +401,13 @@ impl<'a> ElementAttr<'a> { errors.push(syn::Error::new_spanned(nv, "item not understood")); } } - NestedMeta::Meta(Meta::Path(p)) => { - if let Some(id) = p.get_ident() { - if id == "direct" { - direct = true; - } else { - errors.push(syn::Error::new_spanned(p, "item not understood")); - } - } else { - errors.push(syn::Error::new_spanned(p, "item not understood")); - } - } i => errors.push(syn::Error::new_spanned(i, "item not understood")), } } if let Some(ref nv) = prefix_nv { namespaces.process_prefix(errors, nv, &mut name); } - ElementAttr { - name, - namespaces, - direct, - } + ElementAttr { name, namespaces } } } @@ -506,20 +489,6 @@ impl<'a> ElementStruct<'a> { text_field_pos, }) } - - pub(crate) fn quote_flatten_fields(&self) -> Vec { - self.fields - .iter() - .filter_map(|f| { - if matches!(f.mode, ElementFieldMode::Flatten) { - let ident = &f.inner.ident; - Some(quote_spanned! { f.inner.ident.span() => &mut self.#ident }) - } else { - None - } - }) - .collect() - } } const STATIC_XML: &str = "static_xml"; @@ -553,20 +522,10 @@ pub(crate) enum ElementFieldMode { Flatten, } -impl ElementFieldMode { - pub(crate) fn quote_deserialize_trait(self) -> TokenStream { - match self { - ElementFieldMode::Element { .. } => quote! { ::static_xml::de::DeserializeField }, - ElementFieldMode::Attribute { .. } => quote! { ::static_xml::de::DeserializeAttr }, - ElementFieldMode::Text => panic!("text is different"), - ElementFieldMode::Flatten => quote! { ::static_xml::de::DeserializeFlatten }, - } - } -} - /// Field within an `ElementStruct`. pub(crate) struct ElementField<'a> { pub(crate) inner: &'a syn::Field, + pub(crate) ident: &'a syn::Ident, pub(crate) mode: ElementFieldMode, pub(crate) default: bool, pub(crate) name: Name<'a>, @@ -579,8 +538,8 @@ impl<'a> ElementField<'a> { let mut default = false; let mut flatten = false; let mut text = false; - let mut name = - Name::from_ident(inner.ident.as_ref().expect("struct fields should be named")); + let ident = inner.ident.as_ref().expect("struct fields should be named"); + let mut name = Name::from_ident(ident); for item in get_meta_items(errors, &inner.attrs) { match &item { NestedMeta::Meta(Meta::Path(p)) if p.is_ident("attribute") => attribute = true, @@ -631,6 +590,7 @@ impl<'a> ElementField<'a> { }; Ok(ElementField { inner, + ident, default, mode, name, diff --git a/static-xml-derive/src/deserialize.rs b/static-xml-derive/src/deserialize.rs index 42a738b..b52fa5c 100644 --- a/static-xml-derive/src/deserialize.rs +++ b/static-xml-derive/src/deserialize.rs @@ -7,20 +7,46 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use syn::{spanned::Spanned, Data}; -use crate::common::{ElementEnum, ElementFieldMode, ElementStruct, Errors}; +use crate::common::{ElementEnum, ElementField, ElementFieldMode, ElementStruct, Errors}; + +pub(crate) fn quote_flatten_visitors(struct_: &ElementStruct) -> Vec { + struct_ + .fields + .iter() + .filter_map(|f| { + if matches!(f.mode, ElementFieldMode::Flatten) { + let field_visitor = format_ident!("{}_visitor", f.ident); + Some(quote_spanned! { f.inner.ident.span() => &mut self.#field_visitor }) + } else { + None + } + }) + .collect() +} + +fn field_mut(field: &ElementField) -> TokenStream { + let ident = field.ident; + quote! { unsafe { ::std::ptr::addr_of_mut!((*self.out).#ident) } } +} fn visitor_field_definitions(struct_: &ElementStruct) -> Vec { struct_ .fields .iter() - .map(|field| { - let field_name = field.inner.ident.as_ref().unwrap(); + .filter_map(|field| { + let span = field.inner.span(); + let field_present = format_ident!("{}_present", field.ident); + let field_visitor = format_ident!("{}_visitor", field.ident); let ty = &field.inner.ty; match field.mode { - ElementFieldMode::Text => quote! { _text_buf: String }, - _ => { - let trait_ = field.mode.quote_deserialize_trait(); - quote_spanned! { field.inner.span() => #field_name: <#ty as #trait_>::Builder } + ElementFieldMode::Text => None, + ElementFieldMode::Flatten => Some(quote_spanned! {span=> + #field_visitor: <#ty as ::static_xml::de::RawDeserialize<'out>>::Visitor + }), + ElementFieldMode::Element { .. } | ElementFieldMode::Attribute { .. } => { + Some(quote_spanned! {span=> + #field_present: bool + }) } } }) @@ -32,14 +58,28 @@ fn visitor_initializers(struct_: &ElementStruct) -> Vec { .fields .iter() .map(|field| { - let field_name = field.inner.ident.as_ref().unwrap(); - let ty = &field.inner.ty; + let span = field.inner.span(); match field.mode { - ElementFieldMode::Text => quote! { _text_buf: String::new() }, + ElementFieldMode::Flatten => { + let field_visitor = format_ident!("{}_visitor", field.ident); + let fident = field.ident; + let ty = &field.inner.ty; + quote_spanned! {span=> + #field_visitor: <#ty as ::static_xml::de::RawDeserialize>::Visitor::new( + // SAFETY: out points to valid, uninitialized memory. + unsafe { + &mut *( + ::std::ptr::addr_of_mut!((*out.as_mut_ptr()).#fident) + as *mut ::std::mem::MaybeUninit<#ty> + ) + } + ) + } + } _ => { - let trait_ = field.mode.quote_deserialize_trait(); - quote_spanned! { - field.inner.span() => #field_name: <#ty as #trait_>::init() + let field_present = format_ident!("{}_present", field.ident); + quote_spanned! {span=> + #field_present: false } } } @@ -47,36 +87,60 @@ fn visitor_initializers(struct_: &ElementStruct) -> Vec { .collect() } -fn value_fields_from_visitor(struct_: &ElementStruct) -> Vec { +fn finalize_visitor_fields(struct_: &ElementStruct) -> Vec { struct_.fields.iter().map(|field| { - let field_name = field.inner.ident.as_ref().unwrap(); + let span = field.inner.span(); let ty = &field.inner.ty; - let default = if field.default { - quote! { Some(Default::default) } - } else { - quote! { None } - }; + let default = field.default.then(|| { + Some(quote_spanned! {span=> + <#ty as ::std::default::Default>::default + }) + }); match field.mode { ElementFieldMode::Element { sorted_elements_pos: p } => { - quote_spanned! { - field.inner.span() => #field_name: <#ty as ::static_xml::de::DeserializeField>::finalize(self.#field_name, &ELEMENTS[#p], #default)? + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); + let value = match default { + Some(d) => quote! { #d() }, + None => quote_spanned! {span=> + <#ty as ::static_xml::de::DeserializeElementField>::missing(&ELEMENTS[#p])? + } + }; + quote_spanned! {span=> + if !self.#field_present { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { ::std::ptr::write(#field_mut as *mut #ty, #value) }; + } } } ElementFieldMode::Attribute { sorted_attributes_pos: p } => { - quote_spanned! { - field.inner.span() => #field_name: <#ty as ::static_xml::de::DeserializeAttr>::finalize(self.#field_name, &ATTRIBUTES[#p], #default)? + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); + let value = match default { + Some(d) => quote! { #d() }, + None => quote_spanned! {span=> + <#ty as ::static_xml::de::DeserializeAttrField>::missing(&ATTRIBUTES[#p])? + } + }; + quote_spanned! {span=> + if !self.#field_present { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { ::std::ptr::write(#field_mut /*as *mut #ty*/, #value) }; + } } } ElementFieldMode::Flatten => { - quote_spanned! { - field.inner.span() => #field_name: <#ty as ::static_xml::de::DeserializeFlatten>::finalize(self.#field_name)? - } - } - ElementFieldMode::Text => { - quote_spanned! { - field.inner.span() => #field_name: <#ty as ::static_xml::de::ParseText>::parse(self._text_buf)? + let field_visitor = format_ident!("{}_visitor", field.ident); + let d = if let Some(d) = default { + quote! { Some(#d) } + } else { + quote! { None } + }; + quote_spanned! {span=> + <#ty as ::static_xml::de::RawDeserialize>::Visitor::finalize(self.#field_visitor, #d)?; } } + ElementFieldMode::Text => todo!(), } }).collect() } @@ -104,10 +168,21 @@ fn attribute_match_branches(struct_: &ElementStruct) -> Vec { .enumerate() .map(|(i, &p)| { let field = &struct_.fields[p]; - let ident = field.inner.ident.as_ref().unwrap(); + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); quote! { + Some(#i) if self.#field_present => { + Err(::static_xml::de::VisitorError::duplicate_attribute(name)) + } Some(#i) => { - ::static_xml::de::DeserializeAttrBuilder::attr(&mut self.#ident, name, value)?; + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { + ::std::ptr::write( + #field_mut, + ::static_xml::de::DeserializeAttrField::init(value)?, + ); + } + self.#field_present = true; Ok(None) } } @@ -122,12 +197,26 @@ fn element_match_branches(struct_: &ElementStruct) -> Vec { .enumerate() .map(|(i, &p)| { let field = &struct_.fields[p]; - let ident = field.inner.ident.as_ref().unwrap(); - let span = ident.span(); + let span = field.inner.span(); + let field_present = format_ident!("{}_present", &field.ident); + let field_mut = field_mut(field); quote_spanned! {span=> - Some(#i) => { - ::static_xml::de::DeserializeFieldBuilder::element(&mut self.#ident, child)?; - return Ok(None) + Some(#i) if self.#field_present => { + ::static_xml::de::DeserializeElementField::update( + // SAFETY: the field is initialized when field_present is true. + unsafe { &mut *#field_mut }, + child, + )?; + Ok(None) + } + Some(#i) => unsafe { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + ::std::ptr::write( + #field_mut, + ::static_xml::de::DeserializeElementField::init(child)?, + ); + self.#field_present = true; + Ok(None) } } }) @@ -141,79 +230,21 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { let element_match_branches = element_match_branches(&struct_); let ident = &struct_.input.ident; - let (impl_generics, ty_generics, where_clause) = struct_.input.generics.split_for_impl(); - let (visitor_type, visitor_defs, finalize, self_asserts); - if struct_.attr.direct { - visitor_type = struct_.input.ident.clone(); - visitor_defs = struct_ - .fields - .iter() - .filter_map(|f| { - let fident = f.inner.ident.as_ref().unwrap(); - let span = fident.span(); - let ty = &f.inner.ty; - let assert_ident = format_ident!("{}_{}_Assertion", ident, fident); - let trait_ = match f.mode { - ElementFieldMode::Attribute { .. } => { - quote! { ::static_xml::de::DeserializeAttrBuilder } - } - ElementFieldMode::Element { .. } => { - quote! { ::static_xml::de::DeserializeFieldBuilder } - } - _ => return None, - }; - Some(quote_spanned! {span=> - // https://docs.rs/quote/1.0.10/quote/macro.quote_spanned.html#example - #[allow(non_camel_case_types)] - struct #assert_ident where #ty: #trait_; - }) - }) - .collect(); - finalize = quote! { Ok(builder) }; - let assert_ident = format_ident!("{}_Assertion", ident); - self_asserts = quote_spanned! {ident.span()=> - // https://docs.rs/quote/1.0.10/quote/macro.quote_spanned.html#example - #[allow(non_camel_case_types)] - struct #assert_ident where #ident: Default; - }; - } else { - visitor_type = format_ident!("{}Visitor", &struct_.input.ident); - let visitor_field_definitions = visitor_field_definitions(&struct_); - let visitor_initializers = visitor_initializers(&struct_); - let value_fields_from_visitor = value_fields_from_visitor(&struct_); - visitor_defs = quote! { - pub struct #visitor_type { - #(#visitor_field_definitions, )* - }; - - impl Default for #visitor_type { - fn default() -> Self { - Self { - #(#visitor_initializers, )* - } - } - } - - impl #visitor_type { - fn finalize(self) -> Result<#ident, ::static_xml::de::VisitorError> { - Ok(#ident { #(#value_fields_from_visitor, )* }) - } - } - }; - finalize = quote! { builder.finalize() }; - self_asserts = TokenStream::new(); - } - let flatten_fields = struct_.quote_flatten_fields(); + let visitor_type = format_ident!("{}Visitor", &struct_.input.ident); + let visitor_field_definitions = visitor_field_definitions(&struct_); + let visitor_initializers = visitor_initializers(&struct_); + let finalize_visitor_fields = finalize_visitor_fields(&struct_); + let flatten_visitors = quote_flatten_visitors(&struct_); let (attribute_fallthrough, element_fallthrough); - if flatten_fields.is_empty() { + if flatten_visitors.is_empty() { attribute_fallthrough = quote! { Ok(Some(value)) }; element_fallthrough = quote! { Ok(Some(child)) }; } else { attribute_fallthrough = quote! { - ::static_xml::de::delegate_attribute(&mut [#(#flatten_fields),*], name, value) + ::static_xml::de::delegate_attribute(&mut [#(#flatten_visitors),*], name, value) }; element_fallthrough = quote! { - ::static_xml::de::delegate_element(&mut [#(#flatten_fields),*], child) + ::static_xml::de::delegate_element(&mut [#(#flatten_visitors),*], child) }; } let visitor_characters = if struct_.text_field_pos.is_some() { @@ -223,10 +254,10 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { Ok(None) } } - } else if !flatten_fields.is_empty() { + } else if !flatten_visitors.is_empty() { quote! { fn characters(&mut self, s: String, p: ::static_xml::TextPosition) -> Result, ::static_xml::BoxedStdError> { - ::static_xml::de::delegate_characters(&mut [#(#flatten_fields),*], s, p) + ::static_xml::de::delegate_characters(&mut [#(#flatten_visitors),*], s, p) } } } else { @@ -236,10 +267,21 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { const ATTRIBUTES: &[::static_xml::ExpandedNameRef] = &[#(#attributes,)*]; const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - #self_asserts; - #visitor_defs; + // If there's an underlying field named e.g. `foo_`, then there will be + // a generated field name e.g. `foo__required` or `foo__visitor`. Don't + // complain about this. + #[allow(non_snake_case)] + struct #visitor_type<'out> { + // This can't be &mut MaybeUninit<#type> because flattened fields + // (if any) get a &mut MaybeUninit<#type> that aliases it. So + // use a raw pointer and addr_of! to access individual fields, and + // PhantomData for the lifetime. + out: *mut #ident, + _phantom: ::std::marker::PhantomData<&'out mut #ident>, + #(#visitor_field_definitions, )* + } - impl ::static_xml::de::ElementVisitor for #visitor_type { + impl<'out> ::static_xml::de::ElementVisitor for #visitor_type<'out> { fn element<'a>( &mut self, child: ::static_xml::de::ElementReader<'a>, @@ -264,178 +306,72 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { #visitor_characters } - impl #impl_generics ::static_xml::de::Deserialize for #ident #ty_generics - #where_clause { - fn deserialize( - element: ::static_xml::de::ElementReader<'_>, - ) -> Result { - let mut builder = #visitor_type::default(); - element.read_to(&mut builder)?; - #finalize - } - } + unsafe impl<'out> ::static_xml::de::RawDeserializeVisitor<'out> for #visitor_type<'out> { + type Out = #ident; - impl #impl_generics ::static_xml::de::DeserializeFlatten for #ident #ty_generics - #where_clause { - type Builder = #visitor_type; - - fn init() -> Self::Builder { - #visitor_type::default() + fn new(out: &'out mut ::std::mem::MaybeUninit) -> Self { + Self { + out: out.as_mut_ptr(), + _phantom: ::std::marker::PhantomData, + #(#visitor_initializers, )* + } } - fn finalize(builder: Self::Builder) -> Result { - #finalize + fn finalize(self, _default: Option Self::Out>) -> Result<(), ::static_xml::de::VisitorError> { + // Note _default is currently unsupported for structs. + + #(#finalize_visitor_fields)* + // SAFETY: returning `Ok` guarantees `self.out` is fully initialized. + // finalize_visitor_fields has guaranteed that each field is fully initialized. + // The padding doesn't matter. This is similar to this example + // https://doc.rust-lang.org/stable/std/mem/union.MaybeUninit.html#initializing-a-struct-field-by-field + Ok(()) } } + + ::static_xml::impl_deserialize_via_raw!(#ident, #visitor_type); } } -fn do_enum_indirect(enum_: &ElementEnum) -> TokenStream { +fn do_enum(enum_: &ElementEnum) -> TokenStream { let ident = &enum_.input.ident; + let visitor_type = format_ident!("{}Visitor", &enum_.input.ident); let elements: Vec<_> = enum_ .variants .iter() .map(|v| v.name.quote_expanded()) .collect(); - let visitor_variants: Vec = enum_ + let initialized_match_arms: Vec = enum_ .variants - .iter() - .map(|v| { - let vident = v.ident; - match v.ty { - None => quote_spanned! {vident.span()=> #vident }, - Some(ty) => quote_spanned! { - vident.span() => - #vident(<#ty as ::static_xml::de::DeserializeField>::Builder) - }, - } - }) - .collect(); - let finalize_match_arms: Vec = enum_.variants .iter() .enumerate() .map(|(i, v)| { let vident = v.ident; + let span = vident.span(); + match v.ty { - None => quote_spanned! { vident.span() => VisitorInner::#vident => { Ok(#ident::#vident) } }, - Some(ty) => { - quote_spanned! { - vident.span() => - VisitorInner::#vident(builder) => { - <#ty as ::static_xml::de::DeserializeField>::finalize( - builder, - &ELEMENTS[#i], // unused? - Some(|| unreachable!()), - ).map(#ident::#vident) + None => { + quote_spanned! {span=> + #ident::#vident if element_i == Some(#i) => { + return Ok(None); } + #ident::#vident => #i } } - } - }) - .collect(); - let element_match_arms: Vec = enum_ - .variants - .iter() - .enumerate() - .map(|(i, v)| { - let vident = v.ident; - match v.ty { - None => quote_spanned! {vident.span()=> - Some(VisitorInner::#vident) if i == Some(#i) => return Ok(None), - Some(VisitorInner::#vident) if i.is_some() => #i, - None if i == Some(#i) => { - self.0 = Some(VisitorInner::#vident); - return Ok(None); - } - }, - Some(ty) => quote_spanned! {vident.span()=> - Some(VisitorInner::#vident(builder)) if i == Some(#i) => { - ::static_xml::de::DeserializeFieldBuilder::element(builder, child)?; - return Ok(None); - } - Some(VisitorInner::#vident(_)) if i.is_some() => #i, - None if i == Some(#i) => { - let mut builder = <#ty as ::static_xml::de::DeserializeField>::init(); - ::static_xml::de::DeserializeFieldBuilder::element(&mut builder, child)?; - self.0 = Some(VisitorInner::#vident(builder)); - return Ok(None); + Some(_) => { + quote_spanned! {span=> + #ident::#vident(f) if element_i == Some(#i) => { + ::static_xml::de::DeserializeElementField::update(f, child)?; + return Ok(None); + } + #ident::#vident(_) => #i } - }, + } } }) .collect(); - - quote! { - const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - - enum VisitorInner { - #(#visitor_variants,)* - } - - impl VisitorInner { - fn finalize(self) -> Result<#ident, ::static_xml::de::VisitorError> { - match self { - #(#finalize_match_arms,)* - } - } - } - - pub struct Visitor(Option); - - impl ::static_xml::de::ElementVisitor for Visitor { - fn element<'a>( - &mut self, - mut child: ::static_xml::de::ElementReader<'a>, - ) -> Result>, ::static_xml::de::VisitorError> { - let name = child.expanded_name(); - let i = ::static_xml::de::find(&name, ELEMENTS); - let expected_i = match &mut self.0 { - #(#element_match_arms)* - _ => return Ok(Some(child)) - }; - Err(::static_xml::de::VisitorError::unexpected_element( - &name, - &ELEMENTS[expected_i], - )) - } - } - - impl ::static_xml::de::Deserialize for #ident { - fn deserialize(element: ::static_xml::de::ElementReader<'_>) -> Result { - let mut visitor: Visitor = Visitor(None); - element.read_to(&mut visitor)?; - visitor - .0 - .ok_or_else(|| static_xml::de::VisitorError::cant_be_empty(stringify!(#ident)))? - .finalize() - } - } - - impl ::static_xml::de::DeserializeFlatten for #ident { - type Builder = Visitor; - - fn init() -> Self::Builder { Visitor(None) } - fn finalize(builder: Self::Builder) -> Result { - builder - .0 - .ok_or_else(|| static_xml::de::VisitorError::cant_be_empty(stringify!(#ident)))? - .finalize() - } - } - - // can't derive DeserializeFlatten for Option<#ident> and Vec<#ident> because of the - // orphan rule. :-( - } -} - -fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { - let ident = &enum_.input.ident; - let elements: Vec<_> = enum_ + let uninitialized_match_arms: Vec = enum_ .variants - .iter() - .map(|v| v.name.quote_expanded()) - .collect(); - let match_arms: Vec = enum_.variants .iter() .enumerate() .map(|(i, v)| { @@ -443,19 +379,13 @@ fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { match v.ty { None => { quote_spanned! { - vident.span() => Some(#i) => { *self = #ident::#vident; } + vident.span() => Some(#i) => { #ident::#vident } } } - Some(ty) => { + Some(_) => { quote_spanned! { vident.span() => Some(#i) => { - let mut builder = <#ty as ::static_xml::de::DeserializeField>::init(); - ::static_xml::de::DeserializeFieldBuilder::element(&mut builder, child)?; - *self = #ident::#vident(<#ty as ::static_xml::de::DeserializeField>::finalize( - builder, - &ELEMENTS[#i], // unused? - Some(|| unreachable!()), - )?); + #ident::#vident(::static_xml::de::DeserializeElementField::init(child)?) } } } @@ -463,55 +393,85 @@ fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { }) .collect(); + // #ident and #visitor_type's visibilities must exactly match due to a trait reference cycle: + // * #ident refers to #visitor_type via RawDeserialize::Visitor + // * #visitor_type refers to #ident via RawDeserializeVisitor::Out + let vis = &enum_.input.vis; + quote! { const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - impl ::static_xml::de::ElementVisitor for #ident { + #vis struct #visitor_type<'out> { + out: &'out mut ::std::mem::MaybeUninit<#ident>, + initialized: bool, + } + + impl<'out> ::static_xml::de::ElementVisitor for #visitor_type<'out> { fn element<'a>( &mut self, mut child: ::static_xml::de::ElementReader<'a>, ) -> Result>, ::static_xml::de::VisitorError> { let name = child.expanded_name(); - match ::static_xml::de::find(&name, ELEMENTS) { - #(#match_arms,)* - _ => return Ok(Some(child)), + let element_i = ::static_xml::de::find(&name, ELEMENTS); + if self.initialized { + // SAFETY: self.out is initialized when self.initialized is true. + let expected_i = match unsafe { self.out.assume_init_mut() } { + #(#initialized_match_arms,)* + }; + if let Some(element_i) = element_i { + return Err(::static_xml::de::VisitorError::unexpected_element( + &name, + &ELEMENTS[expected_i], + )); + } + return Ok(None); } + self.out.write(match element_i { + #(#uninitialized_match_arms,)* + _ => return Ok(Some(child)), + }); + self.initialized = true; Ok(None) } } - impl ::static_xml::de::Deserialize for #ident { - fn deserialize(element: ::static_xml::de::ElementReader<'_>) -> Result { - let mut visitor: #ident = Default::default(); - element.read_to(&mut visitor)?; - Ok(visitor) - } - } + unsafe impl<'out> ::static_xml::de::RawDeserializeVisitor<'out> for #visitor_type<'out> { + type Out = #ident; - impl ::static_xml::de::DeserializeFlatten for #ident { - type Builder = #ident; + fn new(out: &'out mut ::std::mem::MaybeUninit) -> Self { + Self { + out, + initialized: false, + } + } - fn init() -> Self::Builder { Default::default() } - fn finalize(builder: Self::Builder) -> Result { - Ok(builder) + fn finalize( + self, + default: Option Self::Out>, + ) -> Result<(), ::static_xml::de::VisitorError> { + if !self.initialized { + if let Some(d) = default { + self.out.write(d()); + } else { + return Err(static_xml::de::VisitorError::cant_be_empty(stringify!(#ident))); + } + } + // SAFETY: returning `Ok` guarantees `self.out` is fully initialized. + Ok(()) } } - // can't derive DeserializeFlatten for Option<#ident> and Vec<#ident> because of the - // orphan rule. :-( + impl<'out> ::static_xml::de::RawDeserialize<'out> for #ident { + type Visitor = #visitor_type<'out>; + } + + ::static_xml::impl_deserialize_via_raw!(#ident, #visitor_type); } } pub(crate) fn derive(errors: &Errors, input: syn::DeriveInput) -> Result { match input.data { - Data::Enum(ref data) => { - let enum_ = ElementEnum::new(&errors, &input, data); - if enum_.attr.direct { - Ok(do_enum_direct(&enum_)) - } else { - Ok(do_enum_indirect(&enum_)) - } - } + Data::Enum(ref data) => Ok(do_enum(&ElementEnum::new(&errors, &input, data))), Data::Struct(ref data) => ElementStruct::new(&errors, &input, data).map(|s| do_struct(&s)), _ => { errors.push(syn::Error::new_spanned( diff --git a/static-xml/src/de/mod.rs b/static-xml/src/de/mod.rs index 8f44082..d6d204f 100644 --- a/static-xml/src/de/mod.rs +++ b/static-xml/src/de/mod.rs @@ -3,8 +3,8 @@ //! Deserialization from XML to Rust types. -use std::fmt::Write; use std::sync::Arc; +use std::{fmt::Write, mem::MaybeUninit}; use log::trace; use xml::{ @@ -52,6 +52,16 @@ impl VisitorError { )))) } + // xml-rs might detect this anyway, but static-xml-derive shouldn't rely + // on that for avoiding memory leaks, and it needs an error to return. + #[doc(hidden)] + pub fn duplicate_attribute(attribute: &ExpandedNameRef) -> Self { + Self::Wrap(Box::new(SimpleError(format!( + "Duplicate attribute {}", + attribute + )))) + } + #[doc(hidden)] pub fn duplicate_element(element: &ExpandedNameRef) -> Self { Self::Wrap(Box::new(SimpleError(format!( @@ -672,31 +682,78 @@ pub trait Deserialize: Sized { /// `minOccurs="0" maxOccurs="1". /// 3. `Vec`, for repeated fields. In XML Schema terms, /// `minOccurs="0" maxOccurs="unbounded"`. -pub trait DeserializeField: Sized { - type Builder: DeserializeFieldBuilder; - - fn init() -> Self::Builder; - fn finalize( - builder: Self::Builder, - expected: &ExpandedNameRef<'_>, - default: Option Self>, - ) -> Result; +pub trait DeserializeElementField: Sized { + /// Called on the first occurrence of this field's element within the parent. + fn init(element: ElementReader<'_>) -> Result; + + /// Called on subsequent occurrences of this field's element within the parent. + /// + /// `self` was previously returned by `init` and has been through zero or more prior `update` calls. + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError>; + + /// Called iff this field's element was not found within the parent. + fn missing(expected: &ExpandedNameRef<'_>) -> Result; } -/// Builder used by [`DeserializeField`]. -pub trait DeserializeFieldBuilder { - /// Handles a single occurrence of this element; called zero or more times. - fn element<'a>(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError>; +/// Deserializes an attribute into a field. +/// +/// This is implemented via [`ParseText`] as noted there. +pub trait DeserializeAttrField: Sized { + /// Called iff this field's attribute was found within the parent. + fn init(value: String) -> Result; + + /// Called iff this field's attribute was not found within the parent. + fn missing(expected: &ExpandedNameRef<'_>) -> Result; } -/// Deserializes this type when "flattened" into another. +#[doc(hidden)] +pub unsafe trait RawDeserializeVisitor<'out>: Sized + ElementVisitor { + type Out; + + /// Returns a visitor that can be used to populate `this`. + fn new(out: &'out mut MaybeUninit) -> Self; + + /// Finalizes `out`. + /// + /// An `Ok` return guarantees `out.assume_init()` is valid. + fn finalize(self, default: Option Self::Out>) -> Result<(), VisitorError>; +} + +/// Raw, unsafe implementation of [`Deserialize`], for use by the macros. /// /// With `static-xml-derive`, this can be used via `#[static_xml(flatten)]`. -pub trait DeserializeFlatten: Sized { - type Builder: ElementVisitor; +/// +/// Implementing this type automatically implements `Deserialize`. +#[doc(hidden)] +pub trait RawDeserialize<'out>: Sized { + type Visitor: RawDeserializeVisitor<'out, Out = Self>; +} - fn init() -> Self::Builder; - fn finalize(builder: Self::Builder) -> Result; +/// Implements [`Deserialize`] via [`RawDeserializeVisitor`]. +/// +/// The type opts into [`Deserialize`] via this macro rather than by +/// implementing a (hypothetical) `RawDeserialize`. The latter approach +/// doesn't work because `impl Deserialize for T` and +/// `impl Deserialize for T` would +/// [conflict](https://doc.rust-lang.org/error-index.html#E0119). +#[doc(hidden)] +#[macro_export] +macro_rules! impl_deserialize_via_raw { + ( $t:ident, $visitor:ident ) => { + impl ::static_xml::de::Deserialize for $t { + fn deserialize( + element: ::static_xml::de::ElementReader<'_>, + ) -> Result { + let mut out = ::std::mem::MaybeUninit::uninit(); + let mut visitor = + <$visitor as ::static_xml::de::RawDeserializeVisitor>::new(&mut out); + element.read_to(&mut visitor)?; + ::static_xml::de::RawDeserializeVisitor::finalize(visitor, None)?; + // SAFETY: finalize's contract guarantees assume_init is safe. + Ok(unsafe { out.assume_init() }) + } + } + }; } /// Deserializes text data, whether character nodes or attribute values. @@ -716,30 +773,6 @@ pub trait ParseText: Sized { fn parse(text: String) -> Result; } -/// Deserializes an attribute into a field. -/// -/// This is implemented via [`ParseText`] as noted there. -pub trait DeserializeAttr: Sized { - type Builder: DeserializeAttrBuilder; - - fn init() -> Self::Builder; - fn finalize( - builder: Self::Builder, - expected: &ExpandedNameRef<'_>, - default: Option Self>, - ) -> Result; -} - -/// Builder used by [`DeserializeAttr`]. -pub trait DeserializeAttrBuilder { - /// May be called zero or one time with the relevant attribute. - fn attr<'a>( - &mut self, - name: &ExpandedNameRef, - value: String, - ) -> Result, VisitorError>; -} - /// Visitor used within [`Deserialize`]. pub trait ElementVisitor { /// Processes a given attribute of this element's start tag. @@ -922,164 +955,71 @@ pub fn find(name: &ExpandedNameRef<'_>, sorted_slice: &[ExpandedNameRef<'_>]) -> sorted_slice.binary_search(name).ok() } -impl DeserializeField for T { - type Builder = Option; - - #[inline] - fn init() -> Self::Builder { - None - } - - fn finalize( - builder: Self::Builder, - expected: &ExpandedNameRef<'_>, - default: Option Self>, - ) -> Result { - if let Some(f) = builder { - Ok(f) - } else if let Some(d) = default { - Ok(d()) - } else { - Err(VisitorError::missing_element(expected)) - } +impl DeserializeElementField for T { + fn init(element: ElementReader<'_>) -> Result { + T::deserialize(element) } -} - -impl DeserializeField for Option { - type Builder = Self; - #[inline] - fn init() -> Self::Builder { - None - } - - fn finalize( - builder: Self::Builder, - _expected: &ExpandedNameRef<'_>, - default: Option Self>, - ) -> Result { - if let Some(f) = builder { - Ok(Some(f)) - } else if let Some(d) = default { - Ok(d()) - } else { - Ok(None) - } + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { + Err(VisitorError::duplicate_element(&element.expanded_name())) } -} -impl DeserializeFieldBuilder for T { - #[inline] - fn element<'a>(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { - *self = T::deserialize(element)?; - Ok(()) + fn missing(expected: &ExpandedNameRef<'_>) -> Result { + Err(VisitorError::missing_element(expected)) } } -impl DeserializeFieldBuilder for Option { - fn element<'a>(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { - if self.is_some() { - return Err(VisitorError::duplicate_element(&element.expanded_name())); - } - *self = Some(T::deserialize(element)?); - Ok(()) +impl DeserializeElementField for Option { + fn init(element: ElementReader<'_>) -> Result { + T::deserialize(element).map(Some) } -} - -/// Deserializes into a `Vec`, adding an element. -impl DeserializeField for Vec { - type Builder = Self; - #[inline] - fn init() -> Self::Builder { - Vec::new() + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { + Err(VisitorError::duplicate_element(&element.expanded_name())) } - fn finalize( - builder: Self::Builder, - _expected: &ExpandedNameRef<'_>, - _default: Option Self>, - ) -> Result { - Ok(builder) + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { + Ok(None) } } -impl DeserializeFieldBuilder for Vec { - fn element<'a>(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { +/// Deserializes into a `Vec`, adding an element. +impl DeserializeElementField for Vec { + fn init(element: ElementReader<'_>) -> Result { + Ok(vec![T::deserialize(element)?]) + } + + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { self.push(T::deserialize(element)?); Ok(()) } -} -impl DeserializeAttr for T { - type Builder = Option; - - #[inline] - fn init() -> Self::Builder { - None - } - - fn finalize( - builder: Self::Builder, - expected: &ExpandedNameRef, - default: Option Self>, - ) -> Result { - if let Some(b) = builder { - Ok(b) - } else if let Some(d) = default { - Ok(d()) - } else { - Err(VisitorError::missing_attribute(expected)) - } + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { + Ok(Vec::new()) } } -impl DeserializeAttr for Option { - type Builder = Option; +impl DeserializeAttrField for T { + fn init(value: String) -> Result { + Ok(T::parse(value).map_err(VisitorError::Wrap)?) + } - #[inline] - fn init() -> Self::Builder { - None - } - - fn finalize( - builder: Self::Builder, - _expected: &ExpandedNameRef, - default: Option Self>, - ) -> Result { - if let Some(b) = builder { - Ok(Some(b)) - } else if let Some(d) = default { - Ok(d()) - } else { - Ok(None) - } + fn missing(expected: &ExpandedNameRef<'_>) -> Result { + Err(VisitorError::missing_attribute(expected)) } } -impl DeserializeAttrBuilder for T { - fn attr<'a>( - &mut self, - _name: &ExpandedNameRef, - value: String, - ) -> Result, VisitorError> { - *self = T::parse(value).map_err(VisitorError::Wrap)?; - Ok(None) +impl DeserializeAttrField for Option { + fn init(value: String) -> Result { + Ok(Some(T::parse(value).map_err(VisitorError::Wrap)?)) } -} -impl DeserializeAttrBuilder for Option { - fn attr<'a>( - &mut self, - _name: &ExpandedNameRef, - value: String, - ) -> Result, VisitorError> { - debug_assert!(self.is_none()); - *self = Some(T::parse(value).map_err(VisitorError::Wrap)?); + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { Ok(None) } } + impl ParseText for bool { fn parse(text: String) -> Result { // [https://www.w3.org/TR/xmlschema11-2/#boolean] "For all ·atomic· datatypes other than @@ -1132,7 +1072,7 @@ impl Deserialize for T { } } -#[cfg(test)] +/*#[cfg(test)] mod tests { use super::*; @@ -1154,23 +1094,29 @@ mod tests { } } + /// An element which expects a single attribute of arbitrary name and + /// value type `T`. #[derive(Debug, Default, Eq, PartialEq)] - struct AttrWrapper(T); + struct AttrWrapper(T); - struct AttrWrapperVisitor(T::Builder); + struct AttrWrapperVisitor(T); - impl ElementVisitor for AttrWrapperVisitor { + impl ElementVisitor for AttrWrapperVisitor { fn attribute( &mut self, name: &ExpandedNameRef<'_>, value: String, ) -> Result, VisitorError> { - self.0.attr(name, value) + if self.0.is_some() { + return Err(VisitorError::duplicate_attribute(name)); + } + self.0 = Some(T::init(value)?); + Ok(None) } } - impl Deserialize for AttrWrapper { + impl Deserialize for AttrWrapper { fn deserialize(element: ElementReader<'_>) -> Result { - let mut visitor = AttrWrapperVisitor::(T::init()); + let mut visitor = AttrWrapperVisitor(None); element.read_to(&mut visitor)?; Ok(AttrWrapper(T::finalize( visitor.0, @@ -1261,4 +1207,4 @@ mod tests { } // TODO: test exercising return_to_depth. -} +}*/ diff --git a/test-suite/tests/basic.rs b/test-suite/tests/basic.rs index ca6a3f6..8b64a07 100644 --- a/test-suite/tests/basic.rs +++ b/test-suite/tests/basic.rs @@ -3,7 +3,7 @@ use static_xml_derive::{Deserialize, ParseText, Serialize, ToText}; -#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)] #[static_xml( namespace = "foo: http://example.com/foo", namespace = "bar: http://example.com/bar", @@ -16,9 +16,8 @@ struct Foo { #[static_xml(prefix = "bar", rename = "blah")] string: Vec, - #[static_xml(flatten)] - bar: Bar, - + //#[static_xml(flatten)] + //bar: Bar, text: String, constrained: ConstrainedString, @@ -35,6 +34,11 @@ enum ConstrainedString { #[static_xml(rename = "BAZ")] Baz, } +impl Default for ConstrainedString { + fn default() -> Self { + ConstrainedString::Foo + } +} #[derive(Debug, ParseText, Eq, PartialEq, ToText)] #[static_xml(mode = "restriction", whitespace = "collapse")] @@ -55,6 +59,11 @@ enum MyChoice { Baz(String), UnitValue, } +impl Default for MyChoice { + fn default() -> Self { + MyChoice::UnitValue + } +} #[derive(Debug, Eq, PartialEq, ParseText, ToText)] #[static_xml(mode = "union")] @@ -64,7 +73,7 @@ enum MyUnion { String(String), } -#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Debug, Default, Deserialize, Eq, PartialEq, Serialize)] struct Bar { more: String, } @@ -91,9 +100,9 @@ fn deserialize() { Foo { mybool: true, string: vec!["foo".to_owned(), "bar".to_owned()], - bar: Bar { - more: "more".to_owned() - }, + //bar: Bar { + // more: "more".to_owned() + //}, text: "asdf".to_owned(), constrained: ConstrainedString::Foo, choice: MyChoice::Foo("blah".to_owned()), @@ -107,9 +116,9 @@ fn round_trip() { let original = Foo { mybool: true, string: vec!["foo".to_owned(), "bar".to_owned()], - bar: Bar { - more: "more".to_owned(), - }, + //bar: Bar { + // more: "more".to_owned(), + //}, text: "asdf".to_owned(), constrained: ConstrainedString::Foo, choice: MyChoice::Foo("blah".to_owned()), diff --git a/test-suite/tests/element_enum.rs b/test-suite/tests/element_enum.rs index 8e78ced..c8d1d5b 100644 --- a/test-suite/tests/element_enum.rs +++ b/test-suite/tests/element_enum.rs @@ -3,52 +3,19 @@ use static_xml_derive::{Deserialize, Serialize}; -#[derive(Default, Debug, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -struct DirectHolder { - #[static_xml(flatten)] - direct_enum: DirectEnum, - #[static_xml(flatten)] - other_flatten: OtherDirectFlatten, -} - #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -enum DirectEnum { - Simple(String), - Vec(Vec), - Unit, - #[static_xml(skip)] - Skipped(String), -} -impl Default for DirectEnum { - fn default() -> Self { - DirectEnum::Skipped("default".to_owned()) - } -} - -#[derive(Debug, Default, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -struct OtherDirectFlatten { - field: Vec, -} - -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -struct IndirectHolder { - #[static_xml(flatten)] - indirect_enum: IndirectEnum, +struct Holder { #[static_xml(flatten)] - other_flatten: OtherIndirectFlatten, + enum_: Enum, + //#[static_xml(flatten)] + //other_flatten: OtherFlatten, } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -enum IndirectEnum { +enum Enum { Simple(String), Vec(Vec), Unit, - #[allow(dead_code)] - #[static_xml(skip)] - Skipped, } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] @@ -57,9 +24,9 @@ struct OtherIndirectFlatten { } #[test] -fn deserialize_indirect_simple() { +fn deserialize_simple() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let indirect: IndirectHolder = static_xml::de::from_str( + let holder: Holder = static_xml::de::from_str( r#" @@ -71,20 +38,20 @@ fn deserialize_indirect_simple() { ) .unwrap(); assert_eq!( - indirect, - IndirectHolder { - indirect_enum: IndirectEnum::Simple("asdf".to_owned()), - other_flatten: OtherIndirectFlatten { + holder, + Holder { + enum_: Enum::Simple("asdf".to_owned()), + /*other_flatten: OtherIndirectFlatten { field: vec!["before".to_owned(), "after".to_owned()] - }, + },*/ } ); } #[test] -fn deserialize_indirect_vec() { +fn deserialize_vec() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let indirect: IndirectHolder = static_xml::de::from_str( + let holder: Holder = static_xml::de::from_str( r#" @@ -97,20 +64,20 @@ fn deserialize_indirect_vec() { ) .unwrap(); assert_eq!( - indirect, - IndirectHolder { - indirect_enum: IndirectEnum::Vec(vec!["foo".to_owned(), "bar".to_owned()]), - other_flatten: OtherIndirectFlatten { + holder, + Holder { + enum_: Enum::Vec(vec!["foo".to_owned(), "bar".to_owned()]), + /*other_flatten: OtherIndirectFlatten { field: vec!["before".to_owned(), "after".to_owned()] - }, + },*/ } ); } #[test] -fn deserialize_indirect_mix_error() { +fn deserialize_mix_error() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let e = static_xml::de::from_str::( + let e = static_xml::de::from_str::( r#"