diff --git a/diesel_derives/src/as_changeset.rs b/diesel_derives/src/as_changeset.rs index b4382cc5a428..5920300145e9 100644 --- a/diesel_derives/src/as_changeset.rs +++ b/diesel_derives/src/as_changeset.rs @@ -61,24 +61,33 @@ pub fn derive(item: DeriveInput) -> Result { None => treat_none_as_null, }; - match field.serialize_as.as_ref() { - Some(AttributeSpanWrapper { item: ty, .. }) => { + match (field.serialize_as.as_ref(), field.serialize_fn.as_ref()) { + (Some(AttributeSpanWrapper { item: ty, .. }), serialize_fn) => { direct_field_ty.push(field_changeset_ty_serialize_as( field, table_name, ty, treat_none_as_null, )?); - direct_field_assign.push(field_changeset_expr_serialize_as( - field, - table_name, - ty, - treat_none_as_null, - )?); + if let Some(AttributeSpanWrapper { item: function, .. }) = serialize_fn { + direct_field_ty.push(field_changeset_expr_serialize_fn( + field, + table_name, + function, + treat_none_as_null, + )?); + } else { + direct_field_assign.push(field_changeset_expr_serialize_as( + field, + table_name, + ty, + treat_none_as_null, + )?); + } generate_borrowed_changeset = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of AsChangeset for borrowed structs } - None => { + (None, None) => { direct_field_ty.push(field_changeset_ty( field, table_name, @@ -104,6 +113,12 @@ pub fn derive(item: DeriveInput) -> Result { treat_none_as_null, )?); } + (None, Some(AttributeSpanWrapper { attribute_span, .. })) => { + return Err(syn::Error::new( + *attribute_span, + "`#[diesel(serialize_fn)]` requires `#[diesel(serialize_as)]` to be declared as well", + )); + } } } @@ -222,3 +237,20 @@ fn field_changeset_expr_serialize_as( Ok(quote!(#column.eq(::std::convert::Into::<#ty>::into(self.#field_name)))) } } + +fn field_changeset_expr_serialize_fn( + field: &Field, + table_name: &Path, + function: &Expr, + treat_none_as_null: bool, +) -> Result { + let field_name = &field.name; + let column_name = field.column_name()?; + column_name.valid_ident()?; + let column: Expr = parse_quote!(#table_name::#column_name); + if !treat_none_as_null && is_option_ty(&field.ty) { + Ok(quote!(self.#field_name.map(|x| #column.eq((#function)(x))))) + } else { + Ok(quote!(#column.eq((#function)(self.#field_name)))) + } +} diff --git a/diesel_derives/src/attrs.rs b/diesel_derives/src/attrs.rs index c7d82fe250a9..61f1f6244be7 100644 --- a/diesel_derives/src/attrs.rs +++ b/diesel_derives/src/attrs.rs @@ -13,9 +13,10 @@ use crate::deprecated::ParseDeprecated; use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType}; use crate::util::{ parse_eq, parse_paren, unknown_attribute, BELONGS_TO_NOTE, COLUMN_NAME_NOTE, - DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE, - SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE, SQL_TYPE_NOTE, - TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE, + DESERIALIZE_AS_NOTE, DESERIALIZE_FN_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, + SELECT_EXPRESSION_NOTE, SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SERIALIZE_FN_NOTE, + SQLITE_TYPE_NOTE, SQL_TYPE_NOTE, TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, + TREAT_NONE_AS_NULL_NOTE, }; use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE}; @@ -40,6 +41,8 @@ pub enum FieldAttr { SerializeAs(Ident, TypePath), DeserializeAs(Ident, TypePath), + SerializeFn(Ident, Expr), + DeserializeFn(Ident, Expr), SelectExpression(Ident, Expr), SelectExpressionType(Ident, Type), } @@ -145,6 +148,14 @@ impl Parse for FieldAttr { name, parse_eq(input, DESERIALIZE_AS_NOTE)?, )), + "serialize_fn" => Ok(FieldAttr::SerializeFn( + name, + parse_eq(input, SERIALIZE_FN_NOTE)?, + )), + "deserialize_fn" => Ok(FieldAttr::DeserializeFn( + name, + parse_eq(input, DESERIALIZE_FN_NOTE)?, + )), "select_expression" => Ok(FieldAttr::SelectExpression( name, parse_eq(input, SELECT_EXPRESSION_NOTE)?, @@ -179,6 +190,8 @@ impl MySpanned for FieldAttr { | FieldAttr::TreatNoneAsDefaultValue(ident, _) | FieldAttr::SerializeAs(ident, _) | FieldAttr::DeserializeAs(ident, _) + | FieldAttr::SerializeFn(ident, _) + | FieldAttr::DeserializeFn(ident, _) | FieldAttr::SelectExpression(ident, _) | FieldAttr::SelectExpressionType(ident, _) => ident.span(), } diff --git a/diesel_derives/src/field.rs b/diesel_derives/src/field.rs index 88a0d8cbe17c..ff4fbceb61b5 100644 --- a/diesel_derives/src/field.rs +++ b/diesel_derives/src/field.rs @@ -14,6 +14,8 @@ pub struct Field { pub treat_none_as_null: Option>, pub serialize_as: Option>, pub deserialize_as: Option>, + pub serialize_fn: Option>, + pub deserialize_fn: Option>, pub select_expression: Option>, pub select_expression_type: Option>, pub embed: Option>, @@ -29,6 +31,8 @@ impl Field { let mut sql_type = None; let mut serialize_as = None; let mut deserialize_as = None; + let mut serialize_fn = None; + let mut deserialize_fn = None; let mut embed = None; let mut select_expression = None; let mut select_expression_type = None; @@ -81,6 +85,20 @@ impl Field { ident_span, }) } + FieldAttr::SerializeFn(_, value) => { + serialize_fn = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } + FieldAttr::DeserializeFn(_, value) => { + deserialize_fn = Some(AttributeSpanWrapper { + item: value, + attribute_span, + ident_span, + }) + } FieldAttr::SelectExpression(_, value) => { select_expression = Some(AttributeSpanWrapper { item: value, @@ -125,6 +143,8 @@ impl Field { treat_none_as_null, serialize_as, deserialize_as, + serialize_fn, + deserialize_fn, select_expression, select_expression_type, embed, diff --git a/diesel_derives/src/insertable.rs b/diesel_derives/src/insertable.rs index f1774bc43cc3..5e4de836140f 100644 --- a/diesel_derives/src/insertable.rs +++ b/diesel_derives/src/insertable.rs @@ -67,14 +67,18 @@ fn derive_into_single_table( None => treat_none_as_default_value, }; - match (field.serialize_as.as_ref(), field.embed()) { - (None, true) => { + match ( + field.serialize_as.as_ref(), + field.serialize_fn.as_ref(), + field.embed(), + ) { + (None, None, true) => { direct_field_ty.push(field_ty_embed(field, None)); direct_field_assign.push(field_expr_embed(field, None)); ref_field_ty.push(field_ty_embed(field, Some(quote!(&'insert)))); ref_field_assign.push(field_expr_embed(field, Some(quote!(&)))); } - (None, false) => { + (None, None, false) => { direct_field_ty.push(field_ty( field, table_name, @@ -100,28 +104,50 @@ fn derive_into_single_table( treat_none_as_default_value, )?); } - (Some(AttributeSpanWrapper { item: ty, .. }), false) => { + (Some(AttributeSpanWrapper { item: ty, .. }), serialize_fn, false) => { direct_field_ty.push(field_ty_serialize_as( field, table_name, ty, treat_none_as_default_value, )?); - direct_field_assign.push(field_expr_serialize_as( - field, - table_name, - ty, - treat_none_as_default_value, - )?); + if let Some(AttributeSpanWrapper { item: function, .. }) = serialize_fn { + direct_field_assign.push(field_expr_serialize_fn( + field, + table_name, + ty, + function, + treat_none_as_default_value, + )?); + } else { + direct_field_assign.push(field_expr_serialize_as( + field, + table_name, + ty, + treat_none_as_default_value, + )?); + } generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs } - (Some(AttributeSpanWrapper { attribute_span, .. }), true) => { + (Some(AttributeSpanWrapper { attribute_span, .. }), _, true) => { return Err(syn::Error::new( *attribute_span, "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`", )); } + (None, Some(AttributeSpanWrapper { attribute_span, .. }), true) => { + return Err(syn::Error::new( + *attribute_span, + "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_fn)]`", + )); + } + (None, Some(AttributeSpanWrapper { attribute_span, .. }), false) => { + return Err(syn::Error::new( + *attribute_span, + "`#[diesel(serialize_fn)]` requires `#[diesel(serialize_as)]` to be declared as well", + )); + } } } @@ -227,7 +253,7 @@ fn field_expr_serialize_as( Ok(quote!(self.#field_name.map(|x| #column.eq(::std::convert::Into::<#ty>::into(x))))) } else { Ok( - quote!(std::option::Option::Some(#column.eq(::std::convert::Into::<#ty>::into(self.#field_name)))), + quote!(::std::option::Option::Some(#column.eq(::std::convert::Into::<#ty>::into(self.#field_name)))), ) } } else { @@ -235,6 +261,29 @@ fn field_expr_serialize_as( } } +fn field_expr_serialize_fn( + field: &Field, + table_name: &Path, + ty: &Type, + function: &Expr, + treat_none_as_default_value: bool, +) -> Result { + let field_name = &field.name; + let column_name = field.column_name()?; + column_name.valid_ident()?; + let column = quote!(#table_name::#column_name); + + if treat_none_as_default_value { + if is_option_ty(ty) { + Ok(quote!(self.#field_name.map(|x| #column.eq((#function)(x))))) + } else { + Ok(quote!(::std::option::Option::Some(#column.eq((#function)(self.#field_name))))) + } + } else { + Ok(quote!(#column.eq((#function)(self.#field_name)))) + } +} + fn field_ty( field: &Field, table_name: &Path, diff --git a/diesel_derives/src/util.rs b/diesel_derives/src/util.rs index 4f9930fcc113..8cbbd1525284 100644 --- a/diesel_derives/src/util.rs +++ b/diesel_derives/src/util.rs @@ -10,6 +10,8 @@ pub const COLUMN_NAME_NOTE: &str = "column_name = foo"; pub const SQL_TYPE_NOTE: &str = "sql_type = Foo"; pub const SERIALIZE_AS_NOTE: &str = "serialize_as = Foo"; pub const DESERIALIZE_AS_NOTE: &str = "deserialize_as = Foo"; +pub const SERIALIZE_FN_NOTE: &str = "serialize_fn = some_function"; +pub const DESERIALIZE_FN_NOTE: &str = "deserialize_fn = some_function"; pub const TABLE_NAME_NOTE: &str = "table_name = foo"; pub const TREAT_NONE_AS_DEFAULT_VALUE_NOTE: &str = "treat_none_as_default_value = true"; pub const TREAT_NONE_AS_NULL_NOTE: &str = "treat_none_as_null = true"; diff --git a/diesel_derives/tests/insertable.rs b/diesel_derives/tests/insertable.rs index d6eaeef83ac2..8bee5145dd26 100644 --- a/diesel_derives/tests/insertable.rs +++ b/diesel_derives/tests/insertable.rs @@ -377,3 +377,203 @@ fn embedded_struct() { let expected = vec![(1, "Sean".to_string(), Some("Black".to_string()))]; assert_eq!(Ok(expected), saved); } + +#[test] +fn serialize_fn_custom_option_field_closure() { + struct UserName(String); + impl From for String { + fn from(value: UserName) -> Self { + value.0 + } + } + + enum HairColor { + Green, + } + + impl From for String { + fn from(value: HairColor) -> Self { + match value { + HairColor::Green => "Green".into(), + } + } + } + + #[derive(Insertable)] + #[diesel(table_name = users)] + #[diesel(treat_none_as_default_value = false)] + struct NewUser { + #[diesel(serialize_as = String)] + name: UserName, + #[diesel(serialize_as = Option)] + #[diesel(serialize_fn = |x: Option| x.map(Into::into))] + hair_color: Option, + } + + let conn = &mut connection(); + let new_user = NewUser { + name: UserName("Sean".into()), + hair_color: Some(HairColor::Green), + }; + insert_into(users::table) + .values(new_user) + .execute(conn) + .unwrap(); + + let saved = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn); + let expected = vec![("Sean".to_string(), Some("Green".to_string()))]; + assert_eq!(Ok(expected), saved); +} + +#[test] +fn serialize_fn_custom_option_field_function() { + struct UserName(String); + impl From for String { + fn from(value: UserName) -> Self { + value.0 + } + } + + enum HairColor { + Green, + } + + fn hair_color_to_string(value: Option) -> Option { + value.map(|value| match value { + HairColor::Green => "Green".into(), + }) + } + + #[derive(Insertable)] + #[diesel(table_name = users)] + #[diesel(treat_none_as_default_value = false)] + struct NewUser { + #[diesel(serialize_as = String)] + name: UserName, + #[diesel(serialize_as = Option)] + #[diesel(serialize_fn = hair_color_to_string)] + hair_color: Option, + } + + let conn = &mut connection(); + let new_user = NewUser { + name: UserName("Sean".into()), + hair_color: Some(HairColor::Green), + }; + insert_into(users::table) + .values(new_user) + .execute(conn) + .unwrap(); + + let saved = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn); + let expected = vec![("Sean".to_string(), Some("Green".to_string()))]; + assert_eq!(Ok(expected), saved); +} + +#[test] +fn serialize_fn_custom_option_field_associated_function() { + struct UserName(String); + impl From for String { + fn from(value: UserName) -> Self { + value.0 + } + } + + enum HairColor { + Green, + } + + impl HairColor { + fn to_string(value: Option) -> Option { + value.map(|value| match value { + HairColor::Green => "Green".into(), + }) + } + } + + #[derive(Insertable)] + #[diesel(table_name = users)] + #[diesel(treat_none_as_default_value = false)] + struct NewUser { + #[diesel(serialize_as = String)] + name: UserName, + #[diesel(serialize_as = Option)] + #[diesel(serialize_fn = HairColor::to_string)] + hair_color: Option, + } + + let conn = &mut connection(); + let new_user = NewUser { + name: UserName("Sean".into()), + hair_color: Some(HairColor::Green), + }; + insert_into(users::table) + .values(new_user) + .execute(conn) + .unwrap(); + + let saved = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn); + let expected = vec![("Sean".to_string(), Some("Green".to_string()))]; + assert_eq!(Ok(expected), saved); +} + +#[test] +fn serialize_fn_overrides_from() { + struct UserName(String); + impl From for String { + fn from(value: UserName) -> Self { + value.0 + } + } + + enum HairColor { + Green, + } + + impl From for String { + fn from(value: HairColor) -> Self { + match value { + HairColor::Green => "error".into(), + } + } + } + + fn hair_color_to_string(value: HairColor) -> String { + match value { + HairColor::Green => "Green".into(), + } + } + + #[derive(Insertable)] + #[diesel(table_name = users)] + #[diesel(treat_none_as_default_value = false)] + struct NewUser { + #[diesel(serialize_as = String)] + name: UserName, + #[diesel(serialize_as = String)] + #[diesel(serialize_fn = hair_color_to_string)] + hair_color: HairColor, + } + + let conn = &mut connection(); + let new_user = NewUser { + name: UserName("Sean".into()), + hair_color: HairColor::Green, + }; + insert_into(users::table) + .values(new_user) + .execute(conn) + .unwrap(); + + let saved = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn); + let expected = vec![("Sean".to_string(), Some("Green".to_string()))]; + assert_eq!(Ok(expected), saved); +}