Skip to content

Commit

Permalink
fix(types): fix Serial in Fields derive and add tests (#13926)
Browse files Browse the repository at this point in the history
Signed-off-by: TennyZhuang <[email protected]>
  • Loading branch information
TennyZhuang authored Dec 11, 2023
1 parent ad2073f commit 52632ae
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 47 deletions.
48 changes: 12 additions & 36 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/common/fields-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ proc-macro = true
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["full", "extra-traits"] }

[dev-dependencies]
expect-test = "1"
indoc = "2"
prettyplease = "0.2"
12 changes: 12 additions & 0 deletions src/common/fields-derive/src/gen/test_output.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
impl ::risingwave_common::types::Fields for Data {
fn fields() -> Vec<(&'static str, ::risingwave_common::types::DataType)> {
vec![
("v1", < i16 as ::risingwave_common::types::WithDataType >
::default_data_type()), ("v2", < std::primitive::i32 as
::risingwave_common::types::WithDataType > ::default_data_type()), ("v3", <
bool as ::risingwave_common::types::WithDataType > ::default_data_type()),
("v4", < Serial as ::risingwave_common::types::WithDataType >
::default_data_type())
]
}
}
49 changes: 44 additions & 5 deletions src/common/fields-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ use syn::{Data, DeriveInput, Field, Result};

#[proc_macro_derive(Fields)]
pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input: DeriveInput = syn::parse_macro_input! {tokens};
inner(tokens.into()).into()
}

match gen(input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
fn inner(tokens: TokenStream) -> TokenStream {
match gen(tokens) {
Ok(tokens) => tokens,
Err(err) => err.to_compile_error(),
}
}

fn gen(input: DeriveInput) -> Result<TokenStream> {
fn gen(tokens: TokenStream) -> Result<TokenStream> {
let input: DeriveInput = syn::parse2(tokens)?;

let DeriveInput {
attrs: _attrs,
vis: _vis,
Expand Down Expand Up @@ -71,3 +75,38 @@ fn gen(input: DeriveInput) -> Result<TokenStream> {
}
})
}

#[cfg(test)]
mod tests {
use indoc::indoc;
use proc_macro2::TokenStream;
use syn::File;

fn pretty_print(output: TokenStream) -> String {
let output: File = syn::parse2(output).unwrap();
prettyplease::unparse(&output)
}

#[test]
fn test_gen() {
let code = indoc! {r#"
#[derive(Fields)]
struct Data {
v1: i16,
v2: std::primitive::i32,
v3: bool,
v4: Serial,
}
"#};

let input: TokenStream = str::parse(code).unwrap();

let output = super::gen(input).unwrap();

let output = pretty_print(output);

let expected = expect_test::expect_file!["gen/test_output.rs"];

expected.assert_eq(&output);
}
}
9 changes: 3 additions & 6 deletions src/common/src/types/with_data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@ use std::sync::Arc;
use bytes::Bytes;

use super::{
DataType, Date, Decimal, Fields, Int256, Interval, JsonbRef, JsonbVal, StructType, Time,
Timestamp, Timestamptz, F32, F64,
DataType, Date, Decimal, Fields, Int256, Interval, JsonbRef, JsonbVal, Serial, StructType,
Time, Timestamp, Timestamptz, F32, F64,
};

/// A trait for all physical types that can be associated with a [`DataType`].
///
/// This is also a helper for [`Fields`](derive@crate::types::Fields) derive macro.
pub trait WithDataType {
/// Returns the most obvious [`DataType`] for the rust type.
///
/// There may be more than one [`DataType`] corresponding to the same rust type,
/// for example, [`Int64`](DataType::Int64) and [`Serial`](DataType::Serial) are
/// both expressed as i64. This method will return [`Int64`](DataType::Int64).
fn default_data_type() -> DataType;
}

Expand Down Expand Up @@ -94,6 +90,7 @@ impl_with_data_type!(f64, DataType::Float64);
impl_with_data_type!(F64, DataType::Float64);
impl_with_data_type!(rust_decimal::Decimal, DataType::Decimal);
impl_with_data_type!(Decimal, DataType::Decimal);
impl_with_data_type!(Serial, DataType::Serial);

impl<'a> WithDataType for &'a str {
fn default_data_type() -> DataType {
Expand Down

0 comments on commit 52632ae

Please sign in to comment.