Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make #[derive(PostgresType)] impl its own FromDatum #1381

2 changes: 1 addition & 1 deletion pgrx-examples/custom_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ no-schema-generation = [ "pgrx/no-schema-generation", "pgrx-tests/no-schema-gene
[dependencies]
pgrx = { path = "../../pgrx", default-features = false }
maplit = "1.0.2"
serde = "1.0"
serde = { version = "1.0", features = ["derive"] }

[dev-dependencies]
pgrx-tests = { path = "../../pgrx-tests" }
Expand Down
5 changes: 3 additions & 2 deletions pgrx-examples/custom_types/src/fixed_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
use core::ffi::CStr;
use pgrx::prelude::*;
use pgrx::{opname, pg_operator, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub struct FixedF32Array {
array: [f32; 91],
array: [f32; 32],
workingjubilee marked this conversation as resolved.
Show resolved Hide resolved
}

impl PgVarlenaInOutFuncs for FixedF32Array {
Expand Down
46 changes: 45 additions & 1 deletion pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,52 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
};

// all #[derive(PostgresType)] need to implement that trait
// and also the FromDatum and IntoDatum
stream.extend(quote! {
impl #generics ::pgrx::PostgresType for #name #generics { }
impl #generics ::pgrx::datum::PostgresType for #name #generics { }

impl #generics ::pgrx::datum::IntoDatum for #name #generics {
fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
#[allow(deprecated)]
Some(unsafe { ::pgrx::cbor_encode(&self) }.into())
}

fn type_oid() -> ::pgrx::pg_sys::Oid {
::pgrx::wrappers::rust_regtypein::<Self>()
}
}

impl #generics ::pgrx::datum::FromDatum for #name #generics {
unsafe fn from_polymorphic_datum(
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
#[allow(deprecated)]
::pgrx::cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
mut memory_context: ::pgrx::memcxt::PgMemoryContexts,
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
// this gets the varlena Datum copied into this memory context
let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
Self::from_datum(varlena.into(), is_null)
})
}
}
}
});

// and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/postgres_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use pgrx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub struct VarlenaType {
a: f32,
Expand All @@ -38,7 +38,7 @@ impl PgVarlenaInOutFuncs for VarlenaType {
}
}

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub enum VarlenaEnumType {
A,
Expand Down
86 changes: 8 additions & 78 deletions pgrx/src/datum/varlena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT};
use crate::{
pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType,
StringInfo,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, StringInfo,
};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
Expand Down Expand Up @@ -60,8 +59,9 @@ impl Clone for PallocdVarlena {
/// use std::str::FromStr;
///
/// use pgrx::prelude::*;
/// use serde::{Serialize, Deserialize};
///
/// #[derive(Copy, Clone, PostgresType)]
/// #[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
/// #[pgvarlena_inoutfuncs]
/// struct MyType {
/// a: f32,
Expand Down Expand Up @@ -378,50 +378,8 @@ where
}
}

impl<T> IntoDatum for T
where
T: PostgresType + Serialize,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(cbor_encode(&self).into())
}

fn type_oid() -> pg_sys::Oid {
crate::rust_regtypein::<T>()
}
}

impl<'de, T> FromDatum for T
where
T: PostgresType + Deserialize<'de>,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode_into_context(memory_context, datum.cast_mut_ptr())
}
}
}

fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
#[doc(hidden)]
pub unsafe fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
Expand All @@ -439,6 +397,7 @@ where
varlena as *const pg_sys::varlena
}

#[doc(hidden)]
pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
Expand All @@ -450,6 +409,8 @@ where
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}

#[doc(hidden)]
#[deprecated(since = "0.12.0", note = "just use the FromDatum impl")]
pub unsafe fn cbor_decode_into_context<'de, T>(
mut memory_context: PgMemoryContexts,
input: *mut pg_sys::varlena,
Expand All @@ -464,37 +425,6 @@ where
})
}

#[allow(dead_code)]
fn json_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();

serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header
serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON");

let size = serialized.len();
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}

varlena as *const pg_sys::varlena
}

#[allow(dead_code)]
unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_json::from_slice(slice).expect("failed to decode JSON")
}

workingjubilee marked this conversation as resolved.
Show resolved Hide resolved
unsafe impl<T> SqlTranslatable for PgVarlena<T>
where
T: SqlTranslatable + Copy,
Expand Down