Skip to content

Commit

Permalink
Merge pull request #3944 from JonBoyleCoding/diagnostics-branch
Browse files Browse the repository at this point in the history
Return Error From Postgres Float Deserialization
  • Loading branch information
weiznich authored May 24, 2024
2 parents 3db4ed2 + c19ace7 commit db6730c
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Increasing the minimal supported Rust version will always be coupled at least wi

* The minimal officially supported rustc version is now 1.78.0
* Deprecated `sql_function!` in favour of `define_sql_function!` which provides compatibility with `#[dsl::auto_type]`
* Deserialization error messages now contain information about the field that failed to deserialize

## [2.1.0] 2023-05-26

Expand Down
8 changes: 7 additions & 1 deletion diesel/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,13 @@ where
use crate::row::Field;

let field = row.get(0).ok_or(crate::result::UnexpectedEndOfRow)?;
T::from_nullable_sql(field.value())
T::from_nullable_sql(field.value()).map_err(|e| {
if e.is::<crate::result::UnexpectedNullError>() {
e
} else {
Box::new(crate::result::DeserializeFieldError::new(field, e))
}
})
}
}

Expand Down
1 change: 1 addition & 0 deletions diesel/src/pg/connection/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl PgResult {
)
}

#[inline(always)] // benchmarks indicate a ~1.7% improvement in instruction count for this
pub(super) fn column_name(&self, col_idx: usize) -> Option<&str> {
self.column_name_map
.get_or_init(|| {
Expand Down
54 changes: 34 additions & 20 deletions diesel/src/pg/types/floats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,23 @@ impl ToSql<sql_types::Numeric, Pg> for PgNumeric {
impl FromSql<sql_types::Float, Pg> for f32 {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 4,
"Received more than 4 bytes while decoding \
an f32. Was a double accidentally marked as float?"
);
debug_assert!(
bytes.len() >= 4,
"Received less than 4 bytes while decoding \
an f32."
);

if bytes.len() < 4 {
return deserialize::Result::Err(
"Received less than 4 bytes while decoding an f32. \
Was a numeric accidentally marked as float?"
.into(),
);
}

if bytes.len() > 4 {
return deserialize::Result::Err(
"Received more than 4 bytes while decoding an f32. \
Was a double accidentally marked as float?"
.into(),
);
}

bytes
.read_f32::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
Expand All @@ -137,16 +144,23 @@ impl FromSql<sql_types::Float, Pg> for f32 {
impl FromSql<sql_types::Double, Pg> for f64 {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 8,
"Received less than 8 bytes while decoding \
an f64. Was a float accidentally marked as double?"
);
debug_assert!(
bytes.len() >= 8,
"Received more than 8 bytes while decoding \
an f64. Was a numeric accidentally marked as double?"
);

if bytes.len() < 8 {
return deserialize::Result::Err(
"Received less than 8 bytes while decoding an f64. \
Was a float accidentally marked as double?"
.into(),
);
}

if bytes.len() > 8 {
return deserialize::Result::Err(
"Received more than 8 bytes while decoding an f64. \
Was a numeric accidentally marked as double?"
.into(),
);
}

bytes
.read_f64::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
Expand Down
79 changes: 48 additions & 31 deletions diesel/src/pg/types/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ impl ToSql<sql_types::Oid, Pg> for u32 {

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::SmallInt, Pg> for i16 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 2,
"Received more than 2 bytes decoding i16. \
Was an Integer expression accidentally identified as SmallInt?"
);
debug_assert!(
bytes.len() >= 2,
"Received fewer than 2 bytes decoding i16. \
Was an expression of a different type accidentally identified \
as SmallInt?"
);
if bytes.len() < 2 {
return emit_size_error(
"Received less than 2 bytes while decoding an i16. \
Was an expression of a different type accidentally marked as SmallInt?",
);
}

if bytes.len() > 2 {
return emit_size_error(
"Received more than 2 bytes while decoding an i16. \
Was an Integer expression accidentally marked as SmallInt?",
);
}
bytes
.read_i16::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
Expand All @@ -44,38 +47,52 @@ impl FromSql<sql_types::SmallInt, Pg> for i16 {

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Integer, Pg> for i32 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 4,
"Received more than 4 bytes decoding i32. \
Was a BigInt expression accidentally identified as Integer?"
);
debug_assert!(
bytes.len() >= 4,
"Received fewer than 4 bytes decoding i32. \
Was a SmallInt expression accidentally identified as Integer?"
);
if bytes.len() < 4 {
return emit_size_error(
"Received less than 4 bytes while decoding an i32. \
Was an SmallInt expression accidentally marked as Integer?",
);
}

if bytes.len() > 4 {
return emit_size_error(
"Received more than 4 bytes while decoding an i32. \
Was an BigInt expression accidentally marked as Integer?",
);
}
bytes
.read_i32::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
}
}

#[cold]
#[inline(never)]
fn emit_size_error<T>(var_name: &str) -> deserialize::Result<T> {
deserialize::Result::Err(var_name.into())
}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::BigInt, Pg> for i64 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 8,
"Received more than 8 bytes decoding i64. \
Was an expression of a different type misidentified as BigInt?"
);
debug_assert!(
bytes.len() >= 8,
"Received fewer than 8 bytes decoding i64. \
Was an Integer expression misidentified as BigInt?"
);
if bytes.len() < 8 {
return emit_size_error(
"Received less than 8 bytes while decoding an i64. \
Was an Integer expression accidentally marked as BigInt?",
);
}

if bytes.len() > 8 {
return emit_size_error(
"Received more than 8 bytes while decoding an i64. \
Was an expression of a different type expression accidentally marked as BigInt?"
);
}
bytes
.read_i64::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
Expand Down
44 changes: 44 additions & 0 deletions diesel/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,47 @@ impl fmt::Display for EmptyChangeset {
}

impl StdError for EmptyChangeset {}

/// An error occurred while deserializing a field
#[derive(Debug)]
#[non_exhaustive]
pub struct DeserializeFieldError {
/// The name of the field that failed to deserialize
pub field_name: Option<String>,
/// The error that occurred while deserializing the field
pub error: Box<dyn StdError + Send + Sync>,
}

impl DeserializeFieldError {
#[cold]
pub(crate) fn new<'a, F, DB>(field: F, error: Box<dyn std::error::Error + Send + Sync>) -> Self
where
DB: crate::backend::Backend,
F: crate::row::Field<'a, DB>,
{
DeserializeFieldError {
field_name: field.field_name().map(|s| s.to_string()),
error,
}
}
}

impl StdError for DeserializeFieldError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&*self.error)
}
}

impl fmt::Display for DeserializeFieldError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref field_name) = self.field_name {
write!(
f,
"Error deserializing field '{}': {}",
field_name, self.error
)
} else {
write!(f, "Error deserializing field: {}", self.error)
}
}
}
1 change: 1 addition & 0 deletions diesel_bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fast_run = []

[profile.release]
lto = true
debug = true
codegen-units = 1

[patch.crates-io]
Expand Down
114 changes: 113 additions & 1 deletion diesel_tests/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ where

#[cfg(feature = "postgres")]
#[test]
#[should_panic(expected = "Received more than 4 bytes decoding i32")]
#[should_panic(expected = "Received more than 4 bytes while decoding an i32")]
fn debug_check_catches_reading_bigint_as_i32_when_using_raw_sql() {
use diesel::dsl::sql;
use diesel::sql_types::Integer;
Expand Down Expand Up @@ -1574,3 +1574,115 @@ fn citext_fields() {

assert_eq!(lowercase_in_db, Some("lowercase_value".to_string()));
}

#[test]
#[cfg(feature = "postgres")]
fn deserialize_wrong_primitive_gives_good_error() {
let conn = &mut connection();

diesel::sql_query(
"CREATE TABLE test_table(\
bool BOOLEAN,
small SMALLINT, \
int INTEGER, \
big BIGINT, \
float FLOAT4, \
double FLOAT8,
text TEXT)",
)
.execute(conn)
.unwrap();
diesel::sql_query("INSERT INTO test_table VALUES('t', 1, 1, 1, 1, 1, 'long text long text')")
.execute(conn)
.unwrap();

let res = diesel::dsl::sql::<SmallInt>("SELECT bool FROM test_table").get_result::<i16>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'bool': \
Received less than 2 bytes while decoding an i16. \
Was an expression of a different type accidentally marked as SmallInt?"
);

let res = diesel::dsl::sql::<SmallInt>("SELECT int FROM test_table").get_result::<i16>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'int': \
Received more than 2 bytes while decoding an i16. \
Was an Integer expression accidentally marked as SmallInt?"
);

let res = diesel::dsl::sql::<Integer>("SELECT small FROM test_table").get_result::<i32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'small': \
Received less than 4 bytes while decoding an i32. \
Was an SmallInt expression accidentally marked as Integer?"
);

let res = diesel::dsl::sql::<Integer>("SELECT big FROM test_table").get_result::<i32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'big': \
Received more than 4 bytes while decoding an i32. \
Was an BigInt expression accidentally marked as Integer?"
);

let res = diesel::dsl::sql::<BigInt>("SELECT int FROM test_table").get_result::<i64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'int': \
Received less than 8 bytes while decoding an i64. \
Was an Integer expression accidentally marked as BigInt?"
);

let res = diesel::dsl::sql::<BigInt>("SELECT text FROM test_table").get_result::<i64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'text': \
Received more than 8 bytes while decoding an i64. \
Was an expression of a different type expression accidentally marked as BigInt?"
);

let res = diesel::dsl::sql::<Float>("SELECT small FROM test_table").get_result::<f32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'small': \
Received less than 4 bytes while decoding an f32. \
Was a numeric accidentally marked as float?"
);

let res = diesel::dsl::sql::<Float>("SELECT double FROM test_table").get_result::<f32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'double': \
Received more than 4 bytes while decoding an f32. \
Was a double accidentally marked as float?"
);

let res = diesel::dsl::sql::<Double>("SELECT float FROM test_table").get_result::<f64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'float': \
Received less than 8 bytes while decoding an f64. \
Was a float accidentally marked as double?"
);

let res = diesel::dsl::sql::<Double>("SELECT text FROM test_table").get_result::<f64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'text': \
Received more than 8 bytes while decoding an f64. \
Was a numeric accidentally marked as double?"
);
}

0 comments on commit db6730c

Please sign in to comment.