diff --git a/README.md b/README.md index 860a9c8c..20baebf8 100644 --- a/README.md +++ b/README.md @@ -160,14 +160,11 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as: cass_statement_bind_custom[by_name] - Binding is not implemented for custom types in the Rust driver.
Binding Decimal and Duration types requires encoding raw bytes into BigDecimal and CqlDuration types in the Rust driver.
Note: The driver does not validate the types of the values passed to queries. + Binding is not implemented for custom types in the Rust driver.
Binding Decimal type requires encoding raw bytes into BigDecimal type in the Rust driver.
Note: The driver does not validate the types of the values passed to queries. cass_statement_bind_decimal[by_name] - - cass_statement_bind_duration[by_name] - Future @@ -190,40 +187,27 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as: cass_collection_append_custom[_n] - Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the appended value is compatible with the type of the collection items. + Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the appended value is compatible with the type of the collection items. cass_collection_append_decimal - - cass_collection_append_duration - User Defined Type cass_user_type_set_custom[by_name] - Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the value being set for a field of the UDT is compatible with the field's actual type. + Unimplemented because of the same reasons as binding for statements.
Note: The driver does not check whether the type of the value being set for a field of the UDT is compatible with the field's actual type. cass_user_type_set_decimal[by_name] - - cass_user_type_set_duration[by_name] - Value - - cass_value_is_duration - Unimplemented - cass_value_get_decimal - Getting raw bytes of Decimal and Duration values requires lazy deserialization feature in the Rust driver. - - - cass_value_get_duration + Getting raw bytes of Decimal values requires lazy deserialization feature in the Rust driver. cass_value_get_bytes diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index f0941859..8fee47cd 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -137,9 +137,8 @@ macro_rules! make_appender { } // TODO: Types for which binding is not implemented yet: -// custom - Not implemented in Rust driver? +// custom - Not implemented in Rust driver // decimal -// duration - DURATION not implemented in Rust Driver macro_rules! invoke_binder_maker_macro_with_type { (null, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -277,6 +276,21 @@ macro_rules! invoke_binder_maker_macro_with_type { [v @ crate::inet::CassInet] ); }; + (duration, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { + $macro_name!( + $this, + $consume_v, + $fn, + |m, d, n| { + Ok(Some(Duration(scylla::frame::value::CqlDuration { + months: m, + days: d, + nanoseconds: n + }))) + }, + [m @ cass_int32_t, d @ cass_int32_t, n @ cass_int64_t] + ); + }; (collection, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { $macro_name!( $this, diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index 845ea79b..3abc3994 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -88,6 +88,7 @@ make_binders!(string_n, cass_collection_append_string_n); make_binders!(bytes, cass_collection_append_bytes); make_binders!(uuid, cass_collection_append_uuid); make_binders!(inet, cass_collection_append_inet); +make_binders!(duration, cass_collection_append_duration); make_binders!(collection, cass_collection_append_collection); make_binders!(tuple, cass_collection_append_tuple); make_binders!(user_type, cass_collection_append_user_type); diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 1ff33949..7983d9ff 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -915,12 +915,21 @@ pub unsafe extern "C" fn cass_value_data_type(value: *const CassValue) -> *const Arc::as_ptr(&value_from_raw.value_type) } +macro_rules! val_ptr_to_ref_ensure_non_null { + ($ptr:ident) => {{ + if $ptr.is_null() { + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + ptr_to_ref($ptr) + }}; +} + #[no_mangle] pub unsafe extern "C" fn cass_value_get_float( value: *const CassValue, output: *mut cass_float_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Float(f))) => std::ptr::write(output, f), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -935,7 +944,7 @@ pub unsafe extern "C" fn cass_value_get_double( value: *const CassValue, output: *mut cass_double_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Double(d))) => std::ptr::write(output, d), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -950,7 +959,7 @@ pub unsafe extern "C" fn cass_value_get_bool( value: *const CassValue, output: *mut cass_bool_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Boolean(b))) => { std::ptr::write(output, b as cass_bool_t) @@ -967,7 +976,7 @@ pub unsafe extern "C" fn cass_value_get_int8( value: *const CassValue, output: *mut cass_int8_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::TinyInt(i))) => std::ptr::write(output, i), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -982,7 +991,7 @@ pub unsafe extern "C" fn cass_value_get_int16( value: *const CassValue, output: *mut cass_int16_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::SmallInt(i))) => std::ptr::write(output, i), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -997,7 +1006,7 @@ pub unsafe extern "C" fn cass_value_get_uint32( value: *const CassValue, output: *mut cass_uint32_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Date(u))) => std::ptr::write(output, u.0), // FIXME: hack Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -1012,7 +1021,7 @@ pub unsafe extern "C" fn cass_value_get_int32( value: *const CassValue, output: *mut cass_int32_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Int(i))) => std::ptr::write(output, i), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -1027,7 +1036,7 @@ pub unsafe extern "C" fn cass_value_get_int64( value: *const CassValue, output: *mut cass_int64_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::BigInt(i))) => std::ptr::write(output, i), Some(Value::RegularValue(CqlValue::Counter(i))) => { @@ -1049,7 +1058,7 @@ pub unsafe extern "C" fn cass_value_get_uuid( value: *const CassValue, output: *mut CassUuid, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Uuid(uuid))) => std::ptr::write(output, uuid.into()), Some(Value::RegularValue(CqlValue::Timeuuid(uuid))) => { @@ -1067,7 +1076,7 @@ pub unsafe extern "C" fn cass_value_get_inet( value: *const CassValue, output: *mut CassInet, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match val.value { Some(Value::RegularValue(CqlValue::Inet(inet))) => std::ptr::write(output, inet.into()), Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -1083,7 +1092,7 @@ pub unsafe extern "C" fn cass_value_get_string( output: *mut *const c_char, output_size: *mut size_t, ) -> CassError { - let val: &CassValue = ptr_to_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); match &val.value { // It seems that cpp driver doesn't check the type - you can call _get_string // on any type and get internal represenation. I don't see how to do it easily in @@ -1102,17 +1111,35 @@ pub unsafe extern "C" fn cass_value_get_string( CassError::CASS_OK } +#[no_mangle] +pub unsafe extern "C" fn cass_value_get_duration( + value: *const CassValue, + months: *mut cass_int32_t, + days: *mut cass_int32_t, + nanos: *mut cass_int64_t, +) -> CassError { + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + + match &val.value { + Some(Value::RegularValue(CqlValue::Duration(duration))) => { + std::ptr::write(months, duration.months); + std::ptr::write(days, duration.days); + std::ptr::write(nanos, duration.nanoseconds); + } + Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + None => return CassError::CASS_ERROR_LIB_NULL_VALUE, + } + + CassError::CASS_OK +} + #[no_mangle] pub unsafe extern "C" fn cass_value_get_bytes( value: *const CassValue, output: *mut *const cass_byte_t, output_size: *mut size_t, ) -> CassError { - if value.is_null() { - return CassError::CASS_ERROR_LIB_NULL_VALUE; - } - - let value_from_raw: &CassValue = ptr_to_ref(value); + let value_from_raw: &CassValue = val_ptr_to_ref_ensure_non_null!(value); // FIXME: This should be implemented for all CQL types // Note: currently rust driver does not allow to get raw bytes of the CQL value. @@ -1138,12 +1165,19 @@ pub unsafe extern "C" fn cass_value_is_null(value: *const CassValue) -> cass_boo pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> cass_bool_t { let val = ptr_to_ref(value); - match val.value { - Some(Value::CollectionValue(Collection::List(_))) => true as cass_bool_t, - Some(Value::CollectionValue(Collection::Set(_))) => true as cass_bool_t, - Some(Value::CollectionValue(Collection::Map(_))) => true as cass_bool_t, - _ => false as cass_bool_t, - } + matches!( + val.value_type.get_value_type(), + CassValueType::CASS_VALUE_TYPE_LIST + | CassValueType::CASS_VALUE_TYPE_SET + | CassValueType::CASS_VALUE_TYPE_MAP + ) as cass_bool_t +} + +#[no_mangle] +pub unsafe extern "C" fn cass_value_is_duration(value: *const CassValue) -> cass_bool_t { + let val = ptr_to_ref(value); + + (val.value_type.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) as cass_bool_t } #[no_mangle] @@ -1505,26 +1539,12 @@ pub unsafe extern "C" fn cass_value_get_decimal( scale: *mut cass_int32_t, ) -> CassError { } -#[no_mangle] -pub unsafe extern "C" fn cass_value_get_duration( - value: *const CassValue, - months: *mut cass_int32_t, - days: *mut cass_int32_t, - nanos: *mut cass_int64_t, -) -> CassError { -} extern "C" { pub fn cass_value_data_type(value: *const CassValue) -> *const CassDataType; } extern "C" { pub fn cass_value_type(value: *const CassValue) -> CassValueType; } -extern "C" { - pub fn cass_value_is_collection(value: *const CassValue) -> cass_bool_t; -} -extern "C" { - pub fn cass_value_is_duration(value: *const CassValue) -> cass_bool_t; -} extern "C" { pub fn cass_value_item_count(collection: *const CassValue) -> size_t; } diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index cf042d55..ec1111c5 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -462,6 +462,12 @@ make_binders!( cass_statement_bind_inet_by_name, cass_statement_bind_inet_by_name_n ); +make_binders!( + duration, + cass_statement_bind_duration, + cass_statement_bind_duration_by_name, + cass_statement_bind_duration_by_name_n +); make_binders!( collection, cass_statement_bind_collection, diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 04109574..31195ae9 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -105,6 +105,7 @@ make_binders!(string_n, cass_tuple_set_string_n); make_binders!(bytes, cass_tuple_set_bytes); make_binders!(uuid, cass_tuple_set_uuid); make_binders!(inet, cass_tuple_set_inet); +make_binders!(duration, cass_tuple_set_duration); make_binders!(collection, cass_tuple_set_collection); make_binders!(tuple, cass_tuple_set_tuple); make_binders!(user_type, cass_tuple_set_user_type); diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 53b16a12..4a651f85 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -181,6 +181,12 @@ make_binders!( cass_user_type_set_inet_by_name, cass_user_type_set_inet_by_name_n ); +make_binders!( + duration, + cass_user_type_set_duration, + cass_user_type_set_duration_by_name, + cass_user_type_set_duration_by_name_n +); make_binders!( collection, cass_user_type_set_collection, diff --git a/scylla-rust-wrapper/src/value.rs b/scylla-rust-wrapper/src/value.rs index 1e5558ce..bb760669 100644 --- a/scylla-rust-wrapper/src/value.rs +++ b/scylla-rust-wrapper/src/value.rs @@ -1,7 +1,10 @@ use std::{convert::TryInto, net::IpAddr}; use scylla::{ - frame::{response::result::ColumnType, value::CqlDate}, + frame::{ + response::result::ColumnType, + value::{CqlDate, CqlDuration}, + }, serialize::{ value::{ BuiltinSerializationErrorKind, MapSerializationErrorKind, SerializeCql, @@ -40,6 +43,7 @@ pub enum CassCqlValue { Uuid(Uuid), Date(CqlDate), Inet(IpAddr), + Duration(CqlDuration), Tuple(Vec>), List(Vec), Map(Vec<(CassCqlValue, CassCqlValue)>), @@ -117,6 +121,9 @@ impl CassCqlValue { CassCqlValue::Inet(v) => { ::serialize(v, &ColumnType::Inet, writer) } + CassCqlValue::Duration(v) => { + ::serialize(v, &ColumnType::Duration, writer) + } CassCqlValue::Tuple(fields) => serialize_tuple_like(fields.iter(), writer), CassCqlValue::List(l) => serialize_sequence(l.len(), l.iter(), writer), CassCqlValue::Map(m) => { diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp index f66cc759..353f6beb 100644 --- a/src/testing_unimplemented.cpp +++ b/src/testing_unimplemented.cpp @@ -169,13 +169,6 @@ cass_collection_append_decimal(CassCollection* collection, cass_int32_t scale){ throw std::runtime_error("UNIMPLEMENTED cass_collection_append_decimal\n"); } -CASS_EXPORT CassError -cass_collection_append_duration(CassCollection* collection, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_collection_append_duration\n"); -} CASS_EXPORT const CassValue* cass_column_meta_field_by_name(const CassColumnMeta* column_meta, const char* name){ @@ -373,22 +366,6 @@ cass_statement_bind_decimal_by_name(CassStatement* statement, throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_decimal_by_name\n"); } CASS_EXPORT CassError -cass_statement_bind_duration(CassStatement* statement, - size_t index, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_duration\n"); -} -CASS_EXPORT CassError -cass_statement_bind_duration_by_name(CassStatement* statement, - const char* name, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_duration_by_name\n"); -} -CASS_EXPORT CassError cass_statement_set_custom_payload(CassStatement* statement, const CassCustomPayload* payload){ throw std::runtime_error("UNIMPLEMENTED cass_statement_set_custom_payload\n"); @@ -455,14 +432,6 @@ cass_tuple_set_decimal(CassTuple* tuple, throw std::runtime_error("UNIMPLEMENTED cass_tuple_set_decimal\n"); } CASS_EXPORT CassError -cass_tuple_set_duration(CassTuple* tuple, - size_t index, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_tuple_set_duration\n"); -} -CASS_EXPORT CassError cass_user_type_set_custom(CassUserType* user_type, size_t index, const char* class_name, @@ -487,24 +456,9 @@ cass_user_type_set_decimal_by_name(CassUserType* user_type, throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_decimal_by_name\n"); } CASS_EXPORT CassError -cass_user_type_set_duration_by_name(CassUserType* user_type, - const char* name, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_duration_by_name\n"); -} -CASS_EXPORT CassError cass_value_get_decimal(const CassValue* value, const cass_byte_t** varint, size_t* varint_size, cass_int32_t* scale){ throw std::runtime_error("UNIMPLEMENTED cass_value_get_decimal\n"); -} -CASS_EXPORT CassError -cass_value_get_duration(const CassValue* value, - cass_int32_t* months, - cass_int32_t* days, - cass_int64_t* nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_value_get_duration\n"); } \ No newline at end of file