From 18f1038b57dd6ce34c826dce3c38173057ef02b2 Mon Sep 17 00:00:00 2001 From: QuenKar <47681251+QuenKar@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:16:47 +0800 Subject: [PATCH] feat: impl cast_with_opt --- src/common/time/src/date.rs | 10 + src/datatypes/src/data_type.rs | 22 +- src/datatypes/src/types/binary_type.rs | 1 + src/datatypes/src/types/cast.rs | 364 ++++++++++------------ src/datatypes/src/types/primitive_type.rs | 57 ++-- 5 files changed, 233 insertions(+), 221 deletions(-) diff --git a/src/common/time/src/date.rs b/src/common/time/src/date.rs index 021085f06013..e1a687d07cdb 100644 --- a/src/common/time/src/date.rs +++ b/src/common/time/src/date.rs @@ -136,4 +136,14 @@ mod tests { let d: Date = 42.into(); assert_eq!(42, d.val()); } + + #[test] + fn test_to_secs() { + let d = Date::from_str("1970-01-01").unwrap(); + assert_eq!(d.to_secs(), 0); + let d = Date::from_str("1970-01-02").unwrap(); + assert_eq!(d.to_secs(), 24 * 3600); + let d = Date::from_str("1970-01-03").unwrap(); + assert_eq!(d.to_secs(), 2 * 24 * 3600); + } } diff --git a/src/datatypes/src/data_type.rs b/src/datatypes/src/data_type.rs index 77511fda1ef5..3c5388586a43 100644 --- a/src/datatypes/src/data_type.rs +++ b/src/datatypes/src/data_type.rs @@ -114,6 +114,10 @@ impl ConcreteDataType { matches!(self, ConcreteDataType::Boolean(_)) } + pub fn is_string(&self) -> bool { + matches!(self, ConcreteDataType::String(_)) + } + pub fn is_stringifiable(&self) -> bool { matches!( self, @@ -151,6 +155,22 @@ impl ConcreteDataType { ) } + pub fn is_numeric(&self) -> bool { + matches!( + self, + ConcreteDataType::Int8(_) + | ConcreteDataType::Int16(_) + | ConcreteDataType::Int32(_) + | ConcreteDataType::Int64(_) + | ConcreteDataType::UInt8(_) + | ConcreteDataType::UInt16(_) + | ConcreteDataType::UInt32(_) + | ConcreteDataType::UInt64(_) + | ConcreteDataType::Float32(_) + | ConcreteDataType::Float64(_) + ) + } + pub fn numerics() -> Vec { vec![ ConcreteDataType::int8_datatype(), @@ -406,7 +426,7 @@ pub trait DataType: std::fmt::Debug + Send + Sync { /// use it as a timestamp. fn is_timestamp_compatible(&self) -> bool; - /// Casts the value to this DataType. + /// Casts the value to specific DataType. /// Return None if cast failed. fn try_cast(&self, from: Value) -> Option; } diff --git a/src/datatypes/src/types/binary_type.rs b/src/datatypes/src/types/binary_type.rs index a2d32ba39136..c9e8d7f12b6e 100644 --- a/src/datatypes/src/types/binary_type.rs +++ b/src/datatypes/src/types/binary_type.rs @@ -61,6 +61,7 @@ impl DataType for BinaryType { fn try_cast(&self, from: Value) -> Option { match from { Value::Binary(v) => Some(Value::Binary(v)), + Value::String(v) => Some(Value::Binary(Bytes::from(v.as_utf8().as_bytes()))), _ => None, } } diff --git a/src/datatypes/src/types/cast.rs b/src/datatypes/src/types/cast.rs index 35ee99c030a6..071826d6edee 100644 --- a/src/datatypes/src/types/cast.rs +++ b/src/datatypes/src/types/cast.rs @@ -12,261 +12,238 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::data_type::ConcreteDataType; -use crate::types::{IntervalType, TimeType}; +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Error, Result}; +use crate::types::TimeType; use crate::value::Value; +/// Cast options for cast functions. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct CastOption { + /// decide how to handle cast failures, + /// either return NULL (strict=false) or return ERR (strict=true) + pub strict: bool, +} + +impl CastOption { + pub fn is_strict(&self) -> bool { + self.strict + } +} + +pub fn cast_with_opt( + src_value: Value, + dest_type: &ConcreteDataType, + cast_option: &CastOption, +) -> Result { + if !can_cast_type(&src_value, dest_type) { + if cast_option.strict { + return Err(invalid_type_cast(&src_value, dest_type)); + } else { + return Ok(Value::Null); + } + } + let new_value = dest_type.try_cast(src_value.clone()); + match new_value { + Some(v) => Ok(v), + None => { + if cast_option.strict { + Err(invalid_type_cast(&src_value, dest_type)) + } else { + Ok(Value::Null) + } + } + } +} + // Return true if the src_value can be casted to dest_type, // Otherwise, return false. -pub fn can_cast_type(src_value: &Value, dest_type: ConcreteDataType) -> bool { +pub fn can_cast_type(src_value: &Value, dest_type: &ConcreteDataType) -> bool { use ConcreteDataType::*; - use IntervalType::*; use TimeType::*; - let src_type = src_value.data_type(); + let src_type = &src_value.data_type(); if src_type == dest_type { return true; } match (src_type, dest_type) { + // null type cast (_, Null(_)) => true, - // numeric types + + // boolean type cast + (_, Boolean(_)) => src_type.is_numeric() || src_type.is_string(), + (Boolean(_), _) => dest_type.is_numeric() || dest_type.is_string(), + + // numeric types cast ( - Boolean(_), UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_) | Float32(_) | Float64(_) | String(_), - ) => true, - ( - UInt8(_), - Boolean(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int16(_) | Int32(_) | Int64(_) + UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_) | Float32(_) | Float64(_) | String(_), ) => true, - ( - UInt16(_), - Boolean(_) | UInt8(_) | UInt32(_) | UInt64(_) | Int32(_) | Int64(_) | Float32(_) - | Float64(_) | String(_), - ) => true, - (UInt32(_), Boolean(_) | UInt64(_) | Int64(_) | Float64(_) | String(_)) => true, - (UInt64(_), Boolean(_) | String(_)) => true, - ( - Int8(_), - Boolean(_) | Int16(_) | Int32(_) | Int64(_) | Float32(_) | Float64(_) | String(_), - ) => true, - (Int16(_), Boolean(_) | Int32(_) | Int64(_) | Float32(_) | Float64(_) | String(_)) => true, - (Int32(_), Boolean(_) | Int64(_) | Float32(_) | Float64(_) | String(_) | Date(_)) => true, - (Int64(_), Boolean(_) | Float64(_) | String(_) | DateTime(_) | Timestamp(_) | Time(_)) => { - true - } - ( - Float32(_), - Boolean(_) | UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) - | Int32(_) | Int64(_) | Float64(_) | String(_), - ) => true, - ( - Float64(_), - Boolean(_) | UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) - | Int32(_) | Int64(_) | String(_), - ) => true, - ( - String(_), - Boolean(_) | UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) - | Int32(_) | Int64(_) | Float32(_) | Float64(_) | Date(_) | DateTime(_) | Timestamp(_) - | Time(_) | Interval(_), - ) => true, - // temporal types + + (String(_), Binary(_)) => true, + + // temporal types cast + // Date type (Date(_), Int32(_) | Timestamp(_) | String(_)) => true, + (Int32(_) | String(_) | Timestamp(_), Date(_)) => true, + // DateTime type (DateTime(_), Int64(_) | Timestamp(_) | String(_)) => true, - (Timestamp(_), Int64(_) | Date(_) | DateTime(_) | String(_)) => true, + (Int64(_) | Timestamp(_) | String(_), DateTime(_)) => true, + // Timestamp type + (Timestamp(_), Int64(_) | String(_)) => true, + (Int64(_) | String(_), Timestamp(_)) => true, + // Time type (Time(_), String(_)) => true, (Time(Second(_)), Int32(_)) => true, (Time(Millisecond(_)), Int32(_)) => true, (Time(Microsecond(_)), Int64(_)) => true, (Time(Nanosecond(_)), Int64(_)) => true, - (Interval(_), String(_)) => true, - (Interval(YearMonth(_)), Int32(_)) => true, - (Interval(DayTime(_)), Int64(_)) => true, - (Interval(MonthDayNano(_)), _) => false, + // TODO(QuenKar): interval type cast + // (Interval(_), String(_)) => true, + // other situations return false (_, _) => false, } } +fn invalid_type_cast(src_value: &Value, dest_type: &ConcreteDataType) -> Error { + let src_type = src_value.data_type(); + if src_type.is_string() { + error::CastTypeSnafu { + msg: format!("Could not parse string '{}' to {}", src_value, dest_type), + } + .build() + } else if src_type.is_numeric() && dest_type.is_numeric() { + error::CastTypeSnafu { + msg: format!( + "Type {} with value {} can't be cast because the value is out of range for the destination type {}", + src_type, + src_value, + dest_type + ), + } + .build() + } else { + error::CastTypeSnafu { + msg: format!( + "Type {} with value {} can't be cast to the destination type {}", + src_type, src_value, dest_type + ), + } + .build() + } +} + #[cfg(test)] mod tests { use std::str::FromStr; use common_base::bytes::StringBytes; use common_time::time::Time; - use common_time::{Date, DateTime, Interval, Timestamp}; + use common_time::{Date, DateTime, Timestamp}; use ordered_float::OrderedFloat; use super::*; macro_rules! test_can_cast { - ($src_value: expr, $($dest_type: ident),*) => { + ($src_value: expr, $($dest_type: ident),+) => { $( let val = $src_value; let t = ConcreteDataType::$dest_type(); - assert_eq!(can_cast_type(&val, t), true); + assert_eq!(can_cast_type(&val, &t), true); )* }; } - #[test] - fn test_can_cast_type() { - // uint8 -> other types - test_can_cast!( - Value::UInt8(0), - null_datatype, - uint8_datatype, - uint16_datatype, - uint32_datatype, - uint64_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype - ); - - // uint16 -> other types - test_can_cast!( - Value::UInt16(0), - null_datatype, - uint8_datatype, - uint16_datatype, - uint32_datatype, - uint64_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype - ); - - // uint32 -> other types - test_can_cast!( - Value::UInt32(0), - null_datatype, - uint32_datatype, - uint64_datatype, - int64_datatype, - float64_datatype, - string_datatype - ); - - // uint64 -> other types - test_can_cast!( - Value::UInt64(0), - null_datatype, - uint64_datatype, - string_datatype - ); - - // int8 -> other types - test_can_cast!( - Value::Int8(0), - null_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype - ); + macro_rules! test_primitive_cast { + ($($value: expr),*) => { + $( + test_can_cast!( + $value, + uint8_datatype, + uint16_datatype, + uint32_datatype, + uint64_datatype, + int8_datatype, + int16_datatype, + int32_datatype, + int64_datatype, + float32_datatype, + float64_datatype + ); + )* + }; + } - // int16 -> other types - test_can_cast!( - Value::Int16(0), - null_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype - ); + #[test] + fn test_cast_with_opt() { + // non-strict mode + let cast_option = CastOption { strict: false }; + let src_value = Value::Int8(-1); + let dest_type = ConcreteDataType::uint8_datatype(); + let res = cast_with_opt(src_value, &dest_type, &cast_option); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), Value::Null); - // int32 -> other types - test_can_cast!( - Value::Int32(0), - null_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype, - date_datatype + // strict mode + let cast_option = CastOption { strict: true }; + let src_value = Value::Int8(-1); + let dest_type = ConcreteDataType::uint8_datatype(); + let res = cast_with_opt(src_value, &dest_type, &cast_option); + assert!(res.is_err()); + assert_eq!( + res.unwrap_err().to_string(), + "Type Int8 with value -1 can't be cast because the value is out of range for the destination type UInt8" ); - // int64 -> other types - test_can_cast!( - Value::Int64(0), - null_datatype, - int64_datatype, - float64_datatype, - string_datatype, - datetime_datatype, - timestamp_second_datatype, - time_second_datatype + let src_value = Value::String(StringBytes::from("abc")); + let dest_type = ConcreteDataType::uint8_datatype(); + let res = cast_with_opt(src_value, &dest_type, &cast_option); + assert!(res.is_err()); + assert_eq!( + res.unwrap_err().to_string(), + "Could not parse string 'abc' to UInt8" ); - // float32 -> other types - test_can_cast!( - Value::Float32(OrderedFloat(0.0)), - null_datatype, - uint8_datatype, - uint16_datatype, - uint32_datatype, - uint64_datatype, - int8_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype + let src_value = Value::Timestamp(Timestamp::new_second(10)); + let dest_type = ConcreteDataType::int8_datatype(); + let res = cast_with_opt(src_value, &dest_type, &cast_option); + assert!(res.is_err()); + assert_eq!( + res.unwrap_err().to_string(), + "Type Timestamp with value 1970-01-01 08:00:10+0800 can't be cast to the destination type Int8" ); + } - // float64 -> other types - test_can_cast!( - Value::Float64(OrderedFloat(0.0)), - null_datatype, - uint8_datatype, - uint16_datatype, - uint32_datatype, - uint64_datatype, - int8_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float64_datatype, - string_datatype + #[test] + fn test_can_cast_type() { + // numeric cast + test_primitive_cast!( + Value::UInt8(0), + Value::UInt16(1), + Value::UInt32(2), + Value::UInt64(3), + Value::Int8(4), + Value::Int16(5), + Value::Int32(6), + Value::Int64(7), + Value::Float32(OrderedFloat(8.0)), + Value::Float64(OrderedFloat(9.0)), + Value::String(StringBytes::from("10")) ); // string -> other types test_can_cast!( Value::String(StringBytes::from("0")), null_datatype, - uint8_datatype, - uint16_datatype, - uint32_datatype, - uint64_datatype, - int8_datatype, - int16_datatype, - int32_datatype, - int64_datatype, - float32_datatype, - float64_datatype, - string_datatype, + boolean_datatype, date_datatype, datetime_datatype, timestamp_second_datatype, - time_second_datatype, - interval_year_month_datatype, - interval_day_time_datatype, - interval_month_day_nano_datatype + binary_datatype ); // date -> other types @@ -303,12 +280,5 @@ mod tests { null_datatype, string_datatype ); - - // interval -> other types - test_can_cast!( - Value::Interval(Interval::from_year_month(0)), - null_datatype, - string_datatype - ); } } diff --git a/src/datatypes/src/types/primitive_type.rs b/src/datatypes/src/types/primitive_type.rs index 42093992c7dd..7bf90c964a3c 100644 --- a/src/datatypes/src/types/primitive_type.rs +++ b/src/datatypes/src/types/primitive_type.rs @@ -289,19 +289,29 @@ macro_rules! define_non_timestamp_primitive { }; } -define_non_timestamp_primitive!(u8, UInt8, UInt8Type, UInt64Type, UInt8, Float32, Float64); define_non_timestamp_primitive!( - u16, UInt16, UInt16Type, UInt64Type, UInt8, UInt16, Float32, Float64 + u8, UInt8, UInt8Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 ); define_non_timestamp_primitive!( - u32, UInt32, UInt32Type, UInt64Type, UInt8, UInt16, UInt32, Float32, Float64 + u16, UInt16, UInt16Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 ); define_non_timestamp_primitive!( - u64, UInt64, UInt64Type, UInt64Type, UInt8, UInt16, UInt32, UInt64, Float32, Float64 + u32, UInt32, UInt32Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 ); -define_non_timestamp_primitive!(i8, Int8, Int8Type, Int64Type, Int8, Float32, Float64); define_non_timestamp_primitive!( - i16, Int16, Int16Type, Int64Type, Int8, Int16, UInt8, Float32, Float64 + u64, UInt64, UInt64Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 +); +define_non_timestamp_primitive!( + i8, Int8, Int8Type, Int64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 +); +define_non_timestamp_primitive!( + i16, Int16, Int16Type, Int64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, + Float32, Float64 ); define_non_timestamp_primitive!( @@ -309,27 +319,32 @@ define_non_timestamp_primitive!( Float32, Float32Type, Float64Type, - Float32, - UInt8, - UInt16, Int8, Int16, - Int32 + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float32, + Float64 ); define_non_timestamp_primitive!( f64, Float64, Float64Type, Float64Type, - Float32, - Float64, - UInt8, - UInt16, - UInt32, Int8, Int16, Int32, - Int64 + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float32, + Float64 ); // Timestamp primitive: @@ -389,33 +404,26 @@ impl DataType for Int64Type { } impl DataType for Int32Type { - #[doc = " Name of this data type."] fn name(&self) -> &str { "Int32" } - #[doc = " Returns id of the Logical data type."] fn logical_type_id(&self) -> LogicalTypeId { LogicalTypeId::Int32 } - #[doc = " Returns the default value of this type."] fn default_value(&self) -> Value { Value::Int32(0) } - #[doc = " Convert this type as [arrow::datatypes::DataType]."] fn as_arrow_type(&self) -> ArrowDataType { ArrowDataType::Int32 } - #[doc = " Creates a mutable vector with given `capacity` of this type."] fn create_mutable_vector(&self, capacity: usize) -> Box { Box::new(PrimitiveVectorBuilder::::with_capacity(capacity)) } - #[doc = " Returns true if the data type is compatible with timestamp type so we can"] - #[doc = " use it as a timestamp."] fn is_timestamp_compatible(&self) -> bool { false } @@ -426,8 +434,11 @@ impl DataType for Int32Type { Value::Int8(v) => num::cast::cast(v).map(Value::Int32), Value::Int16(v) => num::cast::cast(v).map(Value::Int32), Value::Int32(v) => Some(Value::Int32(v)), + Value::Int64(v) => num::cast::cast(v).map(Value::Int64), Value::UInt8(v) => num::cast::cast(v).map(Value::Int32), Value::UInt16(v) => num::cast::cast(v).map(Value::Int32), + Value::UInt32(v) => num::cast::cast(v).map(Value::UInt32), + Value::UInt64(v) => num::cast::cast(v).map(Value::UInt64), Value::Float32(v) => num::cast::cast(v).map(Value::Int32), Value::Float64(v) => num::cast::cast(v).map(Value::Int32), Value::String(v) => v.as_utf8().parse::().map(Value::Int32).ok(),