Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizable serialization of PostgresType #965

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions pgx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ use std::collections::HashSet;

use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl};
use syn::{
parse_macro_input, Attribute, Data, DeriveInput, GenericParam, Item, ItemImpl, Lifetime,
LifetimeDef, Token,
};

use operators::{impl_postgres_eq, impl_postgres_hash, impl_postgres_ord};
use pgx_sql_entity_graph::{
Expand Down Expand Up @@ -706,9 +710,13 @@ Optionally accepts the following attributes:

* `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type.
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
* `custom_serializer`: Define your own implementation of `pgx::datum::Serializer` trait (only for `Serialize/Deserialize`-implementing types)
* `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx).
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx))]
#[proc_macro_derive(
PostgresType,
attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx, custom_serializer)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -740,7 +748,8 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
}
}

if args.is_empty() {
// If no in/out parameters are defined
if args.iter().filter(|a| a != &&PostgresTypeAttribute::CustomSerializer).next().is_none() {
// assume the user wants us to implement the InOutFuncs
args.insert(PostgresTypeAttribute::Default);
}
Expand All @@ -755,6 +764,28 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
impl #generics ::pgx::PostgresType for #name #generics { }
});

if !args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs)
&& !args.contains(&PostgresTypeAttribute::CustomSerializer)
{
let mut lt_generics = generics.clone();
let mut de = LifetimeDef::new(Lifetime::new("'de", generics.span()));
let bounds = generics
.params
.iter()
.filter_map(|p| match p {
GenericParam::Type(_) => None,
GenericParam::Const(_) => None,
GenericParam::Lifetime(lt) => Some(lt.clone().lifetime),
})
.collect::<Punctuated<Lifetime, Token![+]>>();
de.bounds = bounds;
lt_generics.params.insert(0, GenericParam::Lifetime(de));
stream.extend(quote! {
impl #generics ::pgx::datum::Serializer for #name #generics { }
impl #lt_generics ::pgx::datum::Deserializer<'de> for #name #generics { }
});
}

// and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait
// which implements _in and _out #[pg_extern] functions that just return the type itself
if args.contains(&PostgresTypeAttribute::Default) {
Expand Down Expand Up @@ -931,6 +962,7 @@ enum PostgresTypeAttribute {
InOutFuncs,
PgVarlenaInOutFuncs,
Default,
CustomSerializer,
}

fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAttribute> {
Expand All @@ -948,6 +980,10 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAtt
categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs);
}

"custom_serializer" => {
categorized_attributes.insert(PostgresTypeAttribute::CustomSerializer);
}

_ => {
// we can just ignore attributes we don't understand
}
Expand Down Expand Up @@ -1091,8 +1127,6 @@ pub fn pg_trigger(attrs: TokenStream, input: TokenStream) -> TokenStream {
fn wrapped(attrs: TokenStream, input: TokenStream) -> Result<TokenStream, syn::Error> {
use pgx_sql_entity_graph::{PgTrigger, PgTriggerAttribute};
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::Token;

let attributes =
Punctuated::<PgTriggerAttribute, Token![,]>::parse_terminated.parse(attrs)?;
Expand Down
51 changes: 47 additions & 4 deletions pgx-tests/src/tests/postgres_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ Use of this source code is governed by the MIT license that can be found in the
*/
use core::ffi::CStr;
use pgx::prelude::*;
use pgx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use pgx::{Deserializer, InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, Serializer, StringInfo};
use serde::{Deserialize, Serialize};
use std::io::Write;
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
Expand Down Expand Up @@ -152,18 +153,38 @@ pub enum JsonEnumType {
E2 { b: f32 },
}

#[derive(Serialize, Deserialize, PostgresType)]
#[custom_serializer]
pub struct CustomSerialized;

impl Serializer for CustomSerialized {
fn to_writer<W: Write>(&self, mut writer: W) {
writer.write(&[1]).expect("can't write");
}
}

impl<'de> Deserializer<'de> for CustomSerialized {
fn from_slice(slice: &'de [u8]) -> Self {
if slice != &[1] {
panic!("wrong type")
} else {
CustomSerialized
}
}
}

#[cfg(any(test, feature = "pg_test"))]
#[pgx::pg_schema]
mod tests {
#[allow(unused_imports)]
use crate as pgx_tests;

use crate::tests::postgres_type_tests::{
CustomTextFormatSerializedEnumType, CustomTextFormatSerializedType, JsonEnumType, JsonType,
VarlenaEnumType, VarlenaType,
CustomSerialized, CustomTextFormatSerializedEnumType, CustomTextFormatSerializedType,
JsonEnumType, JsonType, VarlenaEnumType, VarlenaType,
};
use pgx::prelude::*;
use pgx::PgVarlena;
use pgx::{varsize_any_exhdr, PgVarlena};

#[pg_test]
fn test_mytype() -> Result<(), pgx::spi::Error> {
Expand Down Expand Up @@ -253,4 +274,26 @@ mod tests {
assert!(matches!(result, JsonEnumType::E1 { a } if a == 1.0));
Ok(())
}

#[pg_test]
fn custom_serializer() {
let datum = CustomSerialized.into_datum().unwrap();
// Ensure we actually get our custom format, not the default CBOR
unsafe {
let input = datum.cast_mut_ptr();
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
assert_eq!(len, 1);
}
}

#[pg_test]
fn custom_serializer_end_to_end() {
let s = CustomSerialized;
let _ = Spi::get_one_with_args::<CustomSerialized>(
r#"SELECT $1"#,
vec![(PgOid::Custom(CustomSerialized::type_oid()), s.into_datum())],
)
.unwrap();
}
}
120 changes: 119 additions & 1 deletion pgx/src/datum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,138 @@ pub use json::*;
pub use numeric::{AnyNumeric, Numeric};
use once_cell::sync::Lazy;
pub use range::*;
use serde::{Deserialize, Serialize};
use std::any::TypeId;
pub use time_stamp::*;
pub use time_stamp_with_timezone::*;
pub use time_with_timezone::*;
pub use tuples::*;
pub use varlena::*;

use crate::PgBox;
use crate::{pg_sys, PgBox, PgMemoryContexts, StringInfo};
use pgx_sql_entity_graph::RustSqlMapping;

/// A tagging trait to indicate a user type is also meant to be used by Postgres
/// Implemented automatically by `#[derive(PostgresType)]`
pub trait PostgresType {}

/// Serializing to datum
///
/// Default implementation uses CBOR and Varlena
pub trait Serializer: Serialize {
/// Serializes the value to Datum
///
/// Default implementation wraps the output of `Self::to_writer` into a Varlena
fn serialize(&self) -> pg_sys::Datum {
let mut serialized = StringInfo::new();

serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space fo the header
self.to_writer(&mut serialized);

let size = serialized.len() as usize;
let varlena = serialized.into_char_ptr();
unsafe {
crate::set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}

(varlena as *const pg_sys::varlena).into()
}

/// Serializes the value to a writer
///
/// Default implementation serializes to CBOR
fn to_writer<W: std::io::Write>(&self, writer: W) {
serde_cbor::to_writer(writer, &self).expect("failed to encode as CBOR");
}
}

/// Deserializing from datum
///
/// Default implementation uses CBOR and Varlena
pub trait Deserializer<'de>: Deserialize<'de> {
/// Deserializes datum into a value
///
/// Default implementation assumes datum to be a varlena and uses `Self::from_slice`
/// to deserialize the actual value.
fn deserialize(datum: pg_sys::Datum) -> Self {
unsafe {
let input = datum.cast_mut_ptr();
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = crate::varsize_any_exhdr(varlena);
let data = crate::vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
Self::from_slice(slice)
}
}

/// Deserializes datum into a value into a given context
/// Default implementation assumes datum to be a varlena and uses `Self::from_slice`
/// to deserialize the actual value.
fn deserialize_into_context(
mut memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
) -> Self {
unsafe {
memory_context.switch_to(|_| {
let input = datum.cast_mut_ptr();
// this gets the varlena Datum copied into this memory context
let varlena = pg_sys::pg_detoast_datum_copy(input as *mut pg_sys::varlena);
<Self as Deserializer<'de>>::deserialize(varlena.into())
})
}
}

/// Deserializes a value from a slice
///
/// Default implementation deserializes from CBOR.
fn from_slice(slice: &'de [u8]) -> Self {
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}
}

impl<T> IntoDatum for T
where
T: PostgresType + Serializer,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(Serializer::serialize(&self))
}

fn type_oid() -> pg_sys::Oid {
crate::rust_regtypein::<T>()
}
}

impl<'de, T> FromDatum for T
where
T: PostgresType + Deserializer<'de>,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
Some(<T as Deserializer>::deserialize(datum))
}
}

unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
Some(T::deserialize_into_context(memory_context, datum))
}
}
}

/// A type which can have it's [`core::any::TypeId`]s registered for Rust to SQL mapping.
///
/// An example use of this trait:
Expand Down
Loading