diff --git a/static-xml-derive/src/common.rs b/static-xml-derive/src/common.rs index 97753be..3a62a9d 100644 --- a/static-xml-derive/src/common.rs +++ b/static-xml-derive/src/common.rs @@ -6,7 +6,7 @@ use std::{cell::RefCell, collections::BTreeMap}; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use static_xml::{de::WhiteSpace, ExpandedNameRef}; -use syn::{spanned::Spanned, DeriveInput, Fields, Lit, LitStr, Meta, MetaNameValue, NestedMeta}; +use syn::{DeriveInput, Fields, Lit, LitStr, Meta, MetaNameValue, NestedMeta}; // See serde/serde_derive/src/internals/attr.rs and yaserde_derive/src/common/field.rs @@ -213,7 +213,6 @@ pub(crate) enum TextVariantMode { pub(crate) struct ElementAttr<'a> { pub(crate) name: Name<'a>, pub(crate) namespaces: Namespaces, - pub(crate) direct: bool, // TODO: rename_all. } @@ -346,7 +345,6 @@ impl<'a> ElementAttr<'a> { let mut name = Name::from_ident(&input.ident); let mut namespaces = Namespaces::default(); let mut prefix_nv = None; - let mut direct = false; for item in get_meta_items(errors, &input.attrs) { match item { NestedMeta::Meta(Meta::NameValue(nv)) => { @@ -367,28 +365,13 @@ impl<'a> ElementAttr<'a> { errors.push(syn::Error::new_spanned(nv, "item not understood")); } } - NestedMeta::Meta(Meta::Path(p)) => { - if let Some(id) = p.get_ident() { - if id == "direct" { - direct = true; - } else { - errors.push(syn::Error::new_spanned(p, "item not understood")); - } - } else { - errors.push(syn::Error::new_spanned(p, "item not understood")); - } - } i => errors.push(syn::Error::new_spanned(i, "item not understood")), } } if let Some(ref nv) = prefix_nv { namespaces.process_prefix(errors, nv, &mut name); } - ElementAttr { - name, - namespaces, - direct, - } + ElementAttr { name, namespaces } } } @@ -470,20 +453,6 @@ impl<'a> ElementStruct<'a> { text_field_pos, }) } - - pub(crate) fn quote_flatten_fields(&self) -> Vec { - self.fields - .iter() - .filter_map(|f| { - if matches!(f.mode, ElementFieldMode::Flatten) { - let ident = &f.inner.ident; - Some(quote_spanned! { f.inner.ident.span() => &mut self.out.#ident }) - } else { - None - } - }) - .collect() - } } const STATIC_XML: &str = "static_xml"; @@ -517,20 +486,10 @@ pub(crate) enum ElementFieldMode { Flatten, } -impl ElementFieldMode { - pub(crate) fn quote_deserialize_trait(self) -> TokenStream { - match self { - ElementFieldMode::Element { .. } => quote! { ::static_xml::de::DeserializeField }, - ElementFieldMode::Attribute { .. } => quote! { ::static_xml::de::DeserializeAttr }, - ElementFieldMode::Text => panic!("text is different"), - ElementFieldMode::Flatten => quote! { ::static_xml::de::DeserializeFlatten }, - } - } -} - /// Field within an `ElementStruct`. pub(crate) struct ElementField<'a> { pub(crate) inner: &'a syn::Field, + pub(crate) ident: &'a syn::Ident, pub(crate) mode: ElementFieldMode, pub(crate) default: bool, pub(crate) name: Name<'a>, @@ -543,8 +502,8 @@ impl<'a> ElementField<'a> { let mut default = false; let mut flatten = false; let mut text = false; - let mut name = - Name::from_ident(inner.ident.as_ref().expect("struct fields should be named")); + let ident = inner.ident.as_ref().expect("struct fields should be named"); + let mut name = Name::from_ident(ident); for item in get_meta_items(errors, &inner.attrs) { match &item { NestedMeta::Meta(Meta::Path(p)) if p.is_ident("attribute") => attribute = true, @@ -595,6 +554,7 @@ impl<'a> ElementField<'a> { }; Ok(ElementField { inner, + ident, default, mode, name, diff --git a/static-xml-derive/src/deserialize.rs b/static-xml-derive/src/deserialize.rs index 0a1c94b..b52fa5c 100644 --- a/static-xml-derive/src/deserialize.rs +++ b/static-xml-derive/src/deserialize.rs @@ -7,7 +7,27 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use syn::{spanned::Spanned, Data}; -use crate::common::{ElementEnum, ElementFieldMode, ElementStruct, Errors}; +use crate::common::{ElementEnum, ElementField, ElementFieldMode, ElementStruct, Errors}; + +pub(crate) fn quote_flatten_visitors(struct_: &ElementStruct) -> Vec { + struct_ + .fields + .iter() + .filter_map(|f| { + if matches!(f.mode, ElementFieldMode::Flatten) { + let field_visitor = format_ident!("{}_visitor", f.ident); + Some(quote_spanned! { f.inner.ident.span() => &mut self.#field_visitor }) + } else { + None + } + }) + .collect() +} + +fn field_mut(field: &ElementField) -> TokenStream { + let ident = field.ident; + quote! { unsafe { ::std::ptr::addr_of_mut!((*self.out).#ident) } } +} fn visitor_field_definitions(struct_: &ElementStruct) -> Vec { struct_ @@ -15,12 +35,19 @@ fn visitor_field_definitions(struct_: &ElementStruct) -> Vec { .iter() .filter_map(|field| { let span = field.inner.span(); - let field_present = format_ident!("{}_present", &field.inner.ident.as_ref().unwrap()); + let field_present = format_ident!("{}_present", field.ident); + let field_visitor = format_ident!("{}_visitor", field.ident); + let ty = &field.inner.ty; match field.mode { ElementFieldMode::Text => None, - _ => Some(quote_spanned! {span=> - #field_present: bool + ElementFieldMode::Flatten => Some(quote_spanned! {span=> + #field_visitor: <#ty as ::static_xml::de::RawDeserialize<'out>>::Visitor }), + ElementFieldMode::Element { .. } | ElementFieldMode::Attribute { .. } => { + Some(quote_spanned! {span=> + #field_present: bool + }) + } } }) .collect() @@ -30,41 +57,90 @@ fn visitor_initializers(struct_: &ElementStruct) -> Vec { struct_ .fields .iter() - .filter_map(|field| { + .map(|field| { let span = field.inner.span(); - let field_present = format_ident!("{}_present", &field.inner.ident.as_ref().unwrap()); match field.mode { - ElementFieldMode::Text => None, - _ => Some(quote_spanned! {span=> - #field_present: false - }), + ElementFieldMode::Flatten => { + let field_visitor = format_ident!("{}_visitor", field.ident); + let fident = field.ident; + let ty = &field.inner.ty; + quote_spanned! {span=> + #field_visitor: <#ty as ::static_xml::de::RawDeserialize>::Visitor::new( + // SAFETY: out points to valid, uninitialized memory. + unsafe { + &mut *( + ::std::ptr::addr_of_mut!((*out.as_mut_ptr()).#fident) + as *mut ::std::mem::MaybeUninit<#ty> + ) + } + ) + } + } + _ => { + let field_present = format_ident!("{}_present", field.ident); + quote_spanned! {span=> + #field_present: false + } + } } }) .collect() } fn finalize_visitor_fields(struct_: &ElementStruct) -> Vec { - struct_.fields.iter().filter_map(|field| { + struct_.fields.iter().map(|field| { let span = field.inner.span(); - let field_present = format_ident!("{}_present", &field.inner.ident.as_ref().unwrap()); let ty = &field.inner.ty; + let default = field.default.then(|| { + Some(quote_spanned! {span=> + <#ty as ::std::default::Default>::default + }) + }); match field.mode { - ElementFieldMode::Element { sorted_elements_pos: p } if !field.default => { - Some(quote_spanned! {span=> - <#ty as ::static_xml::de::DeserializeField>::finalize(self.#field_present, &ELEMENTS[#p])?; - }) + ElementFieldMode::Element { sorted_elements_pos: p } => { + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); + let value = match default { + Some(d) => quote! { #d() }, + None => quote_spanned! {span=> + <#ty as ::static_xml::de::DeserializeElementField>::missing(&ELEMENTS[#p])? + } + }; + quote_spanned! {span=> + if !self.#field_present { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { ::std::ptr::write(#field_mut as *mut #ty, #value) }; + } + } } - ElementFieldMode::Attribute { sorted_attributes_pos: p } if !field.default => { - Some(quote_spanned! {span=> - <#ty as ::static_xml::de::DeserializeAttr>::finalize(self.#field_present, &ATTRIBUTES[#p])?; - }) + ElementFieldMode::Attribute { sorted_attributes_pos: p } => { + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); + let value = match default { + Some(d) => quote! { #d() }, + None => quote_spanned! {span=> + <#ty as ::static_xml::de::DeserializeAttrField>::missing(&ATTRIBUTES[#p])? + } + }; + quote_spanned! {span=> + if !self.#field_present { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { ::std::ptr::write(#field_mut /*as *mut #ty*/, #value) }; + } + } } - /*ElementFieldMode::Flatten if !field.default => { - Some(quote_spanned! {span=> - <#ty as ::static_xml::de::DeserializeFlatten>::finalize(self.#field_present)?; - }) - }*/ - _ => None + ElementFieldMode::Flatten => { + let field_visitor = format_ident!("{}_visitor", field.ident); + let d = if let Some(d) = default { + quote! { Some(#d) } + } else { + quote! { None } + }; + quote_spanned! {span=> + <#ty as ::static_xml::de::RawDeserialize>::Visitor::finalize(self.#field_visitor, #d)?; + } + } + ElementFieldMode::Text => todo!(), } }).collect() } @@ -92,11 +168,21 @@ fn attribute_match_branches(struct_: &ElementStruct) -> Vec { .enumerate() .map(|(i, &p)| { let field = &struct_.fields[p]; - let field_present = format_ident!("{}_present", &field.inner.ident.as_ref().unwrap()); - let ident = field.inner.ident.as_ref().unwrap(); + let field_present = format_ident!("{}_present", field.ident); + let field_mut = field_mut(field); quote! { + Some(#i) if self.#field_present => { + Err(::static_xml::de::VisitorError::duplicate_attribute(name)) + } Some(#i) => { - ::static_xml::de::DeserializeAttr::attr(&mut self.out.#ident, &mut self.#field_present, name, value)?; + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + unsafe { + ::std::ptr::write( + #field_mut, + ::static_xml::de::DeserializeAttrField::init(value)?, + ); + } + self.#field_present = true; Ok(None) } } @@ -111,13 +197,26 @@ fn element_match_branches(struct_: &ElementStruct) -> Vec { .enumerate() .map(|(i, &p)| { let field = &struct_.fields[p]; - let field_present = format_ident!("{}_present", &field.inner.ident.as_ref().unwrap()); - let ident = field.inner.ident.as_ref().unwrap(); - let span = ident.span(); + let span = field.inner.span(); + let field_present = format_ident!("{}_present", &field.ident); + let field_mut = field_mut(field); quote_spanned! {span=> - Some(#i) => { - ::static_xml::de::DeserializeField::element(&mut self.out.#ident, &mut self.#field_present, child)?; - return Ok(None) + Some(#i) if self.#field_present => { + ::static_xml::de::DeserializeElementField::update( + // SAFETY: the field is initialized when field_present is true. + unsafe { &mut *#field_mut }, + child, + )?; + Ok(None) + } + Some(#i) => unsafe { + // SAFETY: #field_mut is a valid pointer to uninitialized memory. + ::std::ptr::write( + #field_mut, + ::static_xml::de::DeserializeElementField::init(child)?, + ); + self.#field_present = true; + Ok(None) } } }) @@ -131,42 +230,21 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { let element_match_branches = element_match_branches(&struct_); let ident = &struct_.input.ident; - let (impl_generics, ty_generics, where_clause) = struct_.input.generics.split_for_impl(); let visitor_type = format_ident!("{}Visitor", &struct_.input.ident); let visitor_field_definitions = visitor_field_definitions(&struct_); let visitor_initializers = visitor_initializers(&struct_); let finalize_visitor_fields = finalize_visitor_fields(&struct_); - let visitor_defs = quote! { - pub struct #visitor_type<'out> { - out: &'out mut #ident, - #(#visitor_field_definitions, )* - } - - impl<'out> #visitor_type<'out> { - fn new(out: &'out mut #ident) -> Self { - Self { - out, - #(#visitor_initializers, )* - } - } - - fn finalize(self) -> Result<(), ::static_xml::de::VisitorError> { - #(#finalize_visitor_fields)* - Ok(()) - } - } - }; - let flatten_fields = struct_.quote_flatten_fields(); + let flatten_visitors = quote_flatten_visitors(&struct_); let (attribute_fallthrough, element_fallthrough); - if flatten_fields.is_empty() { + if flatten_visitors.is_empty() { attribute_fallthrough = quote! { Ok(Some(value)) }; element_fallthrough = quote! { Ok(Some(child)) }; } else { attribute_fallthrough = quote! { - ::static_xml::de::delegate_attribute(&mut [#(#flatten_fields),*], name, value) + ::static_xml::de::delegate_attribute(&mut [#(#flatten_visitors),*], name, value) }; element_fallthrough = quote! { - ::static_xml::de::delegate_element(&mut [#(#flatten_fields),*], child) + ::static_xml::de::delegate_element(&mut [#(#flatten_visitors),*], child) }; } let visitor_characters = if struct_.text_field_pos.is_some() { @@ -176,10 +254,10 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { Ok(None) } } - } else if !flatten_fields.is_empty() { + } else if !flatten_visitors.is_empty() { quote! { fn characters(&mut self, s: String, p: ::static_xml::TextPosition) -> Result, ::static_xml::BoxedStdError> { - ::static_xml::de::delegate_characters(&mut [#(#flatten_fields),*], s, p) + ::static_xml::de::delegate_characters(&mut [#(#flatten_visitors),*], s, p) } } } else { @@ -189,7 +267,19 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { const ATTRIBUTES: &[::static_xml::ExpandedNameRef] = &[#(#attributes,)*]; const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - #visitor_defs + // If there's an underlying field named e.g. `foo_`, then there will be + // a generated field name e.g. `foo__required` or `foo__visitor`. Don't + // complain about this. + #[allow(non_snake_case)] + struct #visitor_type<'out> { + // This can't be &mut MaybeUninit<#type> because flattened fields + // (if any) get a &mut MaybeUninit<#type> that aliases it. So + // use a raw pointer and addr_of! to access individual fields, and + // PhantomData for the lifetime. + out: *mut #ident, + _phantom: ::std::marker::PhantomData<&'out mut #ident>, + #(#visitor_field_definitions, )* + } impl<'out> ::static_xml::de::ElementVisitor for #visitor_type<'out> { fn element<'a>( @@ -216,180 +306,72 @@ fn do_struct(struct_: &ElementStruct) -> TokenStream { #visitor_characters } - impl #impl_generics ::static_xml::de::Deserialize for #ident #ty_generics - #where_clause { - fn deserialize( - element: ::static_xml::de::ElementReader<'_>, - ) -> Result { - let mut out = <#ident as Default>::default(); - let mut builder = #visitor_type::new(&mut out); - element.read_to(&mut builder)?; - builder.finalize()?; - Ok(out) + unsafe impl<'out> ::static_xml::de::RawDeserializeVisitor<'out> for #visitor_type<'out> { + type Out = #ident; + + fn new(out: &'out mut ::std::mem::MaybeUninit) -> Self { + Self { + out: out.as_mut_ptr(), + _phantom: ::std::marker::PhantomData, + #(#visitor_initializers, )* + } } - } - /*impl #impl_generics ::static_xml::de::DeserializeFlatten for #ident #ty_generics - #where_clause { - type Builder = #visitor_type; + fn finalize(self, _default: Option Self::Out>) -> Result<(), ::static_xml::de::VisitorError> { + // Note _default is currently unsupported for structs. - fn init() -> Self::Builder { - #visitor_type::default() + #(#finalize_visitor_fields)* + // SAFETY: returning `Ok` guarantees `self.out` is fully initialized. + // finalize_visitor_fields has guaranteed that each field is fully initialized. + // The padding doesn't matter. This is similar to this example + // https://doc.rust-lang.org/stable/std/mem/union.MaybeUninit.html#initializing-a-struct-field-by-field + Ok(()) } + } - fn finalize(builder: Self::Builder) -> Result { - #finalize - } - }*/ + ::static_xml::impl_deserialize_via_raw!(#ident, #visitor_type); } } -/*fn do_enum_indirect(enum_: &ElementEnum) -> TokenStream { +fn do_enum(enum_: &ElementEnum) -> TokenStream { let ident = &enum_.input.ident; + let visitor_type = format_ident!("{}Visitor", &enum_.input.ident); let elements: Vec<_> = enum_ .variants .iter() .map(|v| v.name.quote_expanded()) .collect(); - let visitor_variants: Vec = enum_ + let initialized_match_arms: Vec = enum_ .variants - .iter() - .map(|v| { - let vident = v.ident; - match v.ty { - None => quote_spanned! {vident.span()=> #vident }, - Some(ty) => quote_spanned! { - vident.span() => - #vident(<#ty as ::static_xml::de::DeserializeField>::Builder) - }, - } - }) - .collect(); - let finalize_match_arms: Vec = enum_.variants .iter() .enumerate() .map(|(i, v)| { let vident = v.ident; + let span = vident.span(); + match v.ty { - None => quote_spanned! { vident.span() => VisitorInner::#vident => { Ok(#ident::#vident) } }, - Some(ty) => { - quote_spanned! { - vident.span() => - VisitorInner::#vident(builder) => { - <#ty as ::static_xml::de::DeserializeField>::finalize( - builder, - &ELEMENTS[#i], // unused? - Some(|| unreachable!()), - ).map(#ident::#vident) + None => { + quote_spanned! {span=> + #ident::#vident if element_i == Some(#i) => { + return Ok(None); } + #ident::#vident => #i } } - } - }) - .collect(); - let element_match_arms: Vec = enum_ - .variants - .iter() - .enumerate() - .map(|(i, v)| { - let vident = v.ident; - match v.ty { - None => quote_spanned! {vident.span()=> - Some(VisitorInner::#vident) if i == Some(#i) => return Ok(None), - Some(VisitorInner::#vident) if i.is_some() => #i, - None if i == Some(#i) => { - self.0 = Some(VisitorInner::#vident); - return Ok(None); - } - }, - Some(ty) => quote_spanned! {vident.span()=> - Some(VisitorInner::#vident(builder)) if i == Some(#i) => { - ::static_xml::de::DeserializeFieldBuilder::element(builder, child)?; - return Ok(None); - } - Some(VisitorInner::#vident(_)) if i.is_some() => #i, - None if i == Some(#i) => { - let mut builder = <#ty as ::static_xml::de::DeserializeField>::init(); - ::static_xml::de::DeserializeFieldBuilder::element(&mut builder, child)?; - self.0 = Some(VisitorInner::#vident(builder)); - return Ok(None); + Some(_) => { + quote_spanned! {span=> + #ident::#vident(f) if element_i == Some(#i) => { + ::static_xml::de::DeserializeElementField::update(f, child)?; + return Ok(None); + } + #ident::#vident(_) => #i } - }, + } } }) .collect(); - - quote! { - const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - - enum VisitorInner { - #(#visitor_variants,)* - } - - impl VisitorInner { - fn finalize(self) -> Result<#ident, ::static_xml::de::VisitorError> { - match self { - #(#finalize_match_arms,)* - } - } - } - - pub struct Visitor(Option); - - impl ::static_xml::de::ElementVisitor for Visitor { - fn element<'a>( - &mut self, - mut child: ::static_xml::de::ElementReader<'a>, - ) -> Result>, ::static_xml::de::VisitorError> { - let name = child.expanded_name(); - let i = ::static_xml::de::find(&name, ELEMENTS); - let expected_i = match &mut self.0 { - #(#element_match_arms)* - _ => return Ok(Some(child)) - }; - Err(::static_xml::de::VisitorError::unexpected_element( - &name, - &ELEMENTS[expected_i], - )) - } - } - - impl ::static_xml::de::Deserialize for #ident { - fn deserialize(element: ::static_xml::de::ElementReader<'_>) -> Result { - let mut visitor: Visitor = Visitor(None); - element.read_to(&mut visitor)?; - visitor - .0 - .ok_or_else(|| static_xml::de::VisitorError::cant_be_empty(stringify!(#ident)))? - .finalize() - } - } - - impl ::static_xml::de::DeserializeFlatten for #ident { - type Builder = Visitor; - - fn init() -> Self::Builder { Visitor(None) } - fn finalize(builder: Self::Builder) -> Result { - builder - .0 - .ok_or_else(|| static_xml::de::VisitorError::cant_be_empty(stringify!(#ident)))? - .finalize() - } - } - - // can't derive DeserializeFlatten for Option<#ident> and Vec<#ident> because of the - // orphan rule. :-( - } -}*/ - -fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { - let ident = &enum_.input.ident; - let elements: Vec<_> = enum_ + let uninitialized_match_arms: Vec = enum_ .variants - .iter() - .map(|v| v.name.quote_expanded()) - .collect(); - let match_arms: Vec = enum_.variants .iter() .enumerate() .map(|(i, v)| { @@ -397,16 +379,13 @@ fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { match v.ty { None => { quote_spanned! { - vident.span() => Some(#i) => { *self = #ident::#vident; } + vident.span() => Some(#i) => { #ident::#vident } } } - Some(ty) => { + Some(_) => { quote_spanned! { vident.span() => Some(#i) => { - let mut builder = <#ty as Default>::default(); - let mut present = false; - ::static_xml::de::DeserializeField::element(&mut builder, &mut present, child)?; - *self = #ident::#vident(builder); + #ident::#vident(::static_xml::de::DeserializeElementField::init(child)?) } } } @@ -414,55 +393,85 @@ fn do_enum_direct(enum_: &ElementEnum) -> TokenStream { }) .collect(); + // #ident and #visitor_type's visibilities must exactly match due to a trait reference cycle: + // * #ident refers to #visitor_type via RawDeserialize::Visitor + // * #visitor_type refers to #ident via RawDeserializeVisitor::Out + let vis = &enum_.input.vis; + quote! { const ELEMENTS: &[::static_xml::ExpandedNameRef] = &[#(#elements,)*]; - impl ::static_xml::de::ElementVisitor for #ident { + #vis struct #visitor_type<'out> { + out: &'out mut ::std::mem::MaybeUninit<#ident>, + initialized: bool, + } + + impl<'out> ::static_xml::de::ElementVisitor for #visitor_type<'out> { fn element<'a>( &mut self, mut child: ::static_xml::de::ElementReader<'a>, ) -> Result>, ::static_xml::de::VisitorError> { let name = child.expanded_name(); - match ::static_xml::de::find(&name, ELEMENTS) { - #(#match_arms,)* - _ => return Ok(Some(child)), + let element_i = ::static_xml::de::find(&name, ELEMENTS); + if self.initialized { + // SAFETY: self.out is initialized when self.initialized is true. + let expected_i = match unsafe { self.out.assume_init_mut() } { + #(#initialized_match_arms,)* + }; + if let Some(element_i) = element_i { + return Err(::static_xml::de::VisitorError::unexpected_element( + &name, + &ELEMENTS[expected_i], + )); + } + return Ok(None); } + self.out.write(match element_i { + #(#uninitialized_match_arms,)* + _ => return Ok(Some(child)), + }); + self.initialized = true; Ok(None) } } - impl ::static_xml::de::Deserialize for #ident { - fn deserialize(element: ::static_xml::de::ElementReader<'_>) -> Result { - let mut visitor: #ident = Default::default(); - element.read_to(&mut visitor)?; - Ok(visitor) - } - } + unsafe impl<'out> ::static_xml::de::RawDeserializeVisitor<'out> for #visitor_type<'out> { + type Out = #ident; - impl ::static_xml::de::DeserializeFlatten for #ident { - type Builder = #ident; + fn new(out: &'out mut ::std::mem::MaybeUninit) -> Self { + Self { + out, + initialized: false, + } + } - fn init() -> Self::Builder { Default::default() } - fn finalize(builder: Self::Builder) -> Result { - Ok(builder) + fn finalize( + self, + default: Option Self::Out>, + ) -> Result<(), ::static_xml::de::VisitorError> { + if !self.initialized { + if let Some(d) = default { + self.out.write(d()); + } else { + return Err(static_xml::de::VisitorError::cant_be_empty(stringify!(#ident))); + } + } + // SAFETY: returning `Ok` guarantees `self.out` is fully initialized. + Ok(()) } } - // can't derive DeserializeFlatten for Option<#ident> and Vec<#ident> because of the - // orphan rule. :-( + impl<'out> ::static_xml::de::RawDeserialize<'out> for #ident { + type Visitor = #visitor_type<'out>; + } + + ::static_xml::impl_deserialize_via_raw!(#ident, #visitor_type); } } pub(crate) fn derive(errors: &Errors, input: syn::DeriveInput) -> Result { match input.data { - Data::Enum(ref data) => { - let enum_ = ElementEnum::new(&errors, &input, data); - //if enum_.attr.direct { - Ok(do_enum_direct(&enum_)) - //} else { - // Ok(do_enum_indirect(&enum_)) - //} - } + Data::Enum(ref data) => Ok(do_enum(&ElementEnum::new(&errors, &input, data))), Data::Struct(ref data) => ElementStruct::new(&errors, &input, data).map(|s| do_struct(&s)), _ => { errors.push(syn::Error::new_spanned( diff --git a/static-xml/src/de/mod.rs b/static-xml/src/de/mod.rs index ce77429..a748567 100644 --- a/static-xml/src/de/mod.rs +++ b/static-xml/src/de/mod.rs @@ -3,8 +3,8 @@ //! Deserialization from XML to Rust types. -use std::fmt::Write; use std::sync::Arc; +use std::{fmt::Write, mem::MaybeUninit}; use log::trace; use xml::{ @@ -52,6 +52,16 @@ impl VisitorError { )))) } + // xml-rs might detect this anyway, but static-xml-derive shouldn't rely + // on that for avoiding memory leaks, and it needs an error to return. + #[doc(hidden)] + pub fn duplicate_attribute(attribute: &ExpandedNameRef) -> Self { + Self::Wrap(Box::new(SimpleError(format!( + "Duplicate attribute {}", + attribute + )))) + } + #[doc(hidden)] pub fn duplicate_element(element: &ExpandedNameRef) -> Self { Self::Wrap(Box::new(SimpleError(format!( @@ -662,7 +672,7 @@ pub trait Deserialize: Sized { /// /// Typically used only via derive macros. /// -/// For any `T` that implements [`Deserialize`,, there are three implementations +/// For any `T` that implements [`Deserialize`], there are three implementations /// of this trait: /// /// 1. `T`, for mandatory singleton fields. In XML Schema terms, this @@ -671,25 +681,131 @@ pub trait Deserialize: Sized { /// `minOccurs="0" maxOccurs="1". /// 3. `Vec`, for repeated fields. In XML Schema terms, /// `minOccurs="0" maxOccurs="unbounded"`. -pub trait DeserializeField: Default { - /// Handles a single occurrence of this element; called zero or more times. - fn element<'a>(&mut self, present: &mut bool, element: ElementReader<'_>) -> Result<(), VisitorError>; - - fn finalize( - present: bool, - expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError>; +pub trait DeserializeElementField: Sized { + /// Called on the first occurrence of this field's element within the parent. + fn init(element: ElementReader<'_>) -> Result; + + /// Called on subsequent occurrences of this field's element within the parent. + /// + /// `self` was previously returned by `init` and has been through zero or more prior `update` calls. + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError>; + + /// Called iff this field's element was not found within the parent. + fn missing(expected: &ExpandedNameRef<'_>) -> Result; +} + +/// Deserializes an attribute into a field. +/// +/// This is implemented via [`ParseText`] as noted there. +pub trait DeserializeAttrField: Sized { + /// Called iff this field's attribute was found within the parent. + fn init(value: String) -> Result; + + /// Called iff this field's attribute was not found within the parent. + fn missing(expected: &ExpandedNameRef<'_>) -> Result; +} + +#[doc(hidden)] +pub unsafe trait RawDeserializeVisitor<'out>: Sized + ElementVisitor { + type Out; + + /// Returns a visitor that can be used to populate `this`. + fn new(out: &'out mut MaybeUninit) -> Self; + + /// Finalizes `out`. + /// + /// An `Ok` return guarantees `out.assume_init()` is valid. + fn finalize(self, default: Option Self::Out>) -> Result<(), VisitorError>; } -/// Deserializes this type when "flattened" into another. +/// Raw, unsafe implementation of [`Deserialize`], for use by the macros. /// /// With `static-xml-derive`, this can be used via `#[static_xml(flatten)]`. -pub trait DeserializeFlatten: Sized { - type Builder: ElementVisitor; +/// +/// Implementing this type automatically implements `Deserialize`. +#[doc(hidden)] +pub trait RawDeserialize<'out>: Sized { + type Visitor: RawDeserializeVisitor<'out, Out = Self>; +} + +/* +/// Raw, unsafe implementation of [`Deserialize`]. +#[doc(hidden)] +pub unsafe trait RawDeserialize<'out>: Sized { + type Visitor: ElementVisitor; + + /// Returns a visitor that can be used to populate `this`. + fn new(out: &'out mut MaybeUninit) -> Self::Visitor; + + /// Finalizes `out`. + /// + /// An `Ok` return guarantees `out.assume_init()` is valid. + fn finalize(visitor: Self::Visitor, default: Option Self>) -> Result<(), VisitorError>; +} +*/ + +/* +impl<'a, T: RawDeserialize<'a>> Deserialize for T { + fn deserialize(element: ElementReader<'_>) -> Result { + let mut out = ::std::mem::MaybeUninit::uninit(); + let mut visitor = Self::Visitor::new(&mut out); + element.read_to(&mut visitor)?; + visitor.finalize()?; + // SAFETY: finalize()'s contract guarantees this `assume_init` is safe. + Ok(unsafe { + out.assume_init() + }) + } +} +*/ + +/// Implements [`Deserialize`] via [`RawDeserializeVisitor`]. +/// +/// The type opts into [`Deserialize`] via this macro rather than by +/// implementing a (hypothetical) `RawDeserialize`. The latter approach +/// doesn't work because `impl Deserialize for T` and +/// `impl Deserialize for T` would +/// [conflict](https://doc.rust-lang.org/error-index.html#E0119). +#[doc(hidden)] +#[macro_export] +macro_rules! impl_deserialize_via_raw { + ( $t:ident, $visitor:ident ) => { + impl ::static_xml::de::Deserialize for $t { + fn deserialize( + element: ::static_xml::de::ElementReader<'_>, + ) -> Result { + let mut out = ::std::mem::MaybeUninit::uninit(); + let mut visitor = + <$visitor as ::static_xml::de::RawDeserializeVisitor>::new(&mut out); + element.read_to(&mut visitor)?; + ::static_xml::de::RawDeserializeVisitor::finalize(visitor, None)?; + // SAFETY: finalize's contract guarantees assume_init is safe. + Ok(unsafe { out.assume_init() }) + } + } + }; +} - fn init() -> Self::Builder; - fn finalize(builder: Self::Builder) -> Result; +/* +#[doc(hidden)] +#[macro_export] +macro_rules! impl_deserialize_via_raw { + ( $t:ident ) => { + impl Deserialize for $t { + fn deserialize(element: ElementReader<'_>) -> Result { + let mut out = ::std::mem::MaybeUninit::uninit(); + let mut visitor = Self::Visitor::new(&mut out); + element.read_to(&mut visitor)?; + visitor.finalize()?; + // SAFETY: finalize()'s contract guarantees this `assume_init` is safe. + Ok(unsafe { + out.assume_init() + }) + } + } + }; } +*/ /// Deserializes text data, whether character nodes or attribute values. /// @@ -708,24 +824,6 @@ pub trait ParseText: Sized { fn parse(text: String) -> Result; } -/// Deserializes an attribute into a field. -/// -/// This is implemented via [`ParseText`] as noted there. -pub trait DeserializeAttr: Sized { - /// May be called zero or one time with the relevant attribute. - fn attr( - &mut self, - present: &mut bool, - name: &ExpandedNameRef, - value: String, - ) -> Result<(), VisitorError>; - - fn finalize( - present: bool, - expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError>; -} - /// Visitor used within [`Deserialize`]. pub trait ElementVisitor { /// Processes a given attribute of this element's start tag. @@ -908,103 +1006,201 @@ pub fn find(name: &ExpandedNameRef<'_>, sorted_slice: &[ExpandedNameRef<'_>]) -> sorted_slice.binary_search(name).ok() } -impl DeserializeField for T { - fn element<'a>(&mut self, present: &mut bool, element: ElementReader<'_>) -> Result<(), VisitorError> { - if *present { - return Err(VisitorError::duplicate_element(&element.expanded_name())); - } - *self = T::deserialize(element)?; - *present = true; - Ok(()) +impl DeserializeElementField for T { + fn init(element: ElementReader<'_>) -> Result { + T::deserialize(element) } - fn finalize( - present: bool, - expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError> { - if !present { - return Err(VisitorError::missing_element(expected)); - } - Ok(()) + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { + Err(VisitorError::duplicate_element(&element.expanded_name())) + } + + fn missing(expected: &ExpandedNameRef<'_>) -> Result { + Err(VisitorError::missing_element(expected)) } } -impl DeserializeField for Option { - fn element<'a>(&mut self, present: &mut bool, element: ElementReader<'_>) -> Result<(), VisitorError> { - if *present { - return Err(VisitorError::duplicate_element(&element.expanded_name())); - } - *self = Some(T::deserialize(element)?); - *present = true; - Ok(()) +impl DeserializeElementField for Option { + fn init(element: ElementReader<'_>) -> Result { + T::deserialize(element).map(Some) } - fn finalize( - _present: bool, - _expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError> { - Ok(()) + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { + Err(VisitorError::duplicate_element(&element.expanded_name())) + } + + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { + Ok(None) } } /// Deserializes into a `Vec`, adding an element. -impl DeserializeField for Vec { - fn element<'a>(&mut self, present: &mut bool, element: ElementReader<'_>) -> Result<(), VisitorError> { +impl DeserializeElementField for Vec { + fn init(element: ElementReader<'_>) -> Result { + Ok(vec![T::deserialize(element)?]) + } + + fn update(&mut self, element: ElementReader<'_>) -> Result<(), VisitorError> { self.push(T::deserialize(element)?); - *present = true; Ok(()) } - fn finalize( - _present: bool, - _expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError> { - Ok(()) + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { + Ok(Vec::new()) + } +} + +impl DeserializeAttrField for T { + fn init(value: String) -> Result { + Ok(T::parse(value).map_err(VisitorError::Wrap)?) + } + + fn missing(expected: &ExpandedNameRef<'_>) -> Result { + Err(VisitorError::missing_attribute(expected)) } } -impl DeserializeAttr for T { - fn attr( +impl DeserializeAttrField for Option { + fn init(value: String) -> Result { + Ok(Some(T::parse(value).map_err(VisitorError::Wrap)?)) + } + + fn missing(_expected: &ExpandedNameRef<'_>) -> Result { + Ok(None) + } +} + +/* +#[doc(hidden)] +pub enum SingleContainerResult<'a, T> { + Inited(T), + Pass(ElementReader<'a>), +} + +/// A type which is initialized after a single child element. +/// +/// `impl SingleChildElement for T` automatically implements [`Deserialize`] +/// for `T`, `Option`, and `Vec`. This is particularly handy for flatten +/// on `enum` types. +#[doc(hidden)] +pub trait SingleElementContainer { + fn element<'a>(child: ElementReader<'a>) -> Result, VisitorError>; + fn missing() -> VisitorError; +} + +impl<'out, T: SingleElementContainer> RawDeserialize<'out> for T { + type Visitor = RequiredVisitor<'out, T>; +} + +struct RequiredVisitor<'out, T: SingleElementContainer> { + out: &'out mut MaybeUninit, + initialized: bool, +} + +// SAFETY: RequiredVisitor fulfills the contract of guaranteeing `out` is +// initialized when `finalize` returns `Ok`. +unsafe impl<'out, T: SingleElementContainer> RawDeserializeVisitor<'out> for RequiredVisitor<'out, T> { + type Out = T; + + fn new(out: &'out mut MaybeUninit) -> Self { + Self { + out, + initialized: false, + } + } + + fn finalize(self, default: Option Self>) -> Result<(), VisitorError> { + match (self.initialized, default) { + (false, Some(d)) => { + self.out.write(d()); + Ok(()) + } + (false, None) => Err(T::missing()), + (true, _) => Ok(()) + } + } +} + +impl<'out, T: SingleElementContainer> ElementVisitor for RequiredVisitor<'out, T> { + fn element<'a>( &mut self, - present: &mut bool, - _name: &ExpandedNameRef, - value: String, - ) -> Result<(), VisitorError> { - *self = T::parse(value).map_err(VisitorError::Wrap)?; - *present = true; - Ok(()) + child: ElementReader<'a>, + ) -> Result>, VisitorError> { + if self.initialized { + return Err(VisitorError::duplicate_element(child)); + } + Ok(match SingleElementContainer::element(child)? { + SingleContainerResult::Inited(c) => { + self.out.write(c); + self.initialized = true; + None + } + SingleContainerResult::Pass(c) => Some(c), + }) + } +} + +impl<'out, T: SingleElementContainer> RawDeserialize<'out> for Option { + type Visitor = OptionalVisitor<'out, T>; +} + +struct OptionalVisitor<'out, T: SingleElementContainer>(&'out mut Option); + +// SAFETY: OptionVisitor fulfills the contract of guaranteeing `out` is +// initialized when `finalize` returns `Ok`. I expect writing `None` to be cheap +// (in code size and CPU time), so it just does this immediately. +unsafe impl<'out, T: SingleElementContainer> RawDeserializeVisitor<'out> for OptionalVisitor<'out, T> { + type Out = Option; + + fn new(out: &'out mut MaybeUninit) -> Self { + Self(out.write(None)) } - fn finalize( - present: bool, - expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError> { - if !present { - return Err(VisitorError::missing_attribute(expected)); + fn finalize(self, default: Option Self>) -> Result<(), VisitorError> { + if let (None, Some(d)) = (self.out, default) { + self.out = Some(d()); } Ok(()) } } -impl DeserializeAttr for Option { - fn attr( +impl<'out, T: SingleElementContainer> ElementVisitor for OptionalVisitor<'out, T> { + fn element<'a>( &mut self, - present: &mut bool, - _name: &ExpandedNameRef, - value: String, - ) -> Result<(), VisitorError> { - *self = Some(T::parse(value).map_err(VisitorError::Wrap)?); - *present = true; - Ok(()) + child: ElementReader<'a>, + ) -> Result>, VisitorError> { + if self.out.is_some() { + return Err(VisitorError::duplicate_element(child.name())); + } + *self.out = Some(T::deserialize(child)?); + Ok(None) } +} - fn finalize( - _present: bool, - _expected: &ExpandedNameRef<'_>, - ) -> Result<(), VisitorError> { +impl<'out, T: SingleElementContainer> RawDeserialize<'out> for Vec { + type Visitor = MultiVisitor<'out, T>; +} + +struct MultiVisitor<'out, T: SingleElementContainer>(&'out mut Vec); + +// SAFETY: OptionVisitor fulfills the contract of guaranteeing `out` is +// initialized when `finalize` returns `Ok`. I expect writing `None` to be cheap +// (in code size and CPU time), so it just does this immediately. +unsafe impl<'out, T: SingleElementContainer> RawDeserializeVisitor<'out> for MultiVisitor<'out, T> { + type Out = Vec; + + fn new(out: &'out mut MaybeUninit) -> Self { + Self(out.write(Vec::new())) + } + + fn finalize(self, default: Option Self>) -> Result<(), VisitorError> { + if let (true, Some(d)) = (self.out.is_empty(), default) { + *self.out = d(); + } Ok(()) } } +*/ impl ParseText for bool { fn parse(text: String) -> Result { @@ -1057,7 +1253,6 @@ impl Deserialize for T { T::parse(str).map_err(VisitorError::Wrap) } } - /*#[cfg(test)] mod tests { use super::*; diff --git a/test-suite/tests/basic.rs b/test-suite/tests/basic.rs index 24d7cbb..7e3143d 100644 --- a/test-suite/tests/basic.rs +++ b/test-suite/tests/basic.rs @@ -18,7 +18,6 @@ struct Foo { //#[static_xml(flatten)] //bar: Bar, - text: String, constrained: ConstrainedString, diff --git a/test-suite/tests/element_enum.rs b/test-suite/tests/element_enum.rs index 73c2a18..c8d1d5b 100644 --- a/test-suite/tests/element_enum.rs +++ b/test-suite/tests/element_enum.rs @@ -3,68 +3,30 @@ use static_xml_derive::{Deserialize, Serialize}; -#[derive(Default, Debug, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -struct DirectHolder { - #[static_xml(flatten)] - direct_enum: DirectEnum, - //#[static_xml(flatten)] - //other_flatten: OtherDirectFlatten, -} - #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -enum DirectEnum { - Simple(String), - Vec(Vec), - Unit, - #[static_xml(skip)] - Skipped(String), -} -impl Default for DirectEnum { - fn default() -> Self { - DirectEnum::Skipped("default".to_owned()) - } -} - -#[derive(Debug, Default, Deserialize, Serialize, PartialEq, Eq)] -#[static_xml(direct)] -struct OtherDirectFlatten { - field: Vec, -} - -#[derive(Debug, Default, Deserialize, Serialize, PartialEq, Eq)] -struct IndirectHolder { +struct Holder { #[static_xml(flatten)] - indirect_enum: IndirectEnum, + enum_: Enum, //#[static_xml(flatten)] - //other_flatten: OtherIndirectFlatten, + //other_flatten: OtherFlatten, } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -enum IndirectEnum { +enum Enum { Simple(String), Vec(Vec), Unit, - #[allow(dead_code)] - #[static_xml(skip)] - Skipped, -} -impl Default for IndirectEnum { - fn default() -> Self { - IndirectEnum::Skipped - } } -#[derive(Debug, Default, Deserialize, Serialize, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] struct OtherIndirectFlatten { field: Vec, } #[test] -fn deserialize_indirect_simple() { +fn deserialize_simple() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let indirect: IndirectHolder = static_xml::de::from_str( + let holder: Holder = static_xml::de::from_str( r#" @@ -76,9 +38,9 @@ fn deserialize_indirect_simple() { ) .unwrap(); assert_eq!( - indirect, - IndirectHolder { - indirect_enum: IndirectEnum::Simple("asdf".to_owned()), + holder, + Holder { + enum_: Enum::Simple("asdf".to_owned()), /*other_flatten: OtherIndirectFlatten { field: vec!["before".to_owned(), "after".to_owned()] },*/ @@ -87,9 +49,9 @@ fn deserialize_indirect_simple() { } #[test] -fn deserialize_indirect_vec() { +fn deserialize_vec() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let indirect: IndirectHolder = static_xml::de::from_str( + let holder: Holder = static_xml::de::from_str( r#" @@ -102,9 +64,9 @@ fn deserialize_indirect_vec() { ) .unwrap(); assert_eq!( - indirect, - IndirectHolder { - indirect_enum: IndirectEnum::Vec(vec!["foo".to_owned(), "bar".to_owned()]), + holder, + Holder { + enum_: Enum::Vec(vec!["foo".to_owned(), "bar".to_owned()]), /*other_flatten: OtherIndirectFlatten { field: vec!["before".to_owned(), "after".to_owned()] },*/ @@ -113,9 +75,9 @@ fn deserialize_indirect_vec() { } #[test] -fn deserialize_indirect_mix_error() { +fn deserialize_mix_error() { let _ = env_logger::Builder::new().is_test(true).try_init(); - let e = static_xml::de::from_str::( + let e = static_xml::de::from_str::( r#"