From 36edf716f1eb9bf6d017c7e2c46c2ea3585ef27b Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 5 Feb 2024 19:17:03 +0800 Subject: [PATCH] feat(type): improve the `Fields` derive macro (#14934) Signed-off-by: Runji Wang --- .../fields-derive/src/gen/test_output.rs | 31 ++++- src/common/fields-derive/src/lib.rs | 108 ++++++++++++++---- src/common/src/types/fields.rs | 65 ++++++++++- src/common/src/types/mod.rs | 45 ++++++-- 4 files changed, 208 insertions(+), 41 deletions(-) diff --git a/src/common/fields-derive/src/gen/test_output.rs b/src/common/fields-derive/src/gen/test_output.rs index d8e7274b4c2ee..517dcdefc7a8c 100644 --- a/src/common/fields-derive/src/gen/test_output.rs +++ b/src/common/fields-derive/src/gen/test_output.rs @@ -6,7 +6,36 @@ impl ::risingwave_common::types::Fields for Data { ::risingwave_common::types::WithDataType > ::default_data_type()), ("v3", < bool as ::risingwave_common::types::WithDataType > ::default_data_type()), ("v4", < Serial as ::risingwave_common::types::WithDataType > - ::default_data_type()) + ::default_data_type()), ("type", < i32 as + ::risingwave_common::types::WithDataType > ::default_data_type()) ] } + fn into_owned_row(self) -> ::risingwave_common::row::OwnedRow { + ::risingwave_common::row::OwnedRow::new( + vec![ + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.v1), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.v2), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.v3), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.v4), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.r#type) + ], + ) + } + fn primary_key() -> &'static [usize] { + &[1usize, 0usize] + } +} +impl From for ::risingwave_common::types::ScalarImpl { + fn from(v: Data) -> Self { + ::risingwave_common::types::StructValue::new( + vec![ + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.v1), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.v2), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.v3), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.v4), + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.r#type) + ], + ) + .into() + } } diff --git a/src/common/fields-derive/src/lib.rs b/src/common/fields-derive/src/lib.rs index 9b30c4ee72419..86fa229a5adcd 100644 --- a/src/common/fields-derive/src/lib.rs +++ b/src/common/fields-derive/src/lib.rs @@ -14,9 +14,9 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Data, DeriveInput, Field, Result}; +use syn::{Data, DeriveInput, Result}; -#[proc_macro_derive(Fields)] +#[proc_macro_derive(Fields, attributes(primary_key))] pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { inner(tokens.into()).into() } @@ -31,51 +31,107 @@ fn inner(tokens: TokenStream) -> TokenStream { fn gen(tokens: TokenStream) -> Result { let input: DeriveInput = syn::parse2(tokens)?; - let DeriveInput { - attrs: _attrs, - vis: _vis, - ident, - generics, - data, - } = input; - if !generics.params.is_empty() { + let ident = &input.ident; + if !input.generics.params.is_empty() { return Err(syn::Error::new_spanned( - generics, + input.generics, "generics are not supported", )); } - let Data::Struct(r#struct) = data else { - return Err(syn::Error::new_spanned(ident, "only structs are supported")); + let Data::Struct(struct_) = &input.data else { + return Err(syn::Error::new_spanned( + input.ident, + "only structs are supported", + )); }; - let fields_rs = r#struct.fields; - let fields_rw: Vec = fields_rs - .into_iter() - .map(|field_rs| { - let Field { - // We can support #[field(ignore)] or other useful attributes here. - attrs: _attrs, - ident: name, - ty, - .. - } = field_rs; - let name = name.map_or("".to_string(), |name| name.to_string()); + let fields_rw: Vec = struct_ + .fields + .iter() + .map(|field| { + let mut name = field.ident.as_ref().expect("field no name").to_string(); + // strip leading `r#` + if name.starts_with("r#") { + name = name[2..].to_string(); + } + let ty = &field.ty; quote! { (#name, <#ty as ::risingwave_common::types::WithDataType>::default_data_type()) } }) .collect(); + let names = struct_ + .fields + .iter() + .map(|field| field.ident.as_ref().expect("field no name")) + .collect::>(); + let primary_key = get_primary_key(&input).map(|indices| { + quote! { + fn primary_key() -> &'static [usize] { + &[#(#indices),*] + } + } + }); Ok(quote! { impl ::risingwave_common::types::Fields for #ident { fn fields() -> Vec<(&'static str, ::risingwave_common::types::DataType)> { vec![#(#fields_rw),*] } + fn into_owned_row(self) -> ::risingwave_common::row::OwnedRow { + ::risingwave_common::row::OwnedRow::new(vec![#( + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(self.#names) + ),*]) + } + #primary_key + } + impl From<#ident> for ::risingwave_common::types::ScalarImpl { + fn from(v: #ident) -> Self { + ::risingwave_common::types::StructValue::new(vec![#( + ::risingwave_common::types::ToOwnedDatum::to_owned_datum(v.#names) + ),*]).into() + } } }) } +/// Get primary key indices from `#[primary_key]` attribute. +fn get_primary_key(input: &syn::DeriveInput) -> Option> { + let syn::Data::Struct(struct_) = &input.data else { + return None; + }; + // find `#[primary_key(k1, k2, ...)]` on struct + let composite = input.attrs.iter().find_map(|attr| match &attr.meta { + syn::Meta::List(list) if list.path.is_ident("primary_key") => Some(&list.tokens), + _ => None, + }); + if let Some(keys) = composite { + let index = |name: &str| { + struct_ + .fields + .iter() + .position(|f| f.ident.as_ref().map_or(false, |i| i == name)) + .expect("primary key not found") + }; + return Some( + keys.to_string() + .split(',') + .map(|s| index(s.trim())) + .collect(), + ); + } + // find `#[primary_key]` on fields + for (i, field) in struct_.fields.iter().enumerate() { + for attr in &field.attrs { + if matches!(&attr.meta, syn::Meta::Path(path) if path.is_ident("primary_key")) { + return Some(vec![i]); + } + } + } + None +} + #[cfg(test)] mod tests { use indoc::indoc; @@ -91,11 +147,13 @@ mod tests { fn test_gen() { let code = indoc! {r#" #[derive(Fields)] + #[primary_key(v2, v1)] struct Data { v1: i16, v2: std::primitive::i32, v3: bool, v4: Serial, + r#type: i32, } "#}; diff --git a/src/common/src/types/fields.rs b/src/common/src/types/fields.rs index 0e5912ead0dc4..f52717297792e 100644 --- a/src/common/src/types/fields.rs +++ b/src/common/src/types/fields.rs @@ -12,13 +12,70 @@ // See the License for the specific language governing permissions and // limitations under the License. use super::DataType; +use crate::row::OwnedRow; +use crate::util::chunk_coalesce::DataChunkBuilder; /// A struct can implements `Fields` when if can be represented as a relational Row. /// -/// Can be automatically derived with [`#[derive(Fields)]`](derive@super::Fields). +/// # Derivable +/// +/// This trait can be automatically derived with [`#[derive(Fields)]`](derive@super::Fields). +/// Type of the fields must implement [`WithDataType`](super::WithDataType) and [`ToOwnedDatum`](super::ToOwnedDatum). +/// +/// ``` +/// # use risingwave_common::types::Fields; +/// +/// #[derive(Fields)] +/// struct Data { +/// v1: i16, +/// v2: i32, +/// } +/// ``` +/// +/// You can add `#[primary_key]` attribute to one of the fields to specify the primary key of the table. +/// +/// ``` +/// # use risingwave_common::types::Fields; +/// +/// #[derive(Fields)] +/// struct Data { +/// #[primary_key] +/// v1: i16, +/// v2: i32, +/// } +/// ``` +/// +/// If the primary key is composite, you can add `#[primary_key(...)]` attribute to the struct to specify the order of the fields. +/// +/// ``` +/// # use risingwave_common::types::Fields; +/// +/// #[derive(Fields)] +/// #[primary_key(v2, v1)] +/// struct Data { +/// v1: i16, +/// v2: i32, +/// } +/// ``` pub trait Fields { - /// When the struct being converted to an [`Row`](crate::row::Row) or a [`DataChunk`](crate::array::DataChunk), it schema must be consistent with the `fields` call. + /// Return the schema of the struct. fn fields() -> Vec<(&'static str, DataType)>; + + /// Convert the struct to an `OwnedRow`. + fn into_owned_row(self) -> OwnedRow; + + /// The primary key of the table. + fn primary_key() -> &'static [usize] { + &[] + } + + /// Create a [`DataChunkBuilder`](crate::util::chunk_coalesce::DataChunkBuilder) with the schema of the struct. + fn data_chunk_builder(capacity: usize) -> DataChunkBuilder { + DataChunkBuilder::new( + Self::fields().into_iter().map(|(_, ty)| ty).collect(), + capacity, + ) + } } #[cfg(test)] @@ -48,8 +105,8 @@ mod tests { v7: Vec, v8: std::vec::Vec, v9: Option>, - v10: std::option::Option>>, - v11: Box, + v10: std::option::Option>>, + v11: Timestamp, v14: Sub, } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index e45601d2e0240..71f352ce2f87d 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -600,13 +600,6 @@ pub trait ToOwnedDatum { fn to_owned_datum(self) -> Datum; } -impl ToOwnedDatum for Datum { - #[inline(always)] - fn to_owned_datum(self) -> Datum { - self - } -} - impl ToOwnedDatum for &Datum { #[inline(always)] fn to_owned_datum(self) -> Datum { @@ -614,17 +607,17 @@ impl ToOwnedDatum for &Datum { } } -impl ToOwnedDatum for Option<&ScalarImpl> { +impl> ToOwnedDatum for T { #[inline(always)] fn to_owned_datum(self) -> Datum { - self.cloned() + Some(self.into()) } } -impl ToOwnedDatum for DatumRef<'_> { +impl> ToOwnedDatum for Option { #[inline(always)] fn to_owned_datum(self) -> Datum { - self.map(ScalarRefImpl::into_scalar_impl) + self.map(Into::into) } } @@ -816,6 +809,36 @@ impl From> for ScalarImpl { } } +impl From> for ScalarImpl { + fn from(v: Vec) -> Self { + Self::List(v.into_iter().collect()) + } +} + +impl From>> for ScalarImpl { + fn from(v: Vec>) -> Self { + Self::List(v.into_iter().collect()) + } +} + +impl From> for ScalarImpl { + fn from(v: Vec) -> Self { + Self::List(v.iter().map(|s| s.as_str()).collect()) + } +} + +impl From> for ScalarImpl { + fn from(v: Vec) -> Self { + Self::Bytea(v.into()) + } +} + +impl From for ScalarImpl { + fn from(v: Bytes) -> Self { + Self::Bytea(v.as_ref().into()) + } +} + impl ScalarImpl { /// Creates a scalar from binary. pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> Result {