From 700cd1223f7f5deccbd03a656f1e7605d8c13764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Tue, 26 Nov 2024 16:01:42 +0100 Subject: [PATCH] argconv: pointers This commit introduces a `CassPtr` type, generic over pointer `Properties`. It allows specific pointer-to-reference conversions based on the guarantees provided by the pointer type. Existing `Ref/Box/Arc`FFIs are adjusted, so they now allow interaction with new pointer type. You can say, that for a user of `argconv` API, new type is opaque. The only way to operate on it is to use the corresponding ffi API. --- scylla-rust-wrapper/src/argconv.rs | 369 +++++++++-- scylla-rust-wrapper/src/batch.rs | 48 +- scylla-rust-wrapper/src/binding.rs | 34 +- scylla-rust-wrapper/src/cass_types.rs | 133 ++-- scylla-rust-wrapper/src/cluster.rs | 265 ++++---- scylla-rust-wrapper/src/collection.rs | 92 ++- scylla-rust-wrapper/src/exec_profile.rs | 125 ++-- scylla-rust-wrapper/src/future.rs | 148 +++-- .../src/integration_testing.rs | 16 +- scylla-rust-wrapper/src/logging.rs | 19 +- scylla-rust-wrapper/src/metadata.rs | 236 +++---- scylla-rust-wrapper/src/prepared.rs | 32 +- scylla-rust-wrapper/src/query_error.rs | 54 +- scylla-rust-wrapper/src/query_result.rs | 597 ++++++++++-------- scylla-rust-wrapper/src/retry_policy.rs | 11 +- scylla-rust-wrapper/src/session.rs | 284 +++++---- scylla-rust-wrapper/src/ssl.rs | 32 +- scylla-rust-wrapper/src/statement.rs | 60 +- scylla-rust-wrapper/src/testing.rs | 2 +- scylla-rust-wrapper/src/tuple.rs | 26 +- scylla-rust-wrapper/src/user_type.rs | 16 +- scylla-rust-wrapper/src/uuid.rs | 19 +- 22 files changed, 1559 insertions(+), 1059 deletions(-) diff --git a/scylla-rust-wrapper/src/argconv.rs b/scylla-rust-wrapper/src/argconv.rs index 058ea45e..9a06ed09 100644 --- a/scylla-rust-wrapper/src/argconv.rs +++ b/scylla-rust-wrapper/src/argconv.rs @@ -1,7 +1,9 @@ use crate::types::size_t; use std::cmp::min; use std::ffi::CStr; +use std::marker::PhantomData; use std::os::raw::c_char; +use std::ptr::NonNull; use std::sync::Arc; pub unsafe fn ptr_to_cstr(ptr: *const c_char) -> Option<&'static str> { @@ -71,35 +73,307 @@ macro_rules! make_c_str { #[cfg(test)] pub(crate) use make_c_str; +mod sealed { + pub trait Sealed {} +} + +/// A trait representing an ownership of the pointer. +pub trait Ownership: sealed::Sealed {} + +/// Represents an exclusive ownership of the pointer. +/// For Box-allocated pointers. +pub struct Exclusive; +impl sealed::Sealed for Exclusive {} +impl Ownership for Exclusive {} + +/// Represents a shared ownership of the pointer. +/// For Arc-allocated pointers. +pub struct Shared; +impl sealed::Sealed for Shared {} +impl Ownership for Shared {} + +/// Represents a borrowed pointer. +/// For pointers created from some valid reference. +pub struct Borrowed; +impl sealed::Sealed for Borrowed {} +impl Ownership for Borrowed {} + +/// A trait representing mutability of the pointer. +pub trait Mutability: sealed::Sealed {} + +/// Represents immutable pointer. +pub struct Const; +impl sealed::Sealed for Const {} +impl Mutability for Const {} + +/// Represents mutable pointer. +pub struct Mut; +impl sealed::Sealed for Mut {} +impl Mutability for Mut {} + +/// Represents pointer properties (ownership + mutability). +pub trait Properties { + type Ownership: Ownership; + type Mutability: Mutability; +} + +impl Properties for (O, M) { + type Ownership = O; + type Mutability = M; +} + +/// An exclusive pointer type, generic over mutability. +pub type CassExclusivePtr = CassPtr; + +/// An exclusive const pointer. Should be used for Box-allocated objects +/// that need to be read-only. +pub type CassExclusiveConstPtr = CassExclusivePtr; + +/// An exclusive mutable pointer. Should be used for Box-allocated objects +/// that need to be mutable. +pub type CassExclusiveMutPtr = CassExclusivePtr; + +/// A shared const pointer. Should be used for Arc-allocated objects. +/// +/// Notice that this type does not provide interior mutability itself. +/// It is the responsiblity of the user of this API, to provide soundness +/// in this matter (aliasing ^ mutability). +/// Take for example [`CassDataType`](crate::cass_types::CassDataType). It +/// holds underlying data in [`UnsafeCell`](std::cell::UnsafeCell) to provide +/// interior mutability. +pub type CassSharedPtr = CassPtr; + +/// A borrowed const pointer. Should be used for objects that reference +/// some field of already allocated object. +/// +/// When operating on the pointer of this type, we do not make any assumptions +/// about the origin of the pointer, apart from the fact that it is obtained +/// from some valid reference. +/// +/// In particular, it may be a reference obtained from `Box/Arc::as_ref()`. +/// Notice, however, that we do not allow to convert such pointer back to `Box/Arc`, +/// since we do not have a guarantee that corresponding memory was allocated certain way. +/// OTOH, we can safely convert the pointer to a valid reference (assuming user did not +/// provide a pointer pointing to some garbage memory). +pub type CassBorrowedPtr = CassPtr; + +// repr(transparent), so the struct has the same layout as underlying Option>. +// Thanks to https://doc.rust-lang.org/std/option/#representation optimization, +// we are guaranteed, that for T: Sized, our struct has the same layout +// and function call ABI as simply NonNull. +#[repr(transparent)] +pub struct CassPtr { + ptr: Option>, + _phantom: PhantomData

, +} + +/// Unsafe implementation of clone for tests. +/// +/// In tests, we play a role as C API user. We need to be able +/// to reuse the pointer and to pass it to different API function calls. +/// API functions consume the pointer, thus we need some way to obtain +/// an owned version of the pointer again. +/// We introduce unsafe `clone` method for this reason. +#[cfg(test)] +impl CassPtr { + pub unsafe fn clone(&self) -> Self { + Self { + ptr: self.ptr, + _phantom: PhantomData, + } + } +} + +/// Methods for any pointer kind. +/// +/// Notice that there are no constructors except from null(), which is trivial. +/// Thanks to that, we can make some assumptions and guarantees based on +/// the origin of the pointer. For these, see `SAFETY` comments. +impl CassPtr { + fn null() -> Self { + CassPtr { + ptr: None, + _phantom: PhantomData, + } + } + + fn is_null(&self) -> bool { + self.ptr.is_none() + } + + /// Converts a pointer to an optional valid reference. + fn as_ref(&self) -> Option<&T> { + // SAFETY: We know that our pointers either come from a valid allocation (Box or Arc), + // or were created based on the reference to the field of some allocated object. + // Thus, if the pointer is non-null, we are guaranteed that it's valid. + unsafe { self.ptr.map(|p| p.as_ref()) } + } + + /// Converts a pointer to an optional valid reference, and consumes the pointer. + fn into_ref<'a>(self) -> Option<&'a T> { + // SAFETY: We know that our pointers either come from a valid allocation (Box or Arc), + // or were created based on the reference to the field of some allocated object. + // Thus, if the pointer is non-null, we are guaranteed that it's valid. + unsafe { self.ptr.map(|p| p.as_ref()) } + } + + /// Converts self to raw pointer. + fn as_raw(&self) -> *const T { + self.ptr + .map(|ptr| ptr.as_ptr() as *const _) + .unwrap_or(std::ptr::null()) + } +} + +/// Methods for any exclusive pointers (no matter the mutability). +impl CassPtr { + /// Creates a pointer based on a VALID Box allocation. + /// This is the only way to obtain such pointer. + fn from_box(b: Box) -> Self { + #[allow(clippy::disallowed_methods)] + let ptr = Box::into_raw(b); + CassPtr { + ptr: NonNull::new(ptr), + _phantom: PhantomData, + } + } + + /// Converts the pointer back to the owned Box (if not null). + fn into_box(self) -> Option> { + // SAFETY: We are guaranteed that if ptr is Some, then it was obtained + // from a valid Box allocation (see new() implementation). + unsafe { + self.ptr.map(|p| { + #[allow(clippy::disallowed_methods)] + Box::from_raw(p.as_ptr()) + }) + } + } +} + +/// Methods for exclusive mutable pointers. +impl CassPtr { + fn null_mut() -> Self { + CassPtr { + ptr: None, + _phantom: PhantomData, + } + } + + /// Converts a pointer to an optional valid, and mutable reference. + fn as_mut_ref(&mut self) -> Option<&mut T> { + // SAFETY: We are guaranteed that if ptr is Some, then it was obtained + // from a valid Box allocation (see new() implementation). + unsafe { self.ptr.map(|mut p| p.as_mut()) } + } +} + +/// Define this conversion for test purposes. +/// User receives a mutable exclusive pointer from constructor, +/// but following function calls need a const exclusive pointer. +#[cfg(test)] +impl CassPtr { + pub fn into_const(self) -> CassPtr { + CassPtr { + ptr: self.ptr, + _phantom: PhantomData, + } + } +} + +/// Methods for Shared pointers. +impl CassSharedPtr { + /// Creates a pointer based on a VALID Arc allocation. + fn from_arc(a: Arc) -> Self { + #[allow(clippy::disallowed_methods)] + let ptr = Arc::into_raw(a); + CassPtr { + ptr: NonNull::new(ptr as *mut T), + _phantom: PhantomData, + } + } + + /// Creates a pointer which borrows from a VALID Arc allocation. + fn from_ref(a: &Arc) -> Self { + #[allow(clippy::disallowed_methods)] + let ptr = Arc::as_ptr(a); + CassPtr { + ptr: NonNull::new(ptr as *mut T), + _phantom: PhantomData, + } + } + + /// Converts the pointer back to the Arc. + fn into_arc(self) -> Option> { + // SAFETY: The pointer can only be obtained via new() or from_ref(), + // both of which accept an Arc. + // It means that the pointer comes from a valid allocation. + unsafe { + self.ptr.map(|p| { + #[allow(clippy::disallowed_methods)] + Arc::from_raw(p.as_ptr()) + }) + } + } + + /// Converts the pointer to an Arc, by increasing its reference count. + fn clone_arced(self) -> Option> { + // SAFETY: The pointer can only be obtained via new() or from_ref(), + // both of which accept an Arc. + // It means that the pointer comes from a valid allocation. + unsafe { + self.ptr.map(|p| { + let ptr = p.as_ptr(); + #[allow(clippy::disallowed_methods)] + Arc::increment_strong_count(ptr); + #[allow(clippy::disallowed_methods)] + Arc::from_raw(ptr) + }) + } + } +} + +/// Methods for borrowed pointers. +impl CassPtr { + fn from_ref(r: &T) -> Self { + Self { + ptr: Some(NonNull::from(r)), + _phantom: PhantomData, + } + } +} + /// Defines a pointer manipulation API for non-shared heap-allocated data. /// /// Implement this trait for types that are allocated by the driver via [`Box::new`], /// and then returned to the user as a pointer. The user is responsible for freeing /// the memory associated with the pointer using corresponding driver's API function. -pub trait BoxFFI { - fn into_ptr(self: Box) -> *mut Self { - #[allow(clippy::disallowed_methods)] - Box::into_raw(self) +pub trait BoxFFI: Sized { + fn into_ptr(self: Box) -> CassExclusivePtr { + CassExclusivePtr::from_box(self) } - unsafe fn from_ptr(ptr: *mut Self) -> Box { - #[allow(clippy::disallowed_methods)] - Box::from_raw(ptr) + fn from_ptr(ptr: CassExclusivePtr) -> Option> { + ptr.into_box() } - unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { - #[allow(clippy::disallowed_methods)] + fn as_ref(ptr: &CassExclusivePtr) -> Option<&Self> { ptr.as_ref() } - unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { - #[allow(clippy::disallowed_methods)] - ptr.as_ref().unwrap() + fn into_ref<'a, M: Mutability>(ptr: CassExclusivePtr) -> Option<&'a Self> { + ptr.into_ref() } - unsafe fn as_mut_ref<'a>(ptr: *mut Self) -> &'a mut Self { - #[allow(clippy::disallowed_methods)] - ptr.as_mut().unwrap() + fn as_mut_ref(ptr: &mut CassExclusiveMutPtr) -> Option<&mut Self> { + ptr.as_mut_ref() } - unsafe fn free(ptr: *mut Self) { + fn free(ptr: CassExclusiveMutPtr) { std::mem::drop(BoxFFI::from_ptr(ptr)); } + #[cfg(test)] + fn null() -> CassExclusiveConstPtr { + CassExclusiveConstPtr::null() + } + fn null_mut() -> CassExclusiveMutPtr { + CassExclusiveMutPtr::null_mut() + } } /// Defines a pointer manipulation API for shared heap-allocated data. @@ -108,36 +382,37 @@ pub trait BoxFFI { /// The data should be allocated via [`Arc::new`], and then returned to the user as a pointer. /// The user is responsible for freeing the memory associated /// with the pointer using corresponding driver's API function. -pub trait ArcFFI { - fn as_ptr(self: &Arc) -> *const Self { - #[allow(clippy::disallowed_methods)] - Arc::as_ptr(self) +pub trait ArcFFI: Sized { + fn as_ptr(self: &Arc) -> CassSharedPtr { + CassSharedPtr::from_ref(self) } - fn into_ptr(self: Arc) -> *const Self { - #[allow(clippy::disallowed_methods)] - Arc::into_raw(self) + fn into_ptr(self: Arc) -> CassSharedPtr { + CassSharedPtr::from_arc(self) } - unsafe fn from_ptr(ptr: *const Self) -> Arc { - #[allow(clippy::disallowed_methods)] - Arc::from_raw(ptr) + fn from_ptr(ptr: CassSharedPtr) -> Option> { + ptr.into_arc() } - unsafe fn cloned_from_ptr(ptr: *const Self) -> Arc { - #[allow(clippy::disallowed_methods)] - Arc::increment_strong_count(ptr); - #[allow(clippy::disallowed_methods)] - Arc::from_raw(ptr) + fn cloned_from_ptr(ptr: CassSharedPtr) -> Option> { + ptr.clone_arced() } - unsafe fn as_maybe_ref<'a>(ptr: *const Self) -> Option<&'a Self> { - #[allow(clippy::disallowed_methods)] + fn as_ref(ptr: &CassSharedPtr) -> Option<&Self> { ptr.as_ref() } - unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { - #[allow(clippy::disallowed_methods)] - ptr.as_ref().unwrap() + fn into_ref<'a>(ptr: CassSharedPtr) -> Option<&'a Self> { + ptr.into_ref() } - unsafe fn free(ptr: *const Self) { + fn to_raw(ptr: &CassSharedPtr) -> *const Self { + ptr.as_raw() + } + fn free(ptr: CassSharedPtr) { std::mem::drop(ArcFFI::from_ptr(ptr)); } + fn null() -> CassSharedPtr { + CassSharedPtr::null() + } + fn is_null(ptr: &CassSharedPtr) -> bool { + ptr.is_null() + } } /// Defines a pointer manipulation API for data owned by some other object. @@ -148,12 +423,20 @@ pub trait ArcFFI { /// For example: lifetime of CassRow is bound by the lifetime of CassResult. /// There is no API function that frees the CassRow. It should be automatically /// freed when user calls cass_result_free. -pub trait RefFFI { - fn as_ptr(&self) -> *const Self { - self as *const Self +pub trait RefFFI: Sized { + fn as_ptr(&self) -> CassBorrowedPtr { + CassBorrowedPtr::from_ref(self) } - unsafe fn as_ref<'a>(ptr: *const Self) -> &'a Self { - #[allow(clippy::disallowed_methods)] - ptr.as_ref().unwrap() + fn as_ref(ptr: &CassBorrowedPtr) -> Option<&Self> { + ptr.as_ref() + } + fn into_ref<'a>(ptr: CassBorrowedPtr) -> Option<&'a Self> { + ptr.into_ref() + } + fn null() -> CassBorrowedPtr { + CassBorrowedPtr::null() + } + fn is_null(ptr: &CassBorrowedPtr) -> bool { + ptr.is_null() } } diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index fe18e548..a6bae761 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -1,4 +1,4 @@ -use crate::argconv::{ArcFFI, BoxFFI}; +use crate::argconv::{ArcFFI, BoxFFI, CassExclusiveConstPtr, CassExclusiveMutPtr, CassSharedPtr}; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; use crate::cass_types::{make_batch_type, CassBatchType}; @@ -28,7 +28,7 @@ pub struct CassBatchState { } #[no_mangle] -pub unsafe extern "C" fn cass_batch_new(type_: CassBatchType) -> *mut CassBatch { +pub unsafe extern "C" fn cass_batch_new(type_: CassBatchType) -> CassExclusiveMutPtr { if let Some(batch_type) = make_batch_type(type_) { BoxFFI::into_ptr(Box::new(CassBatch { state: Arc::new(CassBatchState { @@ -39,21 +39,21 @@ pub unsafe extern "C" fn cass_batch_new(type_: CassBatchType) -> *mut CassBatch exec_profile: None, })) } else { - std::ptr::null_mut() + BoxFFI::null_mut() } } #[no_mangle] -pub unsafe extern "C" fn cass_batch_free(batch: *mut CassBatch) { +pub unsafe extern "C" fn cass_batch_free(batch: CassExclusiveMutPtr) { BoxFFI::free(batch); } #[no_mangle] pub unsafe extern "C" fn cass_batch_set_consistency( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, consistency: CassConsistency, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); let consistency = match consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -67,10 +67,10 @@ pub unsafe extern "C" fn cass_batch_set_consistency( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_serial_consistency( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, serial_consistency: CassConsistency, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); let serial_consistency = match serial_consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -84,13 +84,13 @@ pub unsafe extern "C" fn cass_batch_set_serial_consistency( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_retry_policy( - batch: *mut CassBatch, - retry_policy: *const CassRetryPolicy, + mut batch: CassExclusiveMutPtr, + retry_policy: CassSharedPtr, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); let maybe_arced_retry_policy: Option> = - ArcFFI::as_maybe_ref(retry_policy).map(|policy| match policy { + ArcFFI::as_ref(&retry_policy).map(|policy| match policy { CassRetryPolicy::DefaultRetryPolicy(default) => { default.clone() as Arc } @@ -107,10 +107,10 @@ pub unsafe extern "C" fn cass_batch_set_retry_policy( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_timestamp( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, timestamp: cass_int64_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); Arc::make_mut(&mut batch.state) .batch @@ -121,10 +121,10 @@ pub unsafe extern "C" fn cass_batch_set_timestamp( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_request_timeout( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, timeout_ms: cass_uint64_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); batch.batch_request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -132,10 +132,10 @@ pub unsafe extern "C" fn cass_batch_set_request_timeout( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_is_idempotent( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, is_idempotent: cass_bool_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); Arc::make_mut(&mut batch.state) .batch .set_is_idempotent(is_idempotent != 0); @@ -145,10 +145,10 @@ pub unsafe extern "C" fn cass_batch_set_is_idempotent( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_tracing( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, enabled: cass_bool_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); Arc::make_mut(&mut batch.state) .batch .set_tracing(enabled != 0); @@ -158,12 +158,12 @@ pub unsafe extern "C" fn cass_batch_set_tracing( #[no_mangle] pub unsafe extern "C" fn cass_batch_add_statement( - batch: *mut CassBatch, - statement: *const CassStatement, + mut batch: CassExclusiveMutPtr, + statement: CassExclusiveConstPtr, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); let state = Arc::make_mut(&mut batch.state); - let statement = BoxFFI::as_ref(statement); + let statement = BoxFFI::as_ref(&statement).unwrap(); match &statement.statement { Statement::Simple(q) => state.batch.append_statement(q.query.clone()), diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index e9768889..17982268 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -53,7 +53,7 @@ macro_rules! make_index_binder { #[no_mangle] #[allow(clippy::redundant_closure_call)] pub unsafe extern "C" fn $fn_by_idx( - this: *mut $this, + mut this: CassExclusiveMutPtr<$this>, index: size_t, $($arg: $t), * ) -> CassError { @@ -61,7 +61,7 @@ macro_rules! make_index_binder { #[allow(unused_imports)] use crate::value::CassCqlValue::*; match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), index as usize, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(&mut this).unwrap(), index as usize, v), Err(e) => e, } } @@ -73,7 +73,7 @@ macro_rules! make_name_binder { #[no_mangle] #[allow(clippy::redundant_closure_call)] pub unsafe extern "C" fn $fn_by_name( - this: *mut $this, + mut this: CassExclusiveMutPtr<$this>, name: *const c_char, $($arg: $t), * ) -> CassError { @@ -82,7 +82,7 @@ macro_rules! make_name_binder { use crate::value::CassCqlValue::*; let name = ptr_to_cstr(name).unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), name, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(&mut this).unwrap(), name, v), Err(e) => e, } } @@ -94,7 +94,7 @@ macro_rules! make_name_n_binder { #[no_mangle] #[allow(clippy::redundant_closure_call)] pub unsafe extern "C" fn $fn_by_name_n( - this: *mut $this, + mut this: CassExclusiveMutPtr<$this>, name: *const c_char, name_length: size_t, $($arg: $t), * @@ -104,7 +104,7 @@ macro_rules! make_name_n_binder { use crate::value::CassCqlValue::*; let name = ptr_to_cstr_n(name, name_length).unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), name, v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(&mut this).unwrap(), name, v), Err(e) => e, } } @@ -116,14 +116,14 @@ macro_rules! make_appender { #[no_mangle] #[allow(clippy::redundant_closure_call)] pub unsafe extern "C" fn $fn_append( - this: *mut $this, + mut this: CassExclusiveMutPtr<$this>, $($arg: $t), * ) -> CassError { // For some reason detected as unused, which is not true #[allow(unused_imports)] use crate::value::CassCqlValue::*; match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this), v), + Ok(v) => $consume_v(BoxFFI::as_mut_ref(&mut this).unwrap(), v), Err(e) => e, } } @@ -302,13 +302,13 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: *const crate::collection::CassCollection| { - match std::convert::TryInto::try_into(BoxFFI::as_ref(p)) { + |p: CassExclusiveConstPtr| { + match std::convert::TryInto::try_into(BoxFFI::as_ref(&p).unwrap()) { Ok(v) => Ok(Some(v)), Err(_) => Err(CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE), } }, - [p @ *const crate::collection::CassCollection] + [p @ CassExclusiveConstPtr] ); }; (tuple, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -316,10 +316,10 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: *const crate::tuple::CassTuple| { - Ok(Some(BoxFFI::as_ref(p).into())) + |p: CassExclusiveConstPtr| { + Ok(Some(BoxFFI::as_ref(&p).unwrap().into())) }, - [p @ *const crate::tuple::CassTuple] + [p @ CassExclusiveConstPtr] ); }; (user_type, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -327,8 +327,10 @@ macro_rules! invoke_binder_maker_macro_with_type { $this, $consume_v, $fn, - |p: *const crate::user_type::CassUserType| Ok(Some(BoxFFI::as_ref(p).into())), - [p @ *const crate::user_type::CassUserType] + |p: CassExclusiveConstPtr| { + Ok(Some(BoxFFI::as_ref(&p).unwrap().into())) + }, + [p @ CassExclusiveConstPtr] ); }; } diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 6b479020..5a9a7eae 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -9,7 +9,6 @@ use std::cell::UnsafeCell; use std::collections::HashMap; use std::convert::TryFrom; use std::os::raw::c_char; -use std::ptr; use std::sync::Arc; pub(crate) use crate::cass_batch_types::CassBatchType; @@ -552,12 +551,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { CassDataType::new(inner) } -// Changed return type to const ptr - ArcFFI::into_ptr is const. -// It's probably not a good idea - but cppdriver doesn't guarantee -// thread safety apart from CassSession and CassFuture. -// This comment also applies to other functions that create CassDataType. #[no_mangle] -pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const CassDataType { +pub unsafe extern "C" fn cass_data_type_new( + value_type: CassValueType, +) -> CassSharedPtr { let inner = match value_type { CassValueType::CASS_VALUE_TYPE_LIST => CassDataTypeInner::List { typ: None, @@ -574,49 +571,57 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const }, CassValueType::CASS_VALUE_TYPE_UDT => CassDataTypeInner::UDT(UDTDataType::new()), CassValueType::CASS_VALUE_TYPE_CUSTOM => CassDataTypeInner::Custom("".to_string()), - CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ptr::null_mut(), + CassValueType::CASS_VALUE_TYPE_UNKNOWN => return ArcFFI::null(), t if t < CassValueType::CASS_VALUE_TYPE_LAST_ENTRY => CassDataTypeInner::Value(t), - _ => return ptr::null_mut(), + _ => return ArcFFI::null(), }; ArcFFI::into_ptr(CassDataType::new_arced(inner)) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_new_from_existing( - data_type: *const CassDataType, -) -> *const CassDataType { - let data_type = ArcFFI::as_ref(data_type); + data_type: CassSharedPtr, +) -> CassSharedPtr { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); ArcFFI::into_ptr(CassDataType::new_arced(data_type.get_unchecked().clone())) } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_new_tuple(item_count: size_t) -> *const CassDataType { +pub unsafe extern "C" fn cass_data_type_new_tuple( + item_count: size_t, +) -> CassSharedPtr { ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::Tuple( Vec::with_capacity(item_count as usize), ))) } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_new_udt(field_count: size_t) -> *const CassDataType { +pub unsafe extern "C" fn cass_data_type_new_udt( + field_count: size_t, +) -> CassSharedPtr { ArcFFI::into_ptr(CassDataType::new_arced(CassDataTypeInner::UDT( UDTDataType::with_capacity(field_count as usize), ))) } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_free(data_type: *mut CassDataType) { +pub unsafe extern "C" fn cass_data_type_free(data_type: CassSharedPtr) { ArcFFI::free(data_type); } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_type(data_type: *const CassDataType) -> CassValueType { - let data_type = ArcFFI::as_ref(data_type); +pub unsafe extern "C" fn cass_data_type_type( + data_type: CassSharedPtr, +) -> CassValueType { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); data_type.get_unchecked().get_value_type() } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType) -> cass_bool_t { - let data_type = ArcFFI::as_ref(data_type); +pub unsafe extern "C" fn cass_data_type_is_frozen( + data_type: CassSharedPtr, +) -> cass_bool_t { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); let is_frozen = match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => udt.frozen, CassDataTypeInner::List { frozen, .. } => *frozen, @@ -630,11 +635,11 @@ pub unsafe extern "C" fn cass_data_type_is_frozen(data_type: *const CassDataType #[no_mangle] pub unsafe extern "C" fn cass_data_type_type_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, type_name: *mut *const c_char, type_name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, type_name, type_name_length); @@ -646,7 +651,7 @@ pub unsafe extern "C" fn cass_data_type_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_type_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, type_name: *const c_char, ) -> CassError { cass_data_type_set_type_name_n(data_type, type_name, strlen(type_name)) @@ -654,11 +659,11 @@ pub unsafe extern "C" fn cass_data_type_set_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_type_name_n( - data_type_raw: *const CassDataType, + data_type_raw: CassSharedPtr, type_name: *const c_char, type_name_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type_raw); + let data_type = ArcFFI::as_ref(&data_type_raw).unwrap(); let type_name_string = ptr_to_cstr_n(type_name, type_name_length) .unwrap() .to_string(); @@ -674,11 +679,11 @@ pub unsafe extern "C" fn cass_data_type_set_type_name_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_keyspace( - data_type: *const CassDataType, + data_type: CassSharedPtr, keyspace: *mut *const c_char, keyspace_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { write_str_to_c(name, keyspace, keyspace_length); @@ -690,7 +695,7 @@ pub unsafe extern "C" fn cass_data_type_keyspace( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_keyspace( - data_type: *const CassDataType, + data_type: CassSharedPtr, keyspace: *const c_char, ) -> CassError { cass_data_type_set_keyspace_n(data_type, keyspace, strlen(keyspace)) @@ -698,11 +703,11 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_keyspace_n( - data_type: *const CassDataType, + data_type: CassSharedPtr, keyspace: *const c_char, keyspace_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); let keyspace_string = ptr_to_cstr_n(keyspace, keyspace_length) .unwrap() .to_string(); @@ -718,11 +723,11 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_class_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, class_name: *mut *const ::std::os::raw::c_char, class_name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::Custom(name) => { write_str_to_c(name, class_name, class_name_length); @@ -734,7 +739,7 @@ pub unsafe extern "C" fn cass_data_type_class_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_class_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, class_name: *const ::std::os::raw::c_char, ) -> CassError { cass_data_type_set_class_name_n(data_type, class_name, strlen(class_name)) @@ -742,11 +747,11 @@ pub unsafe extern "C" fn cass_data_type_set_class_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_set_class_name_n( - data_type: *const CassDataType, + data_type: CassSharedPtr, class_name: *const ::std::os::raw::c_char, class_name_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); let class_string = ptr_to_cstr_n(class_name, class_name_length) .unwrap() .to_string(); @@ -760,8 +765,10 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( } #[no_mangle] -pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDataType) -> size_t { - let data_type = ArcFFI::as_ref(data_type); +pub unsafe extern "C" fn cass_data_type_sub_type_count( + data_type: CassSharedPtr, +) -> size_t { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::Value(..) => 0, CassDataTypeInner::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, @@ -779,21 +786,23 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat } #[no_mangle] -pub unsafe extern "C" fn cass_data_sub_type_count(data_type: *const CassDataType) -> size_t { +pub unsafe extern "C" fn cass_data_sub_type_count( + data_type: CassSharedPtr, +) -> size_t { cass_data_type_sub_type_count(data_type) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_data_type( - data_type: *const CassDataType, + data_type: CassSharedPtr, index: size_t, -) -> *const CassDataType { - let data_type = ArcFFI::as_ref(data_type); +) -> CassSharedPtr { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); let sub_type: Option<&Arc> = data_type.get_unchecked().get_sub_data_type(index as usize); match sub_type { - None => std::ptr::null(), + None => ArcFFI::null(), // Semantic from cppdriver which also returns non-owning pointer Some(arc) => ArcFFI::as_ptr(arc), } @@ -801,37 +810,37 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type( #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, name: *const ::std::os::raw::c_char, -) -> *const CassDataType { +) -> CassSharedPtr { cass_data_type_sub_data_type_by_name_n(data_type, name, strlen(name)) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( - data_type: *const CassDataType, + data_type: CassSharedPtr, name: *const ::std::os::raw::c_char, name_length: size_t, -) -> *const CassDataType { - let data_type = ArcFFI::as_ref(data_type); +) -> CassSharedPtr { + let data_type = ArcFFI::as_ref(&data_type).unwrap(); let name_str = ptr_to_cstr_n(name, name_length).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => match udt.get_field_by_name(name_str) { - None => std::ptr::null(), + None => ArcFFI::null(), Some(t) => ArcFFI::as_ptr(t), }, - _ => std::ptr::null(), + _ => ArcFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_data_type_sub_type_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, index: size_t, name: *mut *const ::std::os::raw::c_char, name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt) => match udt.field_types.get(index as usize) { None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, @@ -846,13 +855,13 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type( - data_type: *const CassDataType, - sub_data_type: *const CassDataType, + data_type: CassSharedPtr, + sub_data_type: CassSharedPtr, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type); + let data_type = ArcFFI::as_ref(&data_type).unwrap(); match data_type .get_mut_unchecked() - .add_sub_data_type(ArcFFI::cloned_from_ptr(sub_data_type)) + .add_sub_data_type(ArcFFI::cloned_from_ptr(sub_data_type).unwrap()) { Ok(()) => CassError::CASS_OK, Err(e) => e, @@ -861,24 +870,24 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, name: *const c_char, - sub_data_type: *const CassDataType, + sub_data_type: CassSharedPtr, ) -> CassError { cass_data_type_add_sub_type_by_name_n(data_type, name, strlen(name), sub_data_type) } #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( - data_type_raw: *const CassDataType, + data_type_raw: CassSharedPtr, name: *const c_char, name_length: size_t, - sub_data_type_raw: *const CassDataType, + sub_data_type_raw: CassSharedPtr, ) -> CassError { let name_string = ptr_to_cstr_n(name, name_length).unwrap().to_string(); - let sub_data_type = ArcFFI::cloned_from_ptr(sub_data_type_raw); + let sub_data_type = ArcFFI::cloned_from_ptr(sub_data_type_raw).unwrap(); - let data_type = ArcFFI::as_ref(data_type_raw); + let data_type = ArcFFI::as_ref(&data_type_raw).unwrap(); match data_type.get_mut_unchecked() { CassDataTypeInner::UDT(udt_data_type) => { // The Cpp Driver does not check whether field_types size @@ -892,7 +901,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type( - data_type: *const CassDataType, + data_type: CassSharedPtr, sub_value_type: CassValueType, ) -> CassError { let sub_data_type = CassDataType::new_arced(CassDataTypeInner::Value(sub_value_type)); @@ -901,7 +910,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_value_type( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( - data_type: *const CassDataType, + data_type: CassSharedPtr, name: *const c_char, sub_value_type: CassValueType, ) -> CassError { @@ -911,7 +920,7 @@ pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name( #[no_mangle] pub unsafe extern "C" fn cass_data_type_add_sub_value_type_by_name_n( - data_type: *const CassDataType, + data_type: CassSharedPtr, name: *const c_char, name_length: size_t, sub_value_type: CassValueType, diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index bde792b8..11621418 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -194,7 +194,7 @@ pub fn build_session_builder( } #[no_mangle] -pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster { +pub unsafe extern "C" fn cass_cluster_new() -> CassExclusiveMutPtr { let default_execution_profile_builder = ExecutionProfileBuilder::default() .consistency(DEFAULT_CONSISTENCY) .request_timeout(Some(DEFAULT_REQUEST_TIMEOUT)); @@ -234,13 +234,13 @@ pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster { } #[no_mangle] -pub unsafe extern "C" fn cass_cluster_free(cluster: *mut CassCluster) { +pub unsafe extern "C" fn cass_cluster_free(cluster: CassExclusiveMutPtr) { BoxFFI::free(cluster); } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_contact_points( - cluster: *mut CassCluster, + cluster: CassExclusiveMutPtr, contact_points: *const c_char, ) -> CassError { cass_cluster_set_contact_points_n(cluster, contact_points, strlen(contact_points)) @@ -248,7 +248,7 @@ pub unsafe extern "C" fn cass_cluster_set_contact_points( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_contact_points_n( - cluster: *mut CassCluster, + cluster: CassExclusiveMutPtr, contact_points: *const c_char, contact_points_length: size_t, ) -> CassError { @@ -259,11 +259,11 @@ pub unsafe extern "C" fn cass_cluster_set_contact_points_n( } unsafe fn cluster_set_contact_points( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, contact_points_raw: *const c_char, contact_points_length: size_t, ) -> Result<(), CassError> { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let mut contact_points = ptr_to_cstr_n(contact_points_raw, contact_points_length) .ok_or(CassError::CASS_ERROR_LIB_BAD_PARAMS)? .split(',') @@ -289,7 +289,7 @@ unsafe fn cluster_set_contact_points( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_use_randomized_contact_points( - _cluster_raw: *mut CassCluster, + _cluster_raw: CassExclusiveMutPtr, _enabled: cass_bool_t, ) -> CassError { // FIXME: should set `use_randomized_contact_points` flag in cluster config @@ -299,7 +299,7 @@ pub unsafe extern "C" fn cass_cluster_set_use_randomized_contact_points( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_application_name( - cluster_raw: *mut CassCluster, + cluster_raw: CassExclusiveMutPtr, app_name: *const c_char, ) { cass_cluster_set_application_name_n(cluster_raw, app_name, strlen(app_name)) @@ -307,11 +307,11 @@ pub unsafe extern "C" fn cass_cluster_set_application_name( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_application_name_n( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, app_name: *const c_char, app_name_len: size_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let app_name = ptr_to_cstr_n(app_name, app_name_len).unwrap().to_string(); cluster @@ -323,7 +323,7 @@ pub unsafe extern "C" fn cass_cluster_set_application_name_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_application_version( - cluster_raw: *mut CassCluster, + cluster_raw: CassExclusiveMutPtr, app_version: *const c_char, ) { cass_cluster_set_application_version_n(cluster_raw, app_version, strlen(app_version)) @@ -331,11 +331,11 @@ pub unsafe extern "C" fn cass_cluster_set_application_version( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_application_version_n( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, app_version: *const c_char, app_version_len: size_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let app_version = ptr_to_cstr_n(app_version, app_version_len) .unwrap() .to_string(); @@ -349,10 +349,10 @@ pub unsafe extern "C" fn cass_cluster_set_application_version_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_client_id( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, client_id: CassUuid, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let client_uuid: uuid::Uuid = client_id.into(); let client_uuid_str = client_uuid.to_string(); @@ -367,29 +367,29 @@ pub unsafe extern "C" fn cass_cluster_set_client_id( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_use_schema( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.session_builder.config.fetch_schema_metadata = enabled != 0; } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_tcp_nodelay( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.session_builder.config.tcp_nodelay = enabled != 0; } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_tcp_keepalive( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enabled: cass_bool_t, delay_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let enabled = enabled != 0; let tcp_keepalive_interval = enabled.then(|| Duration::from_secs(delay_secs as u64)); @@ -398,10 +398,10 @@ pub unsafe extern "C" fn cass_cluster_set_tcp_keepalive( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_connection_heartbeat_interval( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, interval_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let keepalive_interval = (interval_secs > 0).then(|| Duration::from_secs(interval_secs as u64)); cluster.session_builder.config.keepalive_interval = keepalive_interval; @@ -409,10 +409,10 @@ pub unsafe extern "C" fn cass_cluster_set_connection_heartbeat_interval( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_connection_idle_timeout( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, timeout_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let keepalive_timeout = (timeout_secs > 0).then(|| Duration::from_secs(timeout_secs as u64)); cluster.session_builder.config.keepalive_timeout = keepalive_timeout; @@ -420,19 +420,19 @@ pub unsafe extern "C" fn cass_cluster_set_connection_idle_timeout( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_connect_timeout( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, timeout_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.session_builder.config.connect_timeout = Duration::from_millis(timeout_ms.into()); } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_request_timeout( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, timeout_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { // 0 -> no timeout @@ -442,10 +442,10 @@ pub unsafe extern "C" fn cass_cluster_set_request_timeout( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_max_schema_wait_time( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, wait_time_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.session_builder.config.schema_agreement_timeout = Duration::from_millis(wait_time_ms.into()); @@ -453,10 +453,10 @@ pub unsafe extern "C" fn cass_cluster_set_max_schema_wait_time( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_schema_agreement_interval( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, interval_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.session_builder.config.schema_agreement_interval = Duration::from_millis(interval_ms.into()); @@ -464,21 +464,21 @@ pub unsafe extern "C" fn cass_cluster_set_schema_agreement_interval( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_port( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, port: c_int, ) -> CassError { if port <= 0 { return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.port = port as u16; CassError::CASS_OK } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_credentials( - cluster: *mut CassCluster, + cluster: CassExclusiveMutPtr, username: *const c_char, password: *const c_char, ) { @@ -493,7 +493,7 @@ pub unsafe extern "C" fn cass_cluster_set_credentials( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_credentials_n( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, username_raw: *const c_char, username_length: size_t, password_raw: *const c_char, @@ -503,20 +503,22 @@ pub unsafe extern "C" fn cass_cluster_set_credentials_n( let username = ptr_to_cstr_n(username_raw, username_length).unwrap(); let password = ptr_to_cstr_n(password_raw, password_length).unwrap(); - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.auth_username = Some(username.to_string()); cluster.auth_password = Some(password.to_string()); } #[no_mangle] -pub unsafe extern "C" fn cass_cluster_set_load_balance_round_robin(cluster_raw: *mut CassCluster) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); +pub unsafe extern "C" fn cass_cluster_set_load_balance_round_robin( + mut cluster_raw: CassExclusiveMutPtr, +) { + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_dc_aware( - cluster: *mut CassCluster, + cluster: CassExclusiveMutPtr, local_dc: *const c_char, used_hosts_per_remote_dc: c_uint, allow_remote_dcs_for_local_cl: cass_bool_t, @@ -557,13 +559,13 @@ pub(crate) unsafe fn set_load_balance_dc_aware_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_dc_aware_n( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, local_dc_raw: *const c_char, local_dc_length: size_t, used_hosts_per_remote_dc: c_uint, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); set_load_balance_dc_aware_n( &mut cluster.load_balancing_config, @@ -576,7 +578,7 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_dc_aware_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_rack_aware( - cluster_raw: *mut CassCluster, + cluster_raw: CassExclusiveMutPtr, local_dc_raw: *const c_char, local_rack_raw: *const c_char, ) -> CassError { @@ -591,13 +593,13 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_rack_aware( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_rack_aware_n( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, local_dc_raw: *const c_char, local_dc_length: size_t, local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); set_load_balance_rack_aware_n( &mut cluster.load_balancing_config, @@ -638,7 +640,7 @@ pub(crate) unsafe fn set_load_balance_rack_aware_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_cloud_secure_connection_bundle_n( - _cluster_raw: *mut CassCluster, + _cluster_raw: CassExclusiveMutPtr, path: *const c_char, path_length: size_t, ) -> CassError { @@ -654,7 +656,7 @@ pub unsafe extern "C" fn cass_cluster_set_cloud_secure_connection_bundle_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_exponential_reconnect( - _cluster_raw: *mut CassCluster, + _cluster_raw: CassExclusiveMutPtr, base_delay_ms: cass_uint64_t, max_delay_ms: cass_uint64_t, ) -> CassError { @@ -689,7 +691,7 @@ pub extern "C" fn cass_custom_payload_new() -> *const CassCustomPayload { #[no_mangle] pub extern "C" fn cass_future_custom_payload_item( - _future: *mut CassFuture, + _future: CassExclusiveMutPtr, _i: size_t, _name: *const c_char, _name_length: size_t, @@ -700,16 +702,18 @@ pub extern "C" fn cass_future_custom_payload_item( } #[no_mangle] -pub extern "C" fn cass_future_custom_payload_item_count(_future: *mut CassFuture) -> size_t { +pub extern "C" fn cass_future_custom_payload_item_count( + _future: CassExclusiveMutPtr, +) -> size_t { 0 } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_use_beta_protocol_version( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enable: cass_bool_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.use_beta_protocol_version = enable == cass_true; CassError::CASS_OK @@ -717,10 +721,10 @@ pub unsafe extern "C" fn cass_cluster_set_use_beta_protocol_version( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_protocol_version( - cluster_raw: *mut CassCluster, + cluster_raw: CassExclusiveMutPtr, protocol_version: c_int, ) -> CassError { - let cluster = BoxFFI::as_ref(cluster_raw); + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); if protocol_version == 4 && !cluster.use_beta_protocol_version { // Rust Driver supports only protocol version 4 @@ -732,7 +736,7 @@ pub unsafe extern "C" fn cass_cluster_set_protocol_version( #[no_mangle] pub extern "C" fn cass_cluster_set_queue_size_event( - _cluster: *mut CassCluster, + _cluster: CassExclusiveMutPtr, _queue_size: c_uint, ) -> CassError { // In Cpp Driver this function is also a no-op... @@ -741,7 +745,7 @@ pub extern "C" fn cass_cluster_set_queue_size_event( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, constant_delay_ms: cass_int64_t, max_speculative_executions: c_int, ) -> CassError { @@ -749,7 +753,7 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); let policy = SimpleSpeculativeExecutionPolicy { max_retry_count: max_speculative_executions as usize, @@ -765,9 +769,9 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_no_speculative_execution_policy( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { builder.speculative_execution_policy(None) @@ -778,19 +782,19 @@ pub unsafe extern "C" fn cass_cluster_set_no_speculative_execution_policy( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_token_aware_routing( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster.load_balancing_config.token_awareness_enabled = enabled != 0; } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_token_aware_routing_shuffle_replicas( - cluster_raw: *mut CassCluster, + mut cluster_raw: CassExclusiveMutPtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); cluster .load_balancing_config @@ -799,12 +803,12 @@ pub unsafe extern "C" fn cass_cluster_set_token_aware_routing_shuffle_replicas( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_retry_policy( - cluster_raw: *mut CassCluster, - retry_policy: *const CassRetryPolicy, + mut cluster_raw: CassExclusiveMutPtr, + retry_policy: CassSharedPtr, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw); + let cluster = BoxFFI::as_mut_ref(&mut cluster_raw).unwrap(); - let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { + let retry_policy: Arc = match ArcFFI::as_ref(&retry_policy).unwrap() { DefaultRetryPolicy(default) => Arc::clone(default) as _, FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, @@ -816,9 +820,12 @@ pub unsafe extern "C" fn cass_cluster_set_retry_policy( } #[no_mangle] -pub unsafe extern "C" fn cass_cluster_set_ssl(cluster: *mut CassCluster, ssl: *mut CassSsl) { - let cluster_from_raw = BoxFFI::as_mut_ref(cluster); - let cass_ssl = ArcFFI::cloned_from_ptr(ssl); +pub unsafe extern "C" fn cass_cluster_set_ssl( + mut cluster: CassExclusiveMutPtr, + ssl: CassSharedPtr, +) { + let cluster_from_raw = BoxFFI::as_mut_ref(&mut cluster).unwrap(); + let cass_ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); let ssl_context_builder = SslContextBuilder::from_ptr(cass_ssl.ssl_context); // Reference count is increased as tokio_openssl will try to free `ssl_context` when calling `SSL_free`. @@ -829,10 +836,10 @@ pub unsafe extern "C" fn cass_cluster_set_ssl(cluster: *mut CassCluster, ssl: *m #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_compression( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, compression_type: CassCompressionType, ) { - let cluster_from_raw = BoxFFI::as_mut_ref(cluster); + let cluster_from_raw = BoxFFI::as_mut_ref(&mut cluster).unwrap(); let compression = match compression_type { CassCompressionType::CASS_COMPRESSION_LZ4 => Some(Compression::Lz4), CassCompressionType::CASS_COMPRESSION_SNAPPY => Some(Compression::Snappy), @@ -844,23 +851,23 @@ pub unsafe extern "C" fn cass_cluster_set_compression( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster); + let cluster = BoxFFI::as_mut_ref(&mut cluster).unwrap(); cluster.load_balancing_config.latency_awareness_enabled = enabled != 0; } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing_settings( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, exclusion_threshold: cass_double_t, scale_ms: cass_uint64_t, retry_period_ms: cass_uint64_t, update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster); + let cluster = BoxFFI::as_mut_ref(&mut cluster).unwrap(); cluster.load_balancing_config.latency_awareness_builder = LatencyAwarenessBuilder::new() .exclusion_threshold(exclusion_threshold) .scale(Duration::from_millis(scale_ms)) @@ -871,10 +878,10 @@ pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing_settings( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_consistency( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, consistency: CassConsistency, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster); + let cluster = BoxFFI::as_mut_ref(&mut cluster).unwrap(); let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -889,10 +896,10 @@ pub unsafe extern "C" fn cass_cluster_set_consistency( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_serial_consistency( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, serial_consistency: CassConsistency, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster); + let cluster = BoxFFI::as_mut_ref(&mut cluster).unwrap(); let serial_consistency: SerialConsistency = match serial_consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -907,21 +914,21 @@ pub unsafe extern "C" fn cass_cluster_set_serial_consistency( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_execution_profile( - cluster: *mut CassCluster, + cluster: CassExclusiveMutPtr, name: *const c_char, - profile: *const CassExecProfile, + profile: CassExclusiveConstPtr, ) -> CassError { cass_cluster_set_execution_profile_n(cluster, name, strlen(name), profile) } #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_execution_profile_n( - cluster: *mut CassCluster, + mut cluster: CassExclusiveMutPtr, name: *const c_char, name_length: size_t, - profile: *const CassExecProfile, + profile: CassExclusiveConstPtr, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster); + let cluster = BoxFFI::as_mut_ref(&mut cluster).unwrap(); let name = if let Some(name) = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()) { @@ -930,7 +937,7 @@ pub unsafe extern "C" fn cass_cluster_set_execution_profile_n( // Got NULL or empty string, which is invalid name for a profile. return CassError::CASS_ERROR_LIB_BAD_PARAMS; }; - let profile = if let Some(profile) = BoxFFI::as_maybe_ref(profile) { + let profile = if let Some(profile) = BoxFFI::as_ref(&profile) { profile.clone() } else { return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -965,23 +972,28 @@ mod tests { let cluster_raw = cass_cluster_new(); { /* Test valid configurations */ - let cluster = BoxFFI::as_ref(cluster_raw); + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); { assert_matches!(cluster.load_balancing_config.load_balancing_kind, None); assert!(cluster.load_balancing_config.token_awareness_enabled); assert!(!cluster.load_balancing_config.latency_awareness_enabled); } { - cass_cluster_set_token_aware_routing(cluster_raw, 0); + cass_cluster_set_token_aware_routing(cluster_raw.clone(), 0); assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, c"eu".as_ptr(), 0, 0), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + c"eu".as_ptr(), + 0, + 0 + ), CassError::CASS_OK ); - cass_cluster_set_latency_aware_routing(cluster_raw, 1); + cass_cluster_set_latency_aware_routing(cluster_raw.clone(), 1); // These values cannot currently be tested to be set properly in the latency awareness builder, // but at least we test that the function completed successfully. cass_cluster_set_latency_aware_routing_settings( - cluster_raw, + cluster_raw.clone(), 2., 1, 2000, @@ -1002,7 +1014,7 @@ mod tests { // set preferred rack+dc assert_cass_error_eq!( cass_cluster_set_load_balance_rack_aware( - cluster_raw, + cluster_raw.clone(), c"eu-east".as_ptr(), c"rack1".as_ptr(), ), @@ -1024,7 +1036,12 @@ mod tests { // set back to preferred dc assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, c"eu".as_ptr(), 0, 0), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + c"eu".as_ptr(), + 0, + 0 + ), CassError::CASS_OK ); @@ -1041,22 +1058,37 @@ mod tests { { // Nonzero deprecated parameters assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, c"eu".as_ptr(), 1, 0), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + c"eu".as_ptr(), + 1, + 0 + ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, c"eu".as_ptr(), 0, 1), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + c"eu".as_ptr(), + 0, + 1 + ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); // null pointers assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, std::ptr::null(), 0, 0), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + std::ptr::null(), + 0, + 0 + ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); assert_cass_error_eq!( cass_cluster_set_load_balance_rack_aware( - cluster_raw, + cluster_raw.clone(), c"eu".as_ptr(), std::ptr::null(), ), @@ -1064,7 +1096,7 @@ mod tests { ); assert_cass_error_eq!( cass_cluster_set_load_balance_rack_aware( - cluster_raw, + cluster_raw.clone(), std::ptr::null(), c"rack".as_ptr(), ), @@ -1076,12 +1108,17 @@ mod tests { #[allow(clippy::manual_c_str_literals)] let empty_str = "\0".as_ptr() as *const i8; assert_cass_error_eq!( - cass_cluster_set_load_balance_dc_aware(cluster_raw, std::ptr::null(), 0, 0), + cass_cluster_set_load_balance_dc_aware( + cluster_raw.clone(), + std::ptr::null(), + 0, + 0 + ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); assert_cass_error_eq!( cass_cluster_set_load_balance_rack_aware( - cluster_raw, + cluster_raw.clone(), c"eu".as_ptr(), empty_str, ), @@ -1089,7 +1126,7 @@ mod tests { ); assert_cass_error_eq!( cass_cluster_set_load_balance_rack_aware( - cluster_raw, + cluster_raw.clone(), empty_str, c"rack".as_ptr(), ), @@ -1110,16 +1147,16 @@ mod tests { let exec_profile_raw = cass_execution_profile_new(); { /* Test valid configurations */ - let cluster = BoxFFI::as_ref(cluster_raw); + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); { assert!(cluster.execution_profile_map.is_empty()); } { assert_cass_error_eq!( cass_cluster_set_execution_profile( - cluster_raw, + cluster_raw.clone(), make_c_str!("profile1"), - exec_profile_raw + exec_profile_raw.clone().into_const() ), CassError::CASS_OK ); @@ -1132,10 +1169,10 @@ mod tests { let (c_str, c_strlen) = str_to_c_str_n("profile1"); assert_cass_error_eq!( cass_cluster_set_execution_profile_n( - cluster_raw, + cluster_raw.clone(), c_str, c_strlen, - exec_profile_raw + exec_profile_raw.clone().into_const() ), CassError::CASS_OK ); @@ -1147,9 +1184,9 @@ mod tests { assert_cass_error_eq!( cass_cluster_set_execution_profile( - cluster_raw, + cluster_raw.clone(), make_c_str!("profile2"), - exec_profile_raw + exec_profile_raw.clone().into_const() ), CassError::CASS_OK ); @@ -1165,9 +1202,9 @@ mod tests { // NULL name assert_cass_error_eq!( cass_cluster_set_execution_profile( - cluster_raw, + cluster_raw.clone(), std::ptr::null(), - exec_profile_raw + exec_profile_raw.clone().into_const() ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); @@ -1176,9 +1213,9 @@ mod tests { // empty name assert_cass_error_eq!( cass_cluster_set_execution_profile( - cluster_raw, + cluster_raw.clone(), make_c_str!(""), - exec_profile_raw + exec_profile_raw.clone().into_const() ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); @@ -1187,9 +1224,9 @@ mod tests { // NULL profile assert_cass_error_eq!( cass_cluster_set_execution_profile( - cluster_raw, + cluster_raw.clone(), make_c_str!("profile1"), - std::ptr::null() + BoxFFI::null(), ), CassError::CASS_ERROR_LIB_BAD_PARAMS ); diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index 8f5621ad..527c219f 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -137,7 +137,7 @@ impl TryFrom<&CassCollection> for CassCqlValue { pub unsafe extern "C" fn cass_collection_new( collection_type: CassCollectionType, item_count: size_t, -) -> *mut CassCollection { +) -> CassExclusiveMutPtr { let capacity = match collection_type { // Maps consist of a key and a value, so twice // the number of CassCqlValue will be stored. @@ -155,10 +155,10 @@ pub unsafe extern "C" fn cass_collection_new( #[no_mangle] unsafe extern "C" fn cass_collection_new_from_data_type( - data_type: *const CassDataType, + data_type: CassSharedPtr, item_count: size_t, -) -> *mut CassCollection { - let data_type = ArcFFI::cloned_from_ptr(data_type); +) -> CassExclusiveMutPtr { + let data_type = ArcFFI::cloned_from_ptr(data_type).unwrap(); let (capacity, collection_type) = match data_type.get_unchecked() { CassDataTypeInner::List { .. } => { (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST) @@ -169,7 +169,7 @@ unsafe extern "C" fn cass_collection_new_from_data_type( CassDataTypeInner::Map { .. } => { (item_count * 2, CassCollectionType::CASS_COLLECTION_TYPE_MAP) } - _ => return std::ptr::null_mut(), + _ => return BoxFFI::null_mut(), }; let capacity = capacity as usize; @@ -183,9 +183,9 @@ unsafe extern "C" fn cass_collection_new_from_data_type( #[no_mangle] unsafe extern "C" fn cass_collection_data_type( - collection: *const CassCollection, -) -> *const CassDataType { - let collection_ref = BoxFFI::as_ref(collection); + collection: CassExclusiveConstPtr, +) -> CassSharedPtr { + let collection_ref = BoxFFI::as_ref(&collection).unwrap(); match &collection_ref.data_type { Some(dt) => ArcFFI::as_ptr(dt), @@ -203,7 +203,7 @@ unsafe extern "C" fn cass_collection_data_type( } #[no_mangle] -pub unsafe extern "C" fn cass_collection_free(collection: *mut CassCollection) { +pub unsafe extern "C" fn cass_collection_free(collection: CassExclusiveMutPtr) { BoxFFI::free(collection); } @@ -256,19 +256,19 @@ mod tests { let untyped_map = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_MAP, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_map, false as cass_bool_t), + cass_collection_append_bool(untyped_map.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_map, 42), + cass_collection_append_int16(untyped_map.clone(), 42), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_double(untyped_map, 42.42), + cass_collection_append_double(untyped_map.clone(), 42.42), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_float(untyped_map, 42.42), + cass_collection_append_float(untyped_map.clone(), 42.42), CassError::CASS_OK ); cass_collection_free(untyped_map); @@ -285,19 +285,19 @@ mod tests { let untyped_map = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_map, false as cass_bool_t), + cass_collection_append_bool(untyped_map.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_map, 42), + cass_collection_append_int16(untyped_map.clone(), 42), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_double(untyped_map, 42.42), + cass_collection_append_double(untyped_map.clone(), 42.42), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_float(untyped_map, 42.42), + cass_collection_append_float(untyped_map.clone(), 42.42), CassError::CASS_OK ); cass_collection_free(untyped_map); @@ -316,27 +316,27 @@ mod tests { let half_typed_map = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( - cass_collection_append_bool(half_typed_map, false as cass_bool_t), + cass_collection_append_bool(half_typed_map.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(half_typed_map, 42), + cass_collection_append_int16(half_typed_map.clone(), 42), CassError::CASS_OK ); // Second entry -> key typecheck failed. assert_cass_error_eq!( - cass_collection_append_double(half_typed_map, 42.42), + cass_collection_append_double(half_typed_map.clone(), 42.42), CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); // Second entry -> typecheck succesful. assert_cass_error_eq!( - cass_collection_append_bool(half_typed_map, true as cass_bool_t), + cass_collection_append_bool(half_typed_map.clone(), true as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_double(half_typed_map, 42.42), + cass_collection_append_double(half_typed_map.clone(), 42.42), CassError::CASS_OK ); cass_collection_free(half_typed_map); @@ -356,31 +356,31 @@ mod tests { frozen: false, }); let dt_ptr = ArcFFI::into_ptr(dt); - let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr, 2); + let bool_to_i16_map = cass_collection_new_from_data_type(dt_ptr.clone(), 2); // First entry -> typecheck successful. assert_cass_error_eq!( - cass_collection_append_bool(bool_to_i16_map, false as cass_bool_t), + cass_collection_append_bool(bool_to_i16_map.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(bool_to_i16_map, 42), + cass_collection_append_int16(bool_to_i16_map.clone(), 42), CassError::CASS_OK ); // Second entry -> key typecheck failed. assert_cass_error_eq!( - cass_collection_append_float(bool_to_i16_map, 42.42), + cass_collection_append_float(bool_to_i16_map.clone(), 42.42), CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); // Third entry -> value typecheck failed. assert_cass_error_eq!( - cass_collection_append_bool(bool_to_i16_map, true as cass_bool_t), + cass_collection_append_bool(bool_to_i16_map.clone(), true as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_float(bool_to_i16_map, 42.42), + cass_collection_append_float(bool_to_i16_map.clone(), 42.42), CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); @@ -393,11 +393,11 @@ mod tests { let untyped_set = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_SET, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_set, false as cass_bool_t), + cass_collection_append_bool(untyped_set.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_set, 42), + cass_collection_append_int16(untyped_set.clone(), 42), CassError::CASS_OK ); cass_collection_free(untyped_set); @@ -414,11 +414,11 @@ mod tests { let untyped_set = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_set, false as cass_bool_t), + cass_collection_append_bool(untyped_set.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_set, 42), + cass_collection_append_int16(untyped_set.clone(), 42), CassError::CASS_OK ); cass_collection_free(untyped_set); @@ -433,14 +433,14 @@ mod tests { frozen: false, }); let dt_ptr = ArcFFI::into_ptr(dt); - let bool_set = cass_collection_new_from_data_type(dt_ptr, 2); + let bool_set = cass_collection_new_from_data_type(dt_ptr.clone(), 2); assert_cass_error_eq!( - cass_collection_append_bool(bool_set, true as cass_bool_t), + cass_collection_append_bool(bool_set.clone(), true as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_float(bool_set, 42.42), + cass_collection_append_float(bool_set.clone(), 42.42), CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); @@ -453,11 +453,11 @@ mod tests { let untyped_list = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_list, false as cass_bool_t), + cass_collection_append_bool(untyped_list.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_list, 42), + cass_collection_append_int16(untyped_list.clone(), 42), CassError::CASS_OK ); cass_collection_free(untyped_list); @@ -474,11 +474,11 @@ mod tests { let untyped_list = cass_collection_new_from_data_type(dt_ptr, 2); assert_cass_error_eq!( - cass_collection_append_bool(untyped_list, false as cass_bool_t), + cass_collection_append_bool(untyped_list.clone(), false as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_int16(untyped_list, 42), + cass_collection_append_int16(untyped_list.clone(), 42), CassError::CASS_OK ); cass_collection_free(untyped_list); @@ -493,14 +493,14 @@ mod tests { frozen: false, }); let dt_ptr = ArcFFI::into_ptr(dt); - let bool_list = cass_collection_new_from_data_type(dt_ptr, 2); + let bool_list = cass_collection_new_from_data_type(dt_ptr.clone(), 2); assert_cass_error_eq!( - cass_collection_append_bool(bool_list, true as cass_bool_t), + cass_collection_append_bool(bool_list.clone(), true as cass_bool_t), CassError::CASS_OK ); assert_cass_error_eq!( - cass_collection_append_float(bool_list, 42.42), + cass_collection_append_float(bool_list.clone(), 42.42), CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE ); @@ -519,16 +519,14 @@ mod tests { let empty_list = cass_collection_new(CassCollectionType::CASS_COLLECTION_TYPE_LIST, 2); // This would previously return a non Arc-based pointer. - let empty_list_dt = cass_collection_data_type(empty_list); + let empty_list_dt = cass_collection_data_type(empty_list.into_const()); let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET); // This will try to increment the reference count of `empty_list_dt`. // Previously, this would fail, because `empty_list_dt` did not originate from an Arc allocation. - cass_data_type_add_sub_type(empty_set_dt, empty_list_dt); + cass_data_type_add_sub_type(empty_set_dt.clone(), empty_list_dt); - // Cast to *mut, because `cass_data_type_new` returns a *const. See the comment - // in this function to see why. - cass_data_type_free(empty_set_dt as *mut _) + cass_data_type_free(empty_set_dt) } } } diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index 281d9a3f..0d866751 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -13,7 +13,7 @@ use scylla::retry_policy::RetryPolicy; use scylla::speculative_execution::SimpleSpeculativeExecutionPolicy; use scylla::statement::Consistency; -use crate::argconv::{ptr_to_cstr_n, strlen, ArcFFI, BoxFFI}; +use crate::argconv::{ptr_to_cstr_n, strlen, ArcFFI, BoxFFI, CassExclusiveMutPtr, CassSharedPtr}; use crate::batch::CassBatch; use crate::cass_error::CassError; use crate::cass_types::CassConsistency; @@ -171,12 +171,14 @@ pub(crate) enum PerStatementExecProfileInner { } #[no_mangle] -pub unsafe extern "C" fn cass_execution_profile_new() -> *mut CassExecProfile { +pub unsafe extern "C" fn cass_execution_profile_new() -> CassExclusiveMutPtr { BoxFFI::into_ptr(Box::new(CassExecProfile::new())) } #[no_mangle] -pub unsafe extern "C" fn cass_execution_profile_free(profile: *mut CassExecProfile) { +pub unsafe extern "C" fn cass_execution_profile_free( + profile: CassExclusiveMutPtr, +) { BoxFFI::free(profile); } @@ -184,7 +186,7 @@ pub unsafe extern "C" fn cass_execution_profile_free(profile: *mut CassExecProfi #[no_mangle] pub unsafe extern "C" fn cass_statement_set_execution_profile( - statement: *mut CassStatement, + statement: CassExclusiveMutPtr, name: *const c_char, ) -> CassError { cass_statement_set_execution_profile_n(statement, name, strlen(name)) @@ -192,11 +194,11 @@ pub unsafe extern "C" fn cass_statement_set_execution_profile( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_execution_profile_n( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, name: *const c_char, name_length: size_t, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement); + let statement = BoxFFI::as_mut_ref(&mut statement).unwrap(); let name: Option = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()); statement.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -206,7 +208,7 @@ pub unsafe extern "C" fn cass_statement_set_execution_profile_n( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_execution_profile( - batch: *mut CassBatch, + batch: CassExclusiveMutPtr, name: *const c_char, ) -> CassError { cass_batch_set_execution_profile_n(batch, name, strlen(name)) @@ -214,11 +216,11 @@ pub unsafe extern "C" fn cass_batch_set_execution_profile( #[no_mangle] pub unsafe extern "C" fn cass_batch_set_execution_profile_n( - batch: *mut CassBatch, + mut batch: CassExclusiveMutPtr, name: *const c_char, name_length: size_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch); + let batch = BoxFFI::as_mut_ref(&mut batch).unwrap(); let name: Option = ptr_to_cstr_n(name, name_length).and_then(|name| name.to_owned().try_into().ok()); batch.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -248,10 +250,10 @@ impl CassExecProfile { #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_consistency( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, consistency: CassConsistency, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -264,9 +266,9 @@ pub unsafe extern "C" fn cass_execution_profile_set_consistency( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_no_speculative_execution_policy( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder.modify_in_place(|builder| builder.speculative_execution_policy(None)); @@ -275,11 +277,11 @@ pub unsafe extern "C" fn cass_execution_profile_set_no_speculative_execution_pol #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_constant_speculative_execution_policy( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, constant_delay_ms: cass_int64_t, max_speculative_executions: cass_int32_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); if constant_delay_ms < 0 || max_speculative_executions < 0 { return CassError::CASS_ERROR_LIB_BAD_PARAMS; } @@ -297,10 +299,10 @@ pub unsafe extern "C" fn cass_execution_profile_set_constant_speculative_executi #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder .load_balancing_config .latency_awareness_enabled = enabled != 0; @@ -310,14 +312,14 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing_settings( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, exclusion_threshold: cass_double_t, _scale_ms: cass_uint64_t, // Currently ignored, TODO: add this parameter to Rust driver retry_period_ms: cass_uint64_t, update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder .load_balancing_config .latency_awareness_builder = LatencyAwarenessBuilder::new() @@ -329,7 +331,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing_settin #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware( - profile: *mut CassExecProfile, + profile: CassExclusiveMutPtr, local_dc: *const c_char, used_hosts_per_remote_dc: cass_uint32_t, allow_remote_dcs_for_local_cl: cass_bool_t, @@ -345,13 +347,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware_n( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, local_dc: *const c_char, local_dc_length: size_t, used_hosts_per_remote_dc: cass_uint32_t, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); set_load_balance_dc_aware_n( &mut profile_builder.load_balancing_config, @@ -364,7 +366,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware_n( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware( - profile: *mut CassExecProfile, + profile: CassExclusiveMutPtr, local_dc_raw: *const c_char, local_rack_raw: *const c_char, ) -> CassError { @@ -379,13 +381,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, local_dc_raw: *const c_char, local_dc_length: size_t, local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); set_load_balance_rack_aware_n( &mut profile_builder.load_balancing_config, @@ -398,9 +400,9 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); CassError::CASS_OK @@ -408,10 +410,10 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_request_timeout( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, timeout_ms: cass_uint64_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder.modify_in_place(|builder| { builder.request_timeout(Some(std::time::Duration::from_millis(timeout_ms))) }); @@ -421,15 +423,15 @@ pub unsafe extern "C" fn cass_execution_profile_set_request_timeout( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_retry_policy( - profile: *mut CassExecProfile, - retry_policy: *const CassRetryPolicy, + mut profile: CassExclusiveMutPtr, + retry_policy: CassSharedPtr, ) -> CassError { - let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { + let retry_policy: Arc = match ArcFFI::as_ref(&retry_policy).unwrap() { DefaultRetryPolicy(default) => Arc::clone(default) as _, FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, }; - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder.modify_in_place(|builder| builder.retry_policy(retry_policy)); CassError::CASS_OK @@ -437,10 +439,10 @@ pub unsafe extern "C" fn cass_execution_profile_set_retry_policy( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_serial_consistency( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, serial_consistency: CassConsistency, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); let maybe_serial_consistency = if serial_consistency == CassConsistency::CASS_CONSISTENCY_UNKNOWN { @@ -458,10 +460,10 @@ pub unsafe extern "C" fn cass_execution_profile_set_serial_consistency( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder .load_balancing_config .token_awareness_enabled = enabled != 0; @@ -471,10 +473,10 @@ pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing( #[no_mangle] pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing_shuffle_replicas( - profile: *mut CassExecProfile, + mut profile: CassExclusiveMutPtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile); + let profile_builder = BoxFFI::as_mut_ref(&mut profile).unwrap(); profile_builder .load_balancing_config .token_aware_shuffling_replicas_enabled = enabled != 0; @@ -519,28 +521,28 @@ mod tests { let profile_raw = cass_execution_profile_new(); { /* Test valid configurations */ - let profile = BoxFFI::as_ref(profile_raw); + let profile = BoxFFI::as_ref(&profile_raw).unwrap(); { assert_matches!(profile.load_balancing_config.load_balancing_kind, None); assert!(profile.load_balancing_config.token_awareness_enabled); assert!(!profile.load_balancing_config.latency_awareness_enabled); } { - cass_execution_profile_set_token_aware_routing(profile_raw, 0); + cass_execution_profile_set_token_aware_routing(profile_raw.clone(), 0); assert_cass_error_eq!( cass_execution_profile_set_load_balance_dc_aware( - profile_raw, + profile_raw.clone(), c"eu".as_ptr(), 0, 0 ), CassError::CASS_OK ); - cass_execution_profile_set_latency_aware_routing(profile_raw, 1); + cass_execution_profile_set_latency_aware_routing(profile_raw.clone(), 1); // These values cannot currently be tested to be set properly in the latency awareness builder, // but at least we test that the function completed successfully. cass_execution_profile_set_latency_aware_routing_settings( - profile_raw, + profile_raw.clone(), 2., 1, 2000, @@ -563,7 +565,7 @@ mod tests { // Nonzero deprecated parameters assert_cass_error_eq!( cass_execution_profile_set_load_balance_dc_aware( - profile_raw, + profile_raw.clone(), c"eu".as_ptr(), 1, 0 @@ -572,7 +574,7 @@ mod tests { ); assert_cass_error_eq!( cass_execution_profile_set_load_balance_dc_aware( - profile_raw, + profile_raw.clone(), c"eu".as_ptr(), 0, 1 @@ -618,14 +620,14 @@ mod tests { let statement_raw = cass_statement_new(empty_query, 0); let batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); assert_cass_error_eq!( - cass_batch_add_statement(batch_raw, statement_raw), + cass_batch_add_statement(batch_raw.clone(), statement_raw.clone().into_const()), CassError::CASS_OK ); { /* Test valid configurations */ - let statement = BoxFFI::as_ref(statement_raw); - let batch = BoxFFI::as_ref(batch_raw); + let statement = BoxFFI::as_ref(&statement_raw).unwrap(); + let batch = BoxFFI::as_ref(&batch_raw).unwrap(); { assert!(statement.exec_profile.is_none()); assert!(batch.exec_profile.is_none()); @@ -634,11 +636,14 @@ mod tests { let valid_name = "profile"; let valid_name_c_str = make_c_str!("profile"); assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, valid_name_c_str,), + cass_statement_set_execution_profile( + statement_raw.clone(), + valid_name_c_str, + ), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, valid_name_c_str,), + cass_batch_set_execution_profile(batch_raw.clone(), valid_name_c_str,), CassError::CASS_OK ); assert_eq!( @@ -669,11 +674,14 @@ mod tests { { // NULL name sets exec profile to None assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, std::ptr::null::()), + cass_statement_set_execution_profile( + statement_raw.clone(), + std::ptr::null::() + ), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, std::ptr::null::()), + cass_batch_set_execution_profile(batch_raw.clone(), std::ptr::null::()), CassError::CASS_OK ); assert!(statement.exec_profile.is_none()); @@ -685,7 +693,7 @@ mod tests { let (valid_name_c_str, valid_name_len) = str_to_c_str_n(valid_name); assert_cass_error_eq!( cass_statement_set_execution_profile_n( - statement_raw, + statement_raw.clone(), valid_name_c_str, valid_name_len, ), @@ -693,7 +701,7 @@ mod tests { ); assert_cass_error_eq!( cass_batch_set_execution_profile_n( - batch_raw, + batch_raw.clone(), valid_name_c_str, valid_name_len, ), @@ -727,11 +735,14 @@ mod tests { { // empty name sets exec profile to None assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, make_c_str!("")), + cass_statement_set_execution_profile( + statement_raw.clone(), + make_c_str!("") + ), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, make_c_str!("")), + cass_batch_set_execution_profile(batch_raw.clone(), make_c_str!("")), CassError::CASS_OK ); assert!(statement.exec_profile.is_none()); diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 328f764b..42fcc511 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -75,8 +75,8 @@ struct JoinHandleTimeout(JoinHandle<()>); impl CassFuture { pub fn make_raw( fut: impl Future + Send + 'static, - ) -> *mut CassFuture { - Self::new_from_future(fut).into_raw() as *mut _ + ) -> CassSharedPtr { + Self::new_from_future(fut).into_raw() } pub fn new_from_future( @@ -97,7 +97,7 @@ impl CassFuture { }; if let Some(bound_cb) = maybe_cb { let fut_ptr = ArcFFI::as_ptr(&cass_fut_clone); - bound_cb.invoke(fut_ptr); + bound_cb.invoke(ArcFFI::to_raw(&fut_ptr)); } cass_fut_clone.wait_for_value.notify_all(); @@ -282,7 +282,7 @@ impl CassFuture { CassError::CASS_OK } - fn into_raw(self: Arc) -> *const Self { + fn into_raw(self: Arc) -> CassSharedPtr { ArcFFI::into_ptr(self) } } @@ -295,31 +295,36 @@ impl CheckSendSync for CassFuture {} #[no_mangle] pub unsafe extern "C" fn cass_future_set_callback( - future_raw: *const CassFuture, + future_raw: CassSharedPtr, callback: CassFutureCallback, data: *mut ::std::os::raw::c_void, ) -> CassError { - ArcFFI::as_ref(future_raw).set_callback(future_raw, callback, data) + ArcFFI::as_ref(&future_raw) + .unwrap() + .set_callback(ArcFFI::to_raw(&future_raw), callback, data) } #[no_mangle] -pub unsafe extern "C" fn cass_future_wait(future_raw: *const CassFuture) { - ArcFFI::as_ref(future_raw).with_waited_result(|_| ()); +pub unsafe extern "C" fn cass_future_wait(future_raw: CassSharedPtr) { + ArcFFI::as_ref(&future_raw) + .unwrap() + .with_waited_result(|_| ()); } #[no_mangle] pub unsafe extern "C" fn cass_future_wait_timed( - future_raw: *const CassFuture, + future_raw: CassSharedPtr, timeout_us: cass_duration_t, ) -> cass_bool_t { - ArcFFI::as_ref(future_raw) + ArcFFI::as_ref(&future_raw) + .unwrap() .with_waited_result_timed(|_| (), Duration::from_micros(timeout_us)) .is_ok() as cass_bool_t } #[no_mangle] -pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cass_bool_t { - let state_guard = ArcFFI::as_ref(future_raw).state.lock().unwrap(); +pub unsafe extern "C" fn cass_future_ready(future_raw: CassSharedPtr) -> cass_bool_t { + let state_guard = ArcFFI::as_ref(&future_raw).unwrap().state.lock().unwrap(); match state_guard.value { None => cass_false, Some(_) => cass_true, @@ -327,95 +332,106 @@ pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cas } #[no_mangle] -pub unsafe extern "C" fn cass_future_error_code(future_raw: *const CassFuture) -> CassError { - ArcFFI::as_ref(future_raw).with_waited_result(|r: &mut CassFutureResult| match r { - Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), - Err((err, _)) => *err, - _ => CassError::CASS_OK, - }) +pub unsafe extern "C" fn cass_future_error_code( + future_raw: CassSharedPtr, +) -> CassError { + ArcFFI::as_ref(&future_raw) + .unwrap() + .with_waited_result(|r: &mut CassFutureResult| match r { + Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), + Err((err, _)) => *err, + _ => CassError::CASS_OK, + }) } #[no_mangle] pub unsafe extern "C" fn cass_future_error_message( - future: *mut CassFuture, + future: CassSharedPtr, message: *mut *const ::std::os::raw::c_char, message_length: *mut size_t, ) { - ArcFFI::as_ref(future).with_waited_state(|state: &mut CassFutureState| { - let value = &state.value; - let msg = state - .err_string - .get_or_insert_with(|| match value.as_ref().unwrap() { - Ok(CassResultValue::QueryError(err)) => err.msg(), - Err((_, s)) => s.msg(), - _ => "".to_string(), - }); - write_str_to_c(msg.as_str(), message, message_length); - }); + ArcFFI::as_ref(&future) + .unwrap() + .with_waited_state(|state: &mut CassFutureState| { + let value = &state.value; + let msg = state + .err_string + .get_or_insert_with(|| match value.as_ref().unwrap() { + Ok(CassResultValue::QueryError(err)) => err.msg(), + Err((_, s)) => s.msg(), + _ => "".to_string(), + }); + write_str_to_c(msg.as_str(), message, message_length); + }); } #[no_mangle] -pub unsafe extern "C" fn cass_future_free(future_raw: *const CassFuture) { +pub unsafe extern "C" fn cass_future_free(future_raw: CassSharedPtr) { ArcFFI::free(future_raw); } #[no_mangle] pub unsafe extern "C" fn cass_future_get_result( - future_raw: *const CassFuture, -) -> *const CassResult { - ArcFFI::as_ref(future_raw) + future_raw: CassSharedPtr, +) -> CassSharedPtr { + ArcFFI::as_ref(&future_raw) + .unwrap() .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryResult(qr) => Some(qr.clone()), _ => None, } }) - .map_or(std::ptr::null(), ArcFFI::into_ptr) + .map_or(ArcFFI::null(), ArcFFI::into_ptr) } #[no_mangle] pub unsafe extern "C" fn cass_future_get_error_result( - future_raw: *const CassFuture, -) -> *const CassErrorResult { - ArcFFI::as_ref(future_raw) + future_raw: CassSharedPtr, +) -> CassSharedPtr { + ArcFFI::as_ref(&future_raw) + .unwrap() .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryError(qr) => Some(qr.clone()), _ => None, } }) - .map_or(std::ptr::null(), ArcFFI::into_ptr) + .map_or(ArcFFI::null(), ArcFFI::into_ptr) } #[no_mangle] pub unsafe extern "C" fn cass_future_get_prepared( - future_raw: *mut CassFuture, -) -> *const CassPrepared { - ArcFFI::as_ref(future_raw) + future_raw: CassSharedPtr, +) -> CassSharedPtr { + ArcFFI::as_ref(&future_raw) + .unwrap() .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::Prepared(p) => Some(p.clone()), _ => None, } }) - .map_or(std::ptr::null(), ArcFFI::into_ptr) + .map_or(ArcFFI::null(), ArcFFI::into_ptr) } #[no_mangle] pub unsafe extern "C" fn cass_future_tracing_id( - future: *const CassFuture, + future: CassSharedPtr, tracing_id: *mut CassUuid, ) -> CassError { - ArcFFI::as_ref(future).with_waited_result(|r: &mut CassFutureResult| match r { - Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { - Some(id) => { - *tracing_id = CassUuid::from(id); - CassError::CASS_OK - } - None => CassError::CASS_ERROR_LIB_NO_TRACING_ID, - }, - _ => CassError::CASS_ERROR_LIB_INVALID_FUTURE_TYPE, - }) + ArcFFI::as_ref(&future) + .unwrap() + .with_waited_result(|r: &mut CassFutureResult| match r { + Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { + Some(id) => { + *tracing_id = CassUuid::from(id); + CassError::CASS_OK + } + None => CassError::CASS_ERROR_LIB_NO_TRACING_ID, + }, + _ => CassError::CASS_ERROR_LIB_INVALID_FUTURE_TYPE, + }) } #[cfg(test)] @@ -443,9 +459,9 @@ mod tests { }; let cass_fut = CassFuture::make_raw(fut); - struct PtrWrapper(*mut CassFuture); + struct PtrWrapper(CassSharedPtr); unsafe impl Send for PtrWrapper {} - let wrapped_cass_fut = PtrWrapper(cass_fut); + let wrapped_cass_fut = PtrWrapper(unsafe { cass_fut.clone() }); unsafe { let handle = thread::spawn(move || { let wrapper = wrapped_cass_fut; @@ -474,11 +490,13 @@ mod tests { unsafe { // This should timeout on tokio::time::timeout. - let timed_result = cass_future_wait_timed(cass_fut, HUNDRED_MILLIS_IN_MICROS / 5); + let timed_result = + cass_future_wait_timed(cass_fut.clone(), HUNDRED_MILLIS_IN_MICROS / 5); assert_eq!(0, timed_result); // This should timeout as well. - let timed_result = cass_future_wait_timed(cass_fut, HUNDRED_MILLIS_IN_MICROS / 5); + let timed_result = + cass_future_wait_timed(cass_fut.clone(), HUNDRED_MILLIS_IN_MICROS / 5); assert_eq!(0, timed_result); // Verify that future eventually resolves, even though timeouts occurred before. @@ -516,7 +534,11 @@ mod tests { let flag_ptr = Box::into_raw(flag); assert_cass_error_eq!( - cass_future_set_callback(cass_fut, Some(mark_flag_cb), flag_ptr as *mut c_void), + cass_future_set_callback( + cass_fut.clone(), + Some(mark_flag_cb), + flag_ptr as *mut c_void + ), CassError::CASS_OK ); @@ -526,7 +548,7 @@ mod tests { // Callback executed after awaiting. { let (cass_fut, flag_ptr) = create_future_and_flag(); - cass_future_wait(cass_fut); + cass_future_wait(cass_fut.clone()); assert_cass_future_error_message_eq!(cass_fut, Some(ERROR_MSG)); assert!(*flag_ptr); @@ -551,10 +573,12 @@ mod tests { let (cass_fut, flag_ptr) = create_future_and_flag(); // This should timeout on tokio::time::timeout. - let timed_result = cass_future_wait_timed(cass_fut, HUNDRED_MILLIS_IN_MICROS / 5); + let timed_result = + cass_future_wait_timed(cass_fut.clone(), HUNDRED_MILLIS_IN_MICROS / 5); assert_eq!(0, timed_result); // This should timeout as well. - let timed_result = cass_future_wait_timed(cass_fut, HUNDRED_MILLIS_IN_MICROS / 5); + let timed_result = + cass_future_wait_timed(cass_fut.clone(), HUNDRED_MILLIS_IN_MICROS / 5); assert_eq!(0, timed_result); // Await and check result. diff --git a/scylla-rust-wrapper/src/integration_testing.rs b/scylla-rust-wrapper/src/integration_testing.rs index 0fd4007f..2ed44cb4 100644 --- a/scylla-rust-wrapper/src/integration_testing.rs +++ b/scylla-rust-wrapper/src/integration_testing.rs @@ -1,34 +1,36 @@ use std::ffi::{c_char, CString}; use crate::{ - argconv::BoxFFI, + argconv::{BoxFFI, CassExclusiveConstPtr}, cluster::CassCluster, types::{cass_int32_t, cass_uint16_t, size_t}, }; #[no_mangle] pub unsafe extern "C" fn testing_cluster_get_connect_timeout( - cluster_raw: *const CassCluster, + cluster_raw: CassExclusiveConstPtr, ) -> cass_uint16_t { - let cluster = BoxFFI::as_ref(cluster_raw); + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); cluster.get_session_config().connect_timeout.as_millis() as cass_uint16_t } #[no_mangle] -pub unsafe extern "C" fn testing_cluster_get_port(cluster_raw: *const CassCluster) -> cass_int32_t { - let cluster = BoxFFI::as_ref(cluster_raw); +pub unsafe extern "C" fn testing_cluster_get_port( + cluster_raw: CassExclusiveConstPtr, +) -> cass_int32_t { + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); cluster.get_port() as cass_int32_t } #[no_mangle] pub unsafe extern "C" fn testing_cluster_get_contact_points( - cluster_raw: *const CassCluster, + cluster_raw: CassExclusiveConstPtr, contact_points: *mut *mut c_char, contact_points_length: *mut size_t, ) { - let cluster = BoxFFI::as_ref(cluster_raw); + let cluster = BoxFFI::as_ref(&cluster_raw).unwrap(); let contact_points_string = cluster.get_contact_points().join(","); let length = contact_points_string.len(); diff --git a/scylla-rust-wrapper/src/logging.rs b/scylla-rust-wrapper/src/logging.rs index c1a43f82..0aca2b53 100644 --- a/scylla-rust-wrapper/src/logging.rs +++ b/scylla-rust-wrapper/src/logging.rs @@ -1,4 +1,4 @@ -use crate::argconv::{arr_to_cstr, ptr_to_cstr, str_to_arr, RefFFI}; +use crate::argconv::{arr_to_cstr, ptr_to_cstr, str_to_arr, CassBorrowedPtr, RefFFI}; use crate::cass_log_types::{CassLogLevel, CassLogMessage}; use crate::types::size_t; use crate::LOGGER; @@ -17,9 +17,13 @@ use tracing_subscriber::Layer; impl RefFFI for CassLogMessage {} pub type CassLogCallback = - Option; + Option, data: *mut c_void)>; -unsafe extern "C" fn noop_log_callback(_message: *const CassLogMessage, _data: *mut c_void) {} +unsafe extern "C" fn noop_log_callback( + _message: CassBorrowedPtr, + _data: *mut c_void, +) { +} pub struct Logger { pub cb: CassLogCallback, @@ -64,8 +68,11 @@ impl TryFrom for Level { pub const CASS_LOG_MAX_MESSAGE_SIZE: usize = 1024; -pub unsafe extern "C" fn stderr_log_callback(message: *const CassLogMessage, _data: *mut c_void) { - let message = RefFFI::as_ref(message); +pub unsafe extern "C" fn stderr_log_callback( + message: CassBorrowedPtr, + _data: *mut c_void, +) { + let message = RefFFI::as_ref(&message).unwrap(); eprintln!( "{} [{}] ({}:{}) {}", @@ -132,7 +139,7 @@ where if let Some(log_cb) = logger.cb { unsafe { - log_cb(&log_message as *const CassLogMessage, logger.data); + log_cb(RefFFI::as_ptr(&log_message), logger.data); } } } diff --git a/scylla-rust-wrapper/src/metadata.rs b/scylla-rust-wrapper/src/metadata.rs index 2965c74c..2c515486 100644 --- a/scylla-rust-wrapper/src/metadata.rs +++ b/scylla-rust-wrapper/src/metadata.rs @@ -97,68 +97,68 @@ pub unsafe fn create_table_metadata( } #[no_mangle] -pub unsafe extern "C" fn cass_schema_meta_free(schema_meta: *mut CassSchemaMeta) { +pub unsafe extern "C" fn cass_schema_meta_free(schema_meta: CassExclusiveMutPtr) { BoxFFI::free(schema_meta); } #[no_mangle] pub unsafe extern "C" fn cass_schema_meta_keyspace_by_name( - schema_meta: *const CassSchemaMeta, + schema_meta: CassExclusiveConstPtr, keyspace_name: *const c_char, -) -> *const CassKeyspaceMeta { +) -> CassBorrowedPtr { cass_schema_meta_keyspace_by_name_n(schema_meta, keyspace_name, strlen(keyspace_name)) } #[no_mangle] pub unsafe extern "C" fn cass_schema_meta_keyspace_by_name_n( - schema_meta: *const CassSchemaMeta, + schema_meta: CassExclusiveConstPtr, keyspace_name: *const c_char, keyspace_name_length: size_t, -) -> *const CassKeyspaceMeta { +) -> CassBorrowedPtr { if keyspace_name.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let metadata = BoxFFI::as_ref(schema_meta); + let metadata = BoxFFI::as_ref(&schema_meta).unwrap(); let keyspace = ptr_to_cstr_n(keyspace_name, keyspace_name_length).unwrap(); let keyspace_meta = metadata.keyspaces.get(keyspace); match keyspace_meta { - Some(meta) => meta as *const CassKeyspaceMeta, - None => std::ptr::null(), + Some(meta) => RefFFI::as_ptr(meta), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_name( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, name: *mut *const c_char, name_length: *mut size_t, ) { - let keyspace_meta = RefFFI::as_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(&keyspace_meta).unwrap(); write_str_to_c(keyspace_meta.name.as_str(), name, name_length) } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_user_type_by_name( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, type_: *const c_char, -) -> *const CassDataType { +) -> CassSharedPtr { cass_keyspace_meta_user_type_by_name_n(keyspace_meta, type_, strlen(type_)) } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_user_type_by_name_n( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, type_: *const c_char, type_length: size_t, -) -> *const CassDataType { +) -> CassSharedPtr { if type_.is_null() { - return std::ptr::null(); + return ArcFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(&keyspace_meta).unwrap(); let user_type_name = ptr_to_cstr_n(type_, type_length).unwrap(); match keyspace_meta @@ -166,294 +166,296 @@ pub unsafe extern "C" fn cass_keyspace_meta_user_type_by_name_n( .get(user_type_name) { Some(udt) => ArcFFI::as_ptr(udt), - None => std::ptr::null(), + None => ArcFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_table_by_name( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, table: *const c_char, -) -> *const CassTableMeta { +) -> CassBorrowedPtr { cass_keyspace_meta_table_by_name_n(keyspace_meta, table, strlen(table)) } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_table_by_name_n( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, table: *const c_char, table_length: size_t, -) -> *const CassTableMeta { +) -> CassBorrowedPtr { if table.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(&keyspace_meta).unwrap(); let table_name = ptr_to_cstr_n(table, table_length).unwrap(); let table_meta = keyspace_meta.tables.get(table_name); match table_meta { Some(meta) => RefFFI::as_ptr(meta), - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_name( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, name: *mut *const c_char, name_length: *mut size_t, ) { - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); write_str_to_c(table_meta.name.as_str(), name, name_length) } #[no_mangle] -pub unsafe extern "C" fn cass_table_meta_column_count(table_meta: *const CassTableMeta) -> size_t { - let table_meta = RefFFI::as_ref(table_meta); +pub unsafe extern "C" fn cass_table_meta_column_count( + table_meta: CassBorrowedPtr, +) -> size_t { + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); table_meta.columns_metadata.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_partition_key( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, index: size_t, -) -> *const CassColumnMeta { - let table_meta = RefFFI::as_ref(table_meta); +) -> CassBorrowedPtr { + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); match table_meta.partition_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), }, - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_partition_key_count( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); table_meta.partition_keys.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_clustering_key( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, index: size_t, -) -> *const CassColumnMeta { - let table_meta = RefFFI::as_ref(table_meta); +) -> CassBorrowedPtr { + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); match table_meta.clustering_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), }, - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_clustering_key_count( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); table_meta.clustering_keys.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_column_by_name( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, column: *const c_char, -) -> *const CassColumnMeta { +) -> CassBorrowedPtr { cass_table_meta_column_by_name_n(table_meta, column, strlen(column)) } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_column_by_name_n( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, column: *const c_char, column_length: size_t, -) -> *const CassColumnMeta { +) -> CassBorrowedPtr { if column.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); let column_name = ptr_to_cstr_n(column, column_length).unwrap(); match table_meta.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_column_meta_name( - column_meta: *const CassColumnMeta, + column_meta: CassBorrowedPtr, name: *mut *const c_char, name_length: *mut size_t, ) { - let column_meta = RefFFI::as_ref(column_meta); + let column_meta = RefFFI::as_ref(&column_meta).unwrap(); write_str_to_c(column_meta.name.as_str(), name, name_length) } #[no_mangle] pub unsafe extern "C" fn cass_column_meta_data_type( - column_meta: *const CassColumnMeta, -) -> *const CassDataType { - let column_meta = RefFFI::as_ref(column_meta); + column_meta: CassBorrowedPtr, +) -> CassSharedPtr { + let column_meta = RefFFI::as_ref(&column_meta).unwrap(); ArcFFI::as_ptr(&column_meta.column_type) } #[no_mangle] pub unsafe extern "C" fn cass_column_meta_type( - column_meta: *const CassColumnMeta, + column_meta: CassBorrowedPtr, ) -> CassColumnType { - let column_meta = RefFFI::as_ref(column_meta); + let column_meta = RefFFI::as_ref(&column_meta).unwrap(); column_meta.column_kind } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_materialized_view_by_name( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, view: *const c_char, -) -> *const CassMaterializedViewMeta { +) -> CassBorrowedPtr { cass_keyspace_meta_materialized_view_by_name_n(keyspace_meta, view, strlen(view)) } #[no_mangle] pub unsafe extern "C" fn cass_keyspace_meta_materialized_view_by_name_n( - keyspace_meta: *const CassKeyspaceMeta, + keyspace_meta: CassBorrowedPtr, view: *const c_char, view_length: size_t, -) -> *const CassMaterializedViewMeta { +) -> CassBorrowedPtr { if view.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta); + let keyspace_meta = RefFFI::as_ref(&keyspace_meta).unwrap(); let view_name = ptr_to_cstr_n(view, view_length).unwrap(); match keyspace_meta.views.get(view_name) { Some(view_meta) => RefFFI::as_ptr(view_meta.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, view: *const c_char, -) -> *const CassMaterializedViewMeta { +) -> CassBorrowedPtr { cass_table_meta_materialized_view_by_name_n(table_meta, view, strlen(view)) } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name_n( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, view: *const c_char, view_length: size_t, -) -> *const CassMaterializedViewMeta { +) -> CassBorrowedPtr { if view.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); let view_name = ptr_to_cstr_n(view, view_length).unwrap(); match table_meta.views.get(view_name) { Some(view_meta) => RefFFI::as_ptr(view_meta.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_materialized_view_count( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta); + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); table_meta.views.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_table_meta_materialized_view( - table_meta: *const CassTableMeta, + table_meta: CassBorrowedPtr, index: size_t, -) -> *const CassMaterializedViewMeta { - let table_meta = RefFFI::as_ref(table_meta); +) -> CassBorrowedPtr { + let table_meta = RefFFI::as_ref(&table_meta).unwrap(); match table_meta.views.iter().nth(index as usize) { Some(view_meta) => RefFFI::as_ptr(view_meta.1.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_column_by_name( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, column: *const c_char, -) -> *const CassColumnMeta { +) -> CassBorrowedPtr { cass_materialized_view_meta_column_by_name_n(view_meta, column, strlen(column)) } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_column_by_name_n( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, column: *const c_char, column_length: size_t, -) -> *const CassColumnMeta { +) -> CassBorrowedPtr { if column.is_null() { - return std::ptr::null(); + return RefFFI::null(); } - let view_meta = RefFFI::as_ref(view_meta); + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); let column_name = ptr_to_cstr_n(column, column_length).unwrap(); match view_meta.view_metadata.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_name( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, name: *mut *const c_char, name_length: *mut size_t, ) { - let view_meta = RefFFI::as_ref(view_meta); + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); write_str_to_c(view_meta.name.as_str(), name, name_length) } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_base_table( - view_meta: *const CassMaterializedViewMeta, -) -> *const CassTableMeta { - let view_meta = RefFFI::as_ref(view_meta); + view_meta: CassBorrowedPtr, +) -> CassBorrowedPtr { + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); match view_meta.base_table.upgrade() { Some(arc) => RefFFI::as_ptr(&arc), - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_column_count( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, ) -> size_t { - let view_meta = RefFFI::as_ref(view_meta); + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); view_meta.view_metadata.columns_metadata.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_column( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, index: size_t, -) -> *const CassColumnMeta { - let view_meta = RefFFI::as_ref(view_meta); +) -> CassBorrowedPtr { + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); match view_meta .view_metadata @@ -461,53 +463,53 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column( .iter() .nth(index as usize) { - Some(column_entry) => column_entry.1 as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_entry) => RefFFI::as_ptr(column_entry.1), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_partition_key_count( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, ) -> size_t { - let view_meta = RefFFI::as_ref(view_meta); + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); view_meta.view_metadata.partition_keys.len() as size_t } pub unsafe extern "C" fn cass_materialized_view_meta_partition_key( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, index: size_t, -) -> *const CassColumnMeta { - let view_meta = RefFFI::as_ref(view_meta); +) -> CassBorrowedPtr { + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); match view_meta.view_metadata.partition_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), }, - None => std::ptr::null(), + None => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key_count( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, ) -> size_t { - let view_meta = RefFFI::as_ref(view_meta); + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); view_meta.view_metadata.clustering_keys.len() as size_t } pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key( - view_meta: *const CassMaterializedViewMeta, + view_meta: CassBorrowedPtr, index: size_t, -) -> *const CassColumnMeta { - let view_meta = RefFFI::as_ref(view_meta); +) -> CassBorrowedPtr { + let view_meta = RefFFI::as_ref(&view_meta).unwrap(); match view_meta.view_metadata.clustering_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { - Some(column_meta) => column_meta as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta) => RefFFI::as_ptr(column_meta), + None => RefFFI::null(), }, - None => std::ptr::null(), + None => RefFFI::null(), } } diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index 3fc4927e..9ea0515a 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -75,15 +75,15 @@ impl CassPrepared { impl ArcFFI for CassPrepared {} #[no_mangle] -pub unsafe extern "C" fn cass_prepared_free(prepared_raw: *const CassPrepared) { +pub unsafe extern "C" fn cass_prepared_free(prepared_raw: CassSharedPtr) { ArcFFI::free(prepared_raw); } #[no_mangle] pub unsafe extern "C" fn cass_prepared_bind( - prepared_raw: *const CassPrepared, -) -> *mut CassStatement { - let prepared: Arc<_> = ArcFFI::cloned_from_ptr(prepared_raw); + prepared_raw: CassSharedPtr, +) -> CassExclusiveMutPtr { + let prepared: Arc<_> = ArcFFI::cloned_from_ptr(prepared_raw).unwrap(); let bound_values_size = prepared.statement.get_variable_col_specs().len(); // cloning prepared statement's arc, because creating CassStatement should not invalidate @@ -103,12 +103,12 @@ pub unsafe extern "C" fn cass_prepared_bind( #[no_mangle] pub unsafe extern "C" fn cass_prepared_parameter_name( - prepared_raw: *const CassPrepared, + prepared_raw: CassSharedPtr, index: size_t, name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let prepared = ArcFFI::as_ref(prepared_raw); + let prepared = ArcFFI::as_ref(&prepared_raw).unwrap(); match prepared .statement @@ -125,38 +125,38 @@ pub unsafe extern "C" fn cass_prepared_parameter_name( #[no_mangle] pub unsafe extern "C" fn cass_prepared_parameter_data_type( - prepared_raw: *const CassPrepared, + prepared_raw: CassSharedPtr, index: size_t, -) -> *const CassDataType { - let prepared = ArcFFI::as_ref(prepared_raw); +) -> CassSharedPtr { + let prepared = ArcFFI::as_ref(&prepared_raw).unwrap(); match prepared.variable_col_data_types.get(index as usize) { Some(dt) => ArcFFI::as_ptr(dt), - None => std::ptr::null(), + None => ArcFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_prepared_parameter_data_type_by_name( - prepared_raw: *const CassPrepared, + prepared_raw: CassSharedPtr, name: *const c_char, -) -> *const CassDataType { +) -> CassSharedPtr { cass_prepared_parameter_data_type_by_name_n(prepared_raw, name, strlen(name)) } #[no_mangle] pub unsafe extern "C" fn cass_prepared_parameter_data_type_by_name_n( - prepared_raw: *const CassPrepared, + prepared_raw: CassSharedPtr, name: *const c_char, name_length: size_t, -) -> *const CassDataType { - let prepared = ArcFFI::as_ref(prepared_raw); +) -> CassSharedPtr { + let prepared = ArcFFI::as_ref(&prepared_raw).unwrap(); let parameter_name = ptr_to_cstr_n(name, name_length).expect("Prepared parameter name is not UTF-8"); let data_type = prepared.get_variable_data_type_by_name(parameter_name); match data_type { Some(dt) => ArcFFI::as_ptr(dt), - None => std::ptr::null(), + None => ArcFFI::null(), } } diff --git a/scylla-rust-wrapper/src/query_error.rs b/scylla-rust-wrapper/src/query_error.rs index a6aa376f..a8898014 100644 --- a/scylla-rust-wrapper/src/query_error.rs +++ b/scylla-rust-wrapper/src/query_error.rs @@ -56,21 +56,23 @@ impl From<&WriteType> for CassWriteType { } #[no_mangle] -pub unsafe extern "C" fn cass_error_result_free(error_result: *const CassErrorResult) { +pub unsafe extern "C" fn cass_error_result_free(error_result: CassSharedPtr) { ArcFFI::free(error_result); } #[no_mangle] -pub unsafe extern "C" fn cass_error_result_code(error_result: *const CassErrorResult) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); +pub unsafe extern "C" fn cass_error_result_code( + error_result: CassSharedPtr, +) -> CassError { + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); error_result.to_cass_error() } #[no_mangle] pub unsafe extern "C" fn cass_error_result_consistency( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> CassConsistency { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::Unavailable { consistency, .. }, @@ -98,9 +100,9 @@ pub unsafe extern "C" fn cass_error_result_consistency( #[no_mangle] pub unsafe extern "C" fn cass_error_result_responses_received( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { alive, .. }, _)) => { *alive @@ -123,9 +125,9 @@ pub unsafe extern "C" fn cass_error_result_responses_received( #[no_mangle] pub unsafe extern "C" fn cass_error_result_responses_required( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { required, .. }, _)) => { *required @@ -148,9 +150,9 @@ pub unsafe extern "C" fn cass_error_result_responses_required( #[no_mangle] pub unsafe extern "C" fn cass_error_result_num_failures( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::ReadFailure { numfailures, .. }, @@ -166,9 +168,9 @@ pub unsafe extern "C" fn cass_error_result_num_failures( #[no_mangle] pub unsafe extern "C" fn cass_error_result_data_present( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> cass_bool_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::ReadTimeout { data_present, .. }, @@ -196,9 +198,9 @@ pub unsafe extern "C" fn cass_error_result_data_present( #[no_mangle] pub unsafe extern "C" fn cass_error_result_write_type( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, ) -> CassWriteType { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::WriteTimeout { write_type, .. }, @@ -214,11 +216,11 @@ pub unsafe extern "C" fn cass_error_result_write_type( #[no_mangle] pub unsafe extern "C" fn cass_error_result_keyspace( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, c_keyspace: *mut *const ::std::os::raw::c_char, c_keyspace_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { keyspace, .. }, _)) => { write_str_to_c(keyspace.as_str(), c_keyspace, c_keyspace_len); @@ -237,11 +239,11 @@ pub unsafe extern "C" fn cass_error_result_keyspace( #[no_mangle] pub unsafe extern "C" fn cass_error_result_table( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, c_table: *mut *const ::std::os::raw::c_char, c_table_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { table, .. }, _)) => { write_str_to_c(table.as_str(), c_table, c_table_len); @@ -253,11 +255,11 @@ pub unsafe extern "C" fn cass_error_result_table( #[no_mangle] pub unsafe extern "C" fn cass_error_result_function( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, c_function: *mut *const ::std::os::raw::c_char, c_function_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { function, .. }, @@ -271,8 +273,10 @@ pub unsafe extern "C" fn cass_error_result_function( } #[no_mangle] -pub unsafe extern "C" fn cass_error_num_arg_types(error_result: *const CassErrorResult) -> size_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); +pub unsafe extern "C" fn cass_error_num_arg_types( + error_result: CassSharedPtr, +) -> size_t { + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { arg_types, .. }, @@ -284,12 +288,12 @@ pub unsafe extern "C" fn cass_error_num_arg_types(error_result: *const CassError #[no_mangle] pub unsafe extern "C" fn cass_error_result_arg_type( - error_result: *const CassErrorResult, + error_result: CassSharedPtr, index: size_t, arg_type: *mut *const ::std::os::raw::c_char, arg_type_length: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result); + let error_result: &CassErrorResult = ArcFFI::as_ref(&error_result).unwrap(); match error_result { CassErrorResult::Query(QueryError::DbError( DbError::FunctionFailure { arg_types, .. }, diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 5cce03ef..e21889a3 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -370,14 +370,16 @@ pub enum CassIterator { impl BoxFFI for CassIterator {} #[no_mangle] -pub unsafe extern "C" fn cass_iterator_free(iterator: *mut CassIterator) { +pub unsafe extern "C" fn cass_iterator_free(iterator: CassExclusiveMutPtr) { BoxFFI::free(iterator); } // After creating an iterator we have to call next() before accessing the value #[no_mangle] -pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass_bool_t { - let mut iter = BoxFFI::as_mut_ref(iterator); +pub unsafe extern "C" fn cass_iterator_next( + mut iterator: CassExclusiveMutPtr, +) -> cass_bool_t { + let mut iter = BoxFFI::as_mut_ref(&mut iterator).unwrap(); match &mut iter { CassIterator::CassResultIterator(result_iterator) => { @@ -476,66 +478,68 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass } #[no_mangle] -pub unsafe extern "C" fn cass_iterator_get_row(iterator: *const CassIterator) -> *const CassRow { - let iter = BoxFFI::as_ref(iterator); +pub unsafe extern "C" fn cass_iterator_get_row( + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); // Defined only for result iterator, for other types should return null if let CassIterator::CassResultIterator(result_iterator) = iter { let iter_position = match result_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result_iterator.result.kind else { - return std::ptr::null(); + return RefFFI::null(); }; let row: &CassRow = match rows.get(iter_position) { Some(row) => row, - None => return std::ptr::null(), + None => return RefFFI::null(), }; - return row; + return RefFFI::as_ptr(row); } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_column( - iterator: *const CassIterator, -) -> *const CassValue { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); // Defined only for row iterator, for other types should return null if let CassIterator::CassRowIterator(row_iterator) = iter { let iter_position = match row_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let value = match row_iterator.row.columns.get(iter_position) { Some(col) => col, - None => return std::ptr::null(), + None => return RefFFI::null(), }; - return value as *const CassValue; + return RefFFI::as_ptr(value); } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_value( - iterator: *const CassIterator, -) -> *const CassValue { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); // Defined only for collections(list, set and map) or tuple iterator, for other types should return null if let CassIterator::CassCollectionIterator(collection_iterator) = iter { let iter_position = match collection_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let value = match &collection_iterator.value.value { @@ -549,80 +553,80 @@ pub unsafe extern "C" fn cass_iterator_get_value( map.get(map_entry_index) .map(|(key, value)| if iter_position % 2 == 0 { key } else { value }) } - _ => return std::ptr::null(), + _ => return RefFFI::null(), }; if value.is_none() { - return std::ptr::null(); + return RefFFI::null(); } - return value.unwrap() as *const CassValue; + return RefFFI::as_ptr(value.unwrap()); } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_map_key( - iterator: *const CassIterator, -) -> *const CassValue { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassMapIterator(map_iterator) = iter { let iter_position = match map_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let entry = match &map_iterator.value.value { Some(Value::CollectionValue(Collection::Map(map))) => map.get(iter_position), - _ => return std::ptr::null(), + _ => return RefFFI::null(), }; if entry.is_none() { - return std::ptr::null(); + return RefFFI::null(); } - return &entry.unwrap().0 as *const CassValue; + return RefFFI::as_ptr(&entry.unwrap().0); } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_map_value( - iterator: *const CassIterator, -) -> *const CassValue { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassMapIterator(map_iterator) = iter { let iter_position = match map_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let entry = match &map_iterator.value.value { Some(Value::CollectionValue(Collection::Map(map))) => map.get(iter_position), - _ => return std::ptr::null(), + _ => return RefFFI::null(), }; if entry.is_none() { - return std::ptr::null(); + return RefFFI::null(); } - return &entry.unwrap().1 as *const CassValue; + return RefFFI::as_ptr(&entry.unwrap().1); } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( - iterator: *const CassIterator, + iterator: CassExclusiveConstPtr, name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let iter = BoxFFI::as_ref(iterator); + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassUdtIterator(udt_iterator) = iter { let iter_position = match udt_iterator.position { @@ -653,45 +657,45 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_user_type_field_value( - iterator: *const CassIterator, -) -> *const CassValue { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassUdtIterator(udt_iterator) = iter { let iter_position = match udt_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let udt_entry_opt = match &udt_iterator.value.value { Some(Value::CollectionValue(Collection::UserDefinedType { fields, .. })) => { fields.get(iter_position) } - _ => return std::ptr::null(), + _ => return RefFFI::null(), }; return match udt_entry_opt { Some(udt_entry) => match &udt_entry.1 { - Some(value) => value as *const CassValue, - None => std::ptr::null(), + Some(value) => RefFFI::as_ptr(value), + None => RefFFI::null(), }, - None => std::ptr::null(), + None => RefFFI::null(), }; } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_keyspace_meta( - iterator: *const CassIterator, -) -> *const CassKeyspaceMeta { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassSchemaMetaIterator(schema_meta_iterator) = iter { let iter_position = match schema_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let schema_meta_entry_opt = &schema_meta_iterator @@ -701,24 +705,24 @@ pub unsafe extern "C" fn cass_iterator_get_keyspace_meta( .nth(iter_position); return match schema_meta_entry_opt { - Some(schema_meta_entry) => schema_meta_entry.1 as *const CassKeyspaceMeta, - None => std::ptr::null(), + Some(schema_meta_entry) => RefFFI::as_ptr(schema_meta_entry.1), + None => RefFFI::null(), }; } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_table_meta( - iterator: *const CassIterator, -) -> *const CassTableMeta { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassKeyspaceMetaTableIterator(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let table_meta_entry_opt = keyspace_meta_iterator @@ -729,23 +733,23 @@ pub unsafe extern "C" fn cass_iterator_get_table_meta( return match table_meta_entry_opt { Some(table_meta_entry) => RefFFI::as_ptr(table_meta_entry.1.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), }; } - std::ptr::null() + RefFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_user_type( - iterator: *const CassIterator, -) -> *const CassDataType { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassSharedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); if let CassIterator::CassKeyspaceMetaUserTypeIterator(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return ArcFFI::null(), }; let udt_to_type_entry_opt = keyspace_meta_iterator @@ -756,24 +760,24 @@ pub unsafe extern "C" fn cass_iterator_get_user_type( return match udt_to_type_entry_opt { Some(udt_to_type_entry) => ArcFFI::as_ptr(udt_to_type_entry.1), - None => std::ptr::null(), + None => ArcFFI::null(), }; } - std::ptr::null() + ArcFFI::null() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_column_meta( - iterator: *const CassIterator, -) -> *const CassColumnMeta { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); match iter { CassIterator::CassTableMetaIterator(table_meta_iterator) => { let iter_position = match table_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let column_meta_entry_opt = table_meta_iterator @@ -783,14 +787,14 @@ pub unsafe extern "C" fn cass_iterator_get_column_meta( .nth(iter_position); match column_meta_entry_opt { - Some(column_meta_entry) => column_meta_entry.1 as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta_entry) => RefFFI::as_ptr(column_meta_entry.1), + None => RefFFI::null(), } } CassIterator::CassViewMetaIterator(view_meta_iterator) => { let iter_position = match view_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let column_meta_entry_opt = view_meta_iterator @@ -801,54 +805,56 @@ pub unsafe extern "C" fn cass_iterator_get_column_meta( .nth(iter_position); match column_meta_entry_opt { - Some(column_meta_entry) => column_meta_entry.1 as *const CassColumnMeta, - None => std::ptr::null(), + Some(column_meta_entry) => RefFFI::as_ptr(column_meta_entry.1), + None => RefFFI::null(), } } - _ => std::ptr::null(), + _ => RefFFI::null(), } } #[no_mangle] pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta( - iterator: *const CassIterator, -) -> *const CassMaterializedViewMeta { - let iter = BoxFFI::as_ref(iterator); + iterator: CassExclusiveConstPtr, +) -> CassBorrowedPtr { + let iter = BoxFFI::as_ref(&iterator).unwrap(); match iter { CassIterator::CassKeyspaceMetaViewIterator(keyspace_meta_iterator) => { let iter_position = match keyspace_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let view_meta_entry_opt = keyspace_meta_iterator.value.views.iter().nth(iter_position); match view_meta_entry_opt { Some(view_meta_entry) => RefFFI::as_ptr(view_meta_entry.1.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), } } CassIterator::CassTableMetaIterator(table_meta_iterator) => { let iter_position = match table_meta_iterator.position { Some(pos) => pos, - None => return std::ptr::null(), + None => return RefFFI::null(), }; let view_meta_entry_opt = table_meta_iterator.value.views.iter().nth(iter_position); match view_meta_entry_opt { Some(view_meta_entry) => RefFFI::as_ptr(view_meta_entry.1.as_ref()), - None => std::ptr::null(), + None => RefFFI::null(), } } - _ => std::ptr::null(), + _ => RefFFI::null(), } } #[no_mangle] -pub unsafe extern "C" fn cass_iterator_from_result(result: *const CassResult) -> *mut CassIterator { - let result_from_raw = ArcFFI::cloned_from_ptr(result); +pub unsafe extern "C" fn cass_iterator_from_result( + result: CassSharedPtr, +) -> CassExclusiveMutPtr { + let result_from_raw = ArcFFI::cloned_from_ptr(result).unwrap(); let iterator = CassResultIterator { result: result_from_raw, @@ -859,8 +865,10 @@ pub unsafe extern "C" fn cass_iterator_from_result(result: *const CassResult) -> } #[no_mangle] -pub unsafe extern "C" fn cass_iterator_from_row(row: *const CassRow) -> *mut CassIterator { - let row_from_raw = RefFFI::as_ref(row); +pub unsafe extern "C" fn cass_iterator_from_row( + row: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let row_from_raw = RefFFI::into_ref(row).unwrap(); let iterator = CassRowIterator { row: row_from_raw, @@ -872,20 +880,20 @@ pub unsafe extern "C" fn cass_iterator_from_row(row: *const CassRow) -> *mut Cas #[no_mangle] pub unsafe extern "C" fn cass_iterator_from_collection( - value: *const CassValue, -) -> *mut CassIterator { - let is_collection = cass_value_is_collection(value) != 0; + value: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let is_collection = value_is_collection(&value) != 0; - if value.is_null() || !is_collection { - return std::ptr::null_mut(); + if RefFFI::is_null(&value) || !is_collection { + return BoxFFI::null_mut(); } - let val = RefFFI::as_ref(value); - let item_count = cass_value_item_count(value); - let item_count = match cass_value_type(value) { + let item_count = value_item_count(&value); + let item_count = match value_type(&value) { CassValueType::CASS_VALUE_TYPE_MAP => item_count * 2, _ => item_count, }; + let val = RefFFI::into_ref(value).unwrap(); let iterator = CassCollectionIterator { value: val, @@ -897,8 +905,10 @@ pub unsafe extern "C" fn cass_iterator_from_collection( } #[no_mangle] -pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *mut CassIterator { - let tuple = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_iterator_from_tuple( + value: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let tuple = RefFFI::into_ref(value).unwrap(); if let Some(Value::CollectionValue(Collection::Tuple(val))) = &tuple.value { let item_count = val.len(); @@ -911,12 +921,14 @@ pub unsafe extern "C" fn cass_iterator_from_tuple(value: *const CassValue) -> *m return BoxFFI::into_ptr(Box::new(CassIterator::CassCollectionIterator(iterator))); } - std::ptr::null_mut() + BoxFFI::null_mut() } #[no_mangle] -pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut CassIterator { - let map = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_iterator_from_map( + value: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let map = RefFFI::into_ref(value).unwrap(); if let Some(Value::CollectionValue(Collection::Map(val))) = &map.value { let item_count = val.len(); @@ -929,14 +941,14 @@ pub unsafe extern "C" fn cass_iterator_from_map(value: *const CassValue) -> *mut return BoxFFI::into_ptr(Box::new(CassIterator::CassMapIterator(iterator))); } - std::ptr::null_mut() + BoxFFI::null_mut() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_fields_from_user_type( - value: *const CassValue, -) -> *mut CassIterator { - let udt = RefFFI::as_ref(value); + value: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let udt = RefFFI::into_ref(value).unwrap(); if let Some(Value::CollectionValue(Collection::UserDefinedType { fields, .. })) = &udt.value { let item_count = fields.len(); @@ -949,14 +961,14 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type( return BoxFFI::into_ptr(Box::new(CassIterator::CassUdtIterator(iterator))); } - std::ptr::null_mut() + BoxFFI::null_mut() } #[no_mangle] pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta( - schema_meta: *const CassSchemaMeta, -) -> *mut CassIterator { - let metadata = BoxFFI::as_ref(schema_meta); + schema_meta: CassExclusiveConstPtr, +) -> CassExclusiveMutPtr { + let metadata = BoxFFI::into_ref(schema_meta).unwrap(); let iterator = CassSchemaMetaIterator { value: metadata, @@ -969,9 +981,9 @@ pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta( #[no_mangle] pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta( - keyspace_meta: *const CassKeyspaceMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(keyspace_meta); + keyspace_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(keyspace_meta).unwrap(); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -986,9 +998,9 @@ pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta( #[no_mangle] pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta( - keyspace_meta: *const CassKeyspaceMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(keyspace_meta); + keyspace_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(keyspace_meta).unwrap(); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1003,9 +1015,9 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta( #[no_mangle] pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta( - keyspace_meta: *const CassKeyspaceMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(keyspace_meta); + keyspace_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(keyspace_meta).unwrap(); let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1020,9 +1032,9 @@ pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta( #[no_mangle] pub unsafe extern "C" fn cass_iterator_columns_from_table_meta( - table_meta: *const CassTableMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(table_meta); + table_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(table_meta).unwrap(); let iterator = CassTableMetaIterator { value: metadata, @@ -1034,9 +1046,9 @@ pub unsafe extern "C" fn cass_iterator_columns_from_table_meta( } pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta( - table_meta: *const CassTableMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(table_meta); + table_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(table_meta).unwrap(); let iterator = CassTableMetaIterator { value: metadata, @@ -1048,9 +1060,9 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta( } pub unsafe extern "C" fn cass_iterator_columns_from_materialized_view_meta( - view_meta: *const CassMaterializedViewMeta, -) -> *mut CassIterator { - let metadata = RefFFI::as_ref(view_meta); + view_meta: CassBorrowedPtr, +) -> CassExclusiveMutPtr { + let metadata = RefFFI::into_ref(view_meta).unwrap(); let iterator = CassViewMetaIterator { value: metadata, @@ -1062,37 +1074,43 @@ pub unsafe extern "C" fn cass_iterator_columns_from_materialized_view_meta( } #[no_mangle] -pub unsafe extern "C" fn cass_result_free(result_raw: *const CassResult) { +pub unsafe extern "C" fn cass_result_free(result_raw: CassSharedPtr) { ArcFFI::free(result_raw); } #[no_mangle] -pub unsafe extern "C" fn cass_result_has_more_pages(result: *const CassResult) -> cass_bool_t { - let result = ArcFFI::as_ref(result); +pub unsafe extern "C" fn cass_result_has_more_pages( + result: CassSharedPtr, +) -> cass_bool_t { + result_has_more_pages(&result) +} + +unsafe fn result_has_more_pages(result: &CassSharedPtr) -> cass_bool_t { + let result = ArcFFI::as_ref(result).unwrap(); (!result.paging_state_response.finished()) as cass_bool_t } #[no_mangle] pub unsafe extern "C" fn cass_row_get_column( - row_raw: *const CassRow, + row_raw: CassBorrowedPtr, index: size_t, -) -> *const CassValue { - let row: &CassRow = RefFFI::as_ref(row_raw); +) -> CassBorrowedPtr { + let row: &CassRow = RefFFI::as_ref(&row_raw).unwrap(); let index_usize: usize = index.try_into().unwrap(); let column_value = match row.columns.get(index_usize) { Some(val) => val, - None => return std::ptr::null(), + None => return RefFFI::null(), }; - column_value as *const CassValue + RefFFI::as_ptr(column_value) } #[no_mangle] pub unsafe extern "C" fn cass_row_get_column_by_name( - row: *const CassRow, + row: CassBorrowedPtr, name: *const c_char, -) -> *const CassValue { +) -> CassBorrowedPtr { let name_str = ptr_to_cstr(name).unwrap(); let name_length = name_str.len(); @@ -1101,11 +1119,11 @@ pub unsafe extern "C" fn cass_row_get_column_by_name( #[no_mangle] pub unsafe extern "C" fn cass_row_get_column_by_name_n( - row: *const CassRow, + row: CassBorrowedPtr, name: *const c_char, name_length: size_t, -) -> *const CassValue { - let row_from_raw = RefFFI::as_ref(row); +) -> CassBorrowedPtr { + let row_from_raw = RefFFI::as_ref(&row).unwrap(); let mut name_str = ptr_to_cstr_n(name, name_length).unwrap(); let mut is_case_sensitive = false; @@ -1125,20 +1143,20 @@ pub unsafe extern "C" fn cass_row_get_column_by_name_n( || !is_case_sensitive && col_spec.name.eq_ignore_ascii_case(name_str) }) .map(|(index, _)| match row_from_raw.columns.get(index) { - Some(value) => value as *const CassValue, - None => std::ptr::null(), + Some(value) => RefFFI::as_ptr(value), + None => RefFFI::null(), }) - .unwrap_or(std::ptr::null()) + .unwrap_or(RefFFI::null()) } #[no_mangle] pub unsafe extern "C" fn cass_result_column_name( - result: *const CassResult, + result: CassSharedPtr, index: size_t, name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let result_from_raw = ArcFFI::as_ref(result); + let result_from_raw = ArcFFI::as_ref(&result).unwrap(); let index_usize: usize = index.try_into().unwrap(); let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result_from_raw.kind else { @@ -1158,11 +1176,11 @@ pub unsafe extern "C" fn cass_result_column_name( #[no_mangle] pub unsafe extern "C" fn cass_result_column_type( - result: *const CassResult, + result: CassSharedPtr, index: size_t, ) -> CassValueType { let data_type_ptr = cass_result_column_data_type(result, index); - if data_type_ptr.is_null() { + if ArcFFI::is_null(&data_type_ptr) { return CassValueType::CASS_VALUE_TYPE_UNKNOWN; } cass_data_type_type(data_type_ptr) @@ -1170,51 +1188,58 @@ pub unsafe extern "C" fn cass_result_column_type( #[no_mangle] pub unsafe extern "C" fn cass_result_column_data_type( - result: *const CassResult, + result: CassSharedPtr, index: size_t, -) -> *const CassDataType { - let result_from_raw: &CassResult = ArcFFI::as_ref(result); +) -> CassSharedPtr { + let result_from_raw: &CassResult = ArcFFI::as_ref(&result).unwrap(); let index_usize: usize = index .try_into() .expect("Provided index is out of bounds. Max possible value is usize::MAX"); let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result_from_raw.kind else { - return std::ptr::null(); + return ArcFFI::null(); }; metadata .col_specs .get(index_usize) .map(|col_spec| ArcFFI::as_ptr(&col_spec.data_type)) - .unwrap_or(std::ptr::null()) + .unwrap_or(ArcFFI::null()) } #[no_mangle] -pub unsafe extern "C" fn cass_value_type(value: *const CassValue) -> CassValueType { - let value_from_raw = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_value_type(value: CassBorrowedPtr) -> CassValueType { + value_type(&value) +} + +unsafe fn value_type(value: &CassBorrowedPtr) -> CassValueType { + let value_from_raw = RefFFI::as_ref(value).unwrap(); cass_data_type_type(ArcFFI::as_ptr(&value_from_raw.value_type)) } #[no_mangle] -pub unsafe extern "C" fn cass_value_data_type(value: *const CassValue) -> *const CassDataType { - let value_from_raw = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_value_data_type( + value: CassBorrowedPtr, +) -> CassSharedPtr { + let value_from_raw = RefFFI::as_ref(&value).unwrap(); ArcFFI::as_ptr(&value_from_raw.value_type) } macro_rules! val_ptr_to_ref_ensure_non_null { ($ptr:ident) => {{ - if $ptr.is_null() { - return CassError::CASS_ERROR_LIB_NULL_VALUE; + let maybe_ref = RefFFI::as_ref(&$ptr); + match maybe_ref { + Some(r) => r, + None => return CassError::CASS_ERROR_LIB_NULL_VALUE, } - RefFFI::as_ref($ptr) }}; } #[no_mangle] pub unsafe extern "C" fn cass_value_get_float( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_float_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1229,7 +1254,7 @@ pub unsafe extern "C" fn cass_value_get_float( #[no_mangle] pub unsafe extern "C" fn cass_value_get_double( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_double_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1244,7 +1269,7 @@ pub unsafe extern "C" fn cass_value_get_double( #[no_mangle] pub unsafe extern "C" fn cass_value_get_bool( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_bool_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1261,7 +1286,7 @@ pub unsafe extern "C" fn cass_value_get_bool( #[no_mangle] pub unsafe extern "C" fn cass_value_get_int8( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_int8_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1276,7 +1301,7 @@ pub unsafe extern "C" fn cass_value_get_int8( #[no_mangle] pub unsafe extern "C" fn cass_value_get_int16( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_int16_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1291,7 +1316,7 @@ pub unsafe extern "C" fn cass_value_get_int16( #[no_mangle] pub unsafe extern "C" fn cass_value_get_uint32( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_uint32_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1306,7 +1331,7 @@ pub unsafe extern "C" fn cass_value_get_uint32( #[no_mangle] pub unsafe extern "C" fn cass_value_get_int32( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_int32_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1321,7 +1346,7 @@ pub unsafe extern "C" fn cass_value_get_int32( #[no_mangle] pub unsafe extern "C" fn cass_value_get_int64( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut cass_int64_t, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1343,7 +1368,7 @@ pub unsafe extern "C" fn cass_value_get_int64( #[no_mangle] pub unsafe extern "C" fn cass_value_get_uuid( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut CassUuid, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1361,7 +1386,7 @@ pub unsafe extern "C" fn cass_value_get_uuid( #[no_mangle] pub unsafe extern "C" fn cass_value_get_inet( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut CassInet, ) -> CassError { let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); @@ -1376,12 +1401,12 @@ pub unsafe extern "C" fn cass_value_get_inet( #[no_mangle] pub unsafe extern "C" fn cass_value_get_decimal( - value: *const CassValue, + value: CassBorrowedPtr, varint: *mut *const cass_byte_t, varint_size: *mut size_t, scale: *mut cass_int32_t, ) -> CassError { - let val: &CassValue = RefFFI::as_ref(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); let decimal = match &val.value { Some(Value::RegularValue(CqlValue::Decimal(decimal))) => decimal, Some(_) => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, @@ -1398,7 +1423,7 @@ pub unsafe extern "C" fn cass_value_get_decimal( #[no_mangle] pub unsafe extern "C" fn cass_value_get_string( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut *const c_char, output_size: *mut size_t, ) -> CassError { @@ -1423,7 +1448,7 @@ pub unsafe extern "C" fn cass_value_get_string( #[no_mangle] pub unsafe extern "C" fn cass_value_get_duration( - value: *const CassValue, + value: CassBorrowedPtr, months: *mut cass_int32_t, days: *mut cass_int32_t, nanos: *mut cass_int64_t, @@ -1445,7 +1470,7 @@ pub unsafe extern "C" fn cass_value_get_duration( #[no_mangle] pub unsafe extern "C" fn cass_value_get_bytes( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut *const cass_byte_t, output_size: *mut size_t, ) -> CassError { @@ -1471,14 +1496,20 @@ pub unsafe extern "C" fn cass_value_get_bytes( } #[no_mangle] -pub unsafe extern "C" fn cass_value_is_null(value: *const CassValue) -> cass_bool_t { - let val: &CassValue = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_value_is_null(value: CassBorrowedPtr) -> cass_bool_t { + let val: &CassValue = RefFFI::as_ref(&value).unwrap(); val.value.is_none() as cass_bool_t } #[no_mangle] -pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> cass_bool_t { - let val = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_value_is_collection( + value: CassBorrowedPtr, +) -> cass_bool_t { + value_is_collection(&value) +} + +unsafe fn value_is_collection(value: &CassBorrowedPtr) -> cass_bool_t { + let val = RefFFI::as_ref(value).unwrap(); matches!( val.value_type.get_unchecked().get_value_type(), @@ -1489,16 +1520,20 @@ pub unsafe extern "C" fn cass_value_is_collection(value: *const CassValue) -> ca } #[no_mangle] -pub unsafe extern "C" fn cass_value_is_duration(value: *const CassValue) -> cass_bool_t { - let val = RefFFI::as_ref(value); +pub unsafe extern "C" fn cass_value_is_duration(value: CassBorrowedPtr) -> cass_bool_t { + let val = RefFFI::as_ref(&value).unwrap(); (val.value_type.get_unchecked().get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) as cass_bool_t } #[no_mangle] -pub unsafe extern "C" fn cass_value_item_count(collection: *const CassValue) -> size_t { - let val = RefFFI::as_ref(collection); +pub unsafe extern "C" fn cass_value_item_count(collection: CassBorrowedPtr) -> size_t { + value_item_count(&collection) +} + +unsafe fn value_item_count(collection: &CassBorrowedPtr) -> size_t { + let val = RefFFI::as_ref(collection).unwrap(); match &val.value { Some(Value::CollectionValue(Collection::List(list))) => list.len() as size_t, @@ -1514,9 +1549,9 @@ pub unsafe extern "C" fn cass_value_item_count(collection: *const CassValue) -> #[no_mangle] pub unsafe extern "C" fn cass_value_primary_sub_type( - collection: *const CassValue, + collection: CassBorrowedPtr, ) -> CassValueType { - let val = RefFFI::as_ref(collection); + let val = RefFFI::as_ref(&collection).unwrap(); match val.value_type.get_unchecked() { CassDataTypeInner::List { @@ -1533,9 +1568,9 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( #[no_mangle] pub unsafe extern "C" fn cass_value_secondary_sub_type( - collection: *const CassValue, + collection: CassBorrowedPtr, ) -> CassValueType { - let val = RefFFI::as_ref(collection); + let val = RefFFI::as_ref(&collection).unwrap(); match val.value_type.get_unchecked() { CassDataTypeInner::Map { @@ -1547,8 +1582,8 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( } #[no_mangle] -pub unsafe extern "C" fn cass_result_row_count(result_raw: *const CassResult) -> size_t { - let result = ArcFFI::as_ref(result_raw); +pub unsafe extern "C" fn cass_result_row_count(result_raw: CassSharedPtr) -> size_t { + let result = ArcFFI::as_ref(&result_raw).unwrap(); let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { return 0; @@ -1558,8 +1593,8 @@ pub unsafe extern "C" fn cass_result_row_count(result_raw: *const CassResult) -> } #[no_mangle] -pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) -> size_t { - let result = ArcFFI::as_ref(result_raw); +pub unsafe extern "C" fn cass_result_column_count(result_raw: CassSharedPtr) -> size_t { + let result = ArcFFI::as_ref(&result_raw).unwrap(); let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result.kind else { return 0; @@ -1569,29 +1604,29 @@ pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) } #[no_mangle] -pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> *const CassRow { - let result = ArcFFI::as_ref(result_raw); +pub unsafe extern "C" fn cass_result_first_row( + result_raw: CassSharedPtr, +) -> CassBorrowedPtr { + let result = ArcFFI::as_ref(&result_raw).unwrap(); let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { - return std::ptr::null(); + return RefFFI::null(); }; - rows.first() - .map(|row| row as *const CassRow) - .unwrap_or(std::ptr::null()) + rows.first().map(RefFFI::as_ptr).unwrap_or(RefFFI::null()) } #[no_mangle] pub unsafe extern "C" fn cass_result_paging_state_token( - result: *const CassResult, + result: CassSharedPtr, paging_state: *mut *const c_char, paging_state_size: *mut size_t, ) -> CassError { - if cass_result_has_more_pages(result) == cass_false { + if result_has_more_pages(&result) == cass_false { return CassError::CASS_ERROR_LIB_NO_PAGING_STATE; } - let result_from_raw = ArcFFI::as_ref(result); + let result_from_raw = ArcFFI::as_ref(&result).unwrap(); match &result_from_raw.paging_state_response { PagingStateResponse::HasMorePages { state } => match state.as_bytes_slice() { @@ -1623,7 +1658,7 @@ mod tests { }; use crate::{ - argconv::ArcFFI, + argconv::{ArcFFI, RefFFI}, cass_error::CassError, cass_types::{CassDataType, CassDataTypeInner, CassValueType}, query_result::{ @@ -1634,7 +1669,7 @@ mod tests { use super::{ cass_result_column_count, cass_result_column_type, create_cass_rows_from_rows, CassResult, - CassResultKind, CassResultMetadata, CassRowsResult, + CassResultKind, CassResultMetadata, CassRowsResult, CassSharedPtr, }; fn col_spec(name: &'static str, typ: ColumnType<'static>) -> ColumnSpec<'static> { @@ -1677,7 +1712,7 @@ mod tests { } unsafe fn cass_result_column_name_rust_str( - result_ptr: *const CassResult, + result_ptr: CassSharedPtr, column_index: u64, ) -> Option<&'static str> { let mut name_ptr: *const c_char = std::ptr::null(); @@ -1694,36 +1729,39 @@ mod tests { #[test] fn rows_cass_result_api_test() { - let result = create_cass_rows_result(); + let result = Arc::new(create_cass_rows_result()); unsafe { - let result_ptr = std::ptr::addr_of!(result); + let result_ptr = ArcFFI::as_ptr(&result); // cass_result_column_count test { - let column_count = cass_result_column_count(result_ptr); + let column_count = cass_result_column_count(result_ptr.clone()); assert_eq!(3, column_count); } // cass_result_column_name test { - let first_column_name = cass_result_column_name_rust_str(result_ptr, 0).unwrap(); + let first_column_name = + cass_result_column_name_rust_str(result_ptr.clone(), 0).unwrap(); assert_eq!(FIRST_COLUMN_NAME, first_column_name); - let second_column_name = cass_result_column_name_rust_str(result_ptr, 1).unwrap(); + let second_column_name = + cass_result_column_name_rust_str(result_ptr.clone(), 1).unwrap(); assert_eq!(SECOND_COLUMN_NAME, second_column_name); - let third_column_name = cass_result_column_name_rust_str(result_ptr, 2).unwrap(); + let third_column_name = + cass_result_column_name_rust_str(result_ptr.clone(), 2).unwrap(); assert_eq!(THIRD_COLUMN_NAME, third_column_name); } // cass_result_column_type test { - let first_col_type = cass_result_column_type(result_ptr, 0); + let first_col_type = cass_result_column_type(result_ptr.clone(), 0); assert_eq!(CassValueType::CASS_VALUE_TYPE_BIGINT, first_col_type); - let second_col_type = cass_result_column_type(result_ptr, 1); + let second_col_type = cass_result_column_type(result_ptr.clone(), 1); assert_eq!(CassValueType::CASS_VALUE_TYPE_VARINT, second_col_type); - let third_col_type = cass_result_column_type(result_ptr, 2); + let third_col_type = cass_result_column_type(result_ptr.clone(), 2); assert_eq!(CassValueType::CASS_VALUE_TYPE_LIST, third_col_type); - let out_of_bound_col_type = cass_result_column_type(result_ptr, 555); + let out_of_bound_col_type = cass_result_column_type(result_ptr.clone(), 555); assert_eq!( CassValueType::CASS_VALUE_TYPE_UNKNOWN, out_of_bound_col_type @@ -1732,24 +1770,24 @@ mod tests { // cass_result_column_data_type test { - let first_col_data_type = - ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 0)); + let first_col_data_type_ptr = cass_result_column_data_type(result_ptr.clone(), 0); + let first_col_data_type = ArcFFI::as_ref(&first_col_data_type_ptr).unwrap(); assert_eq!( &CassDataType::new(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_BIGINT )), first_col_data_type ); - let second_col_data_type = - ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 1)); + let second_col_data_type_ptr = cass_result_column_data_type(result_ptr.clone(), 1); + let second_col_data_type = ArcFFI::as_ref(&second_col_data_type_ptr).unwrap(); assert_eq!( &CassDataType::new(CassDataTypeInner::Value( CassValueType::CASS_VALUE_TYPE_VARINT )), second_col_data_type ); - let third_col_data_type = - ArcFFI::as_ref(cass_result_column_data_type(result_ptr, 2)); + let third_col_data_type_ptr = cass_result_column_data_type(result_ptr.clone(), 2); + let third_col_data_type = ArcFFI::as_ref(&third_col_data_type_ptr).unwrap(); assert_eq!( &CassDataType::new(CassDataTypeInner::List { typ: Some(CassDataType::new_arced(CassDataTypeInner::Value( @@ -1760,7 +1798,7 @@ mod tests { third_col_data_type ); let out_of_bound_col_data_type = cass_result_column_data_type(result_ptr, 555); - assert!(out_of_bound_col_data_type.is_null()); + assert!(ArcFFI::is_null(&out_of_bound_col_data_type)); } } } @@ -1775,19 +1813,22 @@ mod tests { #[test] fn non_rows_cass_result_api_test() { - let result = create_non_rows_cass_result(); + let result = Arc::new(create_non_rows_cass_result()); // Check that API functions do not panic when rows are empty - e.g. for INSERT queries. unsafe { - let result_ptr = std::ptr::addr_of!(result); + let result_ptr = ArcFFI::as_ptr(&result); - assert_eq!(0, cass_result_column_count(result_ptr)); + assert_eq!(0, cass_result_column_count(result_ptr.clone())); assert_eq!( CassValueType::CASS_VALUE_TYPE_UNKNOWN, - cass_result_column_type(result_ptr, 0) + cass_result_column_type(result_ptr.clone(), 0) ); - assert!(cass_result_column_data_type(result_ptr, 0).is_null()); - assert!(cass_result_first_row(result_ptr).is_null()); + assert!(ArcFFI::is_null(&cass_result_column_data_type( + result_ptr.clone(), + 0 + ))); + assert!(RefFFI::is_null(&cass_result_first_row(result_ptr.clone()))); { let mut name_ptr: *const c_char = std::ptr::null(); @@ -1809,41 +1850,41 @@ mod tests { extern "C" { pub fn cass_statement_set_paging_state( statement: *mut CassStatement, - result: *const CassResult, + result: CassSharedPtr, ) -> CassError; } extern "C" { - pub fn cass_result_row_count(result: *const CassResult) -> size_t; + pub fn cass_result_row_count(result: CassSharedPtr) -> size_t; } extern "C" { - pub fn cass_result_column_count(result: *const CassResult) -> size_t; + pub fn cass_result_column_count(result: CassSharedPtr) -> size_t; } extern "C" { pub fn cass_result_column_name( - result: *const CassResult, + result: CassSharedPtr, index: size_t, name: *mut *const ::std::os::raw::c_char, name_length: *mut size_t, ) -> CassError; } extern "C" { - pub fn cass_result_column_type(result: *const CassResult, index: size_t) -> CassValueType; + pub fn cass_result_column_type(result: CassSharedPtr, index: size_t) -> CassValueType; } extern "C" { pub fn cass_result_column_data_type( - result: *const CassResult, + result: CassSharedPtr, index: size_t, ) -> *const CassDataType; } extern "C" { - pub fn cass_result_first_row(result: *const CassResult) -> *const CassRow; + pub fn cass_result_first_row(result: CassSharedPtr) -> CassBorrowedPtr; } extern "C" { - pub fn cass_result_has_more_pages(result: *const CassResult) -> cass_bool_t; + pub fn cass_result_has_more_pages(result: CassSharedPtr) -> cass_bool_t; } extern "C" { pub fn cass_result_paging_state_token( - result: *const CassResult, + result: CassSharedPtr, paging_state: *mut *const ::std::os::raw::c_char, paging_state_size: *mut size_t, ) -> CassError; @@ -1853,120 +1894,120 @@ extern "C" { // CassIterator functions: /* extern "C" { - pub fn cass_iterator_type(iterator: *mut CassIterator) -> CassIteratorType; + pub fn cass_iterator_type(iterator: CassExclusiveMutPtr) -> CassIteratorType; } extern "C" { - pub fn cass_iterator_from_row(row: *const CassRow) -> *mut CassIterator; + pub fn cass_iterator_from_row(row: CassBorrowedPtr) -> CassExclusiveMutPtr; } extern "C" { - pub fn cass_iterator_from_collection(value: *const CassValue) -> *mut CassIterator; + pub fn cass_iterator_from_collection(value: CassBorrowedPtr) -> CassExclusiveMutPtr; } extern "C" { - pub fn cass_iterator_from_map(value: *const CassValue) -> *mut CassIterator; + pub fn cass_iterator_from_map(value: CassBorrowedPtr) -> CassExclusiveMutPtr; } extern "C" { - pub fn cass_iterator_from_tuple(value: *const CassValue) -> *mut CassIterator; + pub fn cass_iterator_from_tuple(value: CassBorrowedPtr) -> CassExclusiveMutPtr; } extern "C" { - pub fn cass_iterator_fields_from_user_type(value: *const CassValue) -> *mut CassIterator; + pub fn cass_iterator_fields_from_user_type(value: CassBorrowedPtr) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_keyspaces_from_schema_meta( schema_meta: *const CassSchemaMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_tables_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_materialized_views_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_user_types_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_functions_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_aggregates_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_keyspace_meta( keyspace_meta: *const CassKeyspaceMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_columns_from_table_meta( table_meta: *const CassTableMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_indexes_from_table_meta( table_meta: *const CassTableMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_materialized_views_from_table_meta( table_meta: *const CassTableMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_table_meta( table_meta: *const CassTableMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_columns_from_materialized_view_meta( view_meta: *const CassMaterializedViewMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_materialized_view_meta( view_meta: *const CassMaterializedViewMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_column_meta( column_meta: *const CassColumnMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_index_meta( index_meta: *const CassIndexMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_function_meta( function_meta: *const CassFunctionMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { pub fn cass_iterator_fields_from_aggregate_meta( aggregate_meta: *const CassAggregateMeta, - ) -> *mut CassIterator; + ) -> CassExclusiveMutPtr; } extern "C" { - pub fn cass_iterator_get_column(iterator: *const CassIterator) -> *const CassValue; + pub fn cass_iterator_get_column(iterator: *const CassIterator) -> CassBorrowedPtr; } extern "C" { - pub fn cass_iterator_get_value(iterator: *const CassIterator) -> *const CassValue; + pub fn cass_iterator_get_value(iterator: *const CassIterator) -> CassBorrowedPtr; } extern "C" { - pub fn cass_iterator_get_map_key(iterator: *const CassIterator) -> *const CassValue; + pub fn cass_iterator_get_map_key(iterator: *const CassIterator) -> CassBorrowedPtr; } extern "C" { - pub fn cass_iterator_get_map_value(iterator: *const CassIterator) -> *const CassValue; + pub fn cass_iterator_get_map_value(iterator: *const CassIterator) -> CassBorrowedPtr; } extern "C" { pub fn cass_iterator_get_user_type_field_name( @@ -1978,7 +2019,7 @@ extern "C" { extern "C" { pub fn cass_iterator_get_user_type_field_value( iterator: *const CassIterator, - ) -> *const CassValue; + ) -> CassBorrowedPtr; } extern "C" { pub fn cass_iterator_get_keyspace_meta( @@ -2020,7 +2061,7 @@ extern "C" { ) -> CassError; } extern "C" { - pub fn cass_iterator_get_meta_field_value(iterator: *const CassIterator) -> *const CassValue; + pub fn cass_iterator_get_meta_field_value(iterator: *const CassIterator) -> CassBorrowedPtr; } */ @@ -2028,16 +2069,16 @@ extern "C" { /* extern "C" { pub fn cass_row_get_column_by_name( - row: *const CassRow, + row: CassBorrowedPtr, name: *const ::std::os::raw::c_char, - ) -> *const CassValue; + ) -> CassBorrowedPtr; } extern "C" { pub fn cass_row_get_column_by_name_n( - row: *const CassRow, + row: CassBorrowedPtr, name: *const ::std::os::raw::c_char, name_length: size_t, - ) -> *const CassValue; + ) -> CassBorrowedPtr; } */ @@ -2045,24 +2086,24 @@ extern "C" { /* #[no_mangle] pub unsafe extern "C" fn cass_value_get_bytes( - value: *const CassValue, + value: CassBorrowedPtr, output: *mut *const cass_byte_t, output_size: *mut size_t, ) -> CassError { } extern "C" { - pub fn cass_value_data_type(value: *const CassValue) -> *const CassDataType; + pub fn cass_value_data_type(value: CassBorrowedPtr) -> *const CassDataType; } extern "C" { - pub fn cass_value_type(value: *const CassValue) -> CassValueType; + pub fn cass_value_type(value: CassBorrowedPtr) -> CassValueType; } extern "C" { - pub fn cass_value_item_count(collection: *const CassValue) -> size_t; + pub fn cass_value_item_count(collection: CassBorrowedPtr) -> size_t; } extern "C" { - pub fn cass_value_primary_sub_type(collection: *const CassValue) -> CassValueType; + pub fn cass_value_primary_sub_type(collection: CassBorrowedPtr) -> CassValueType; } extern "C" { - pub fn cass_value_secondary_sub_type(collection: *const CassValue) -> CassValueType; + pub fn cass_value_secondary_sub_type(collection: CassBorrowedPtr) -> CassValueType; } */ diff --git a/scylla-rust-wrapper/src/retry_policy.rs b/scylla-rust-wrapper/src/retry_policy.rs index 6945c32a..210c76d5 100644 --- a/scylla-rust-wrapper/src/retry_policy.rs +++ b/scylla-rust-wrapper/src/retry_policy.rs @@ -2,7 +2,7 @@ use scylla::retry_policy::{DefaultRetryPolicy, FallthroughRetryPolicy}; use scylla::transport::downgrading_consistency_retry_policy::DowngradingConsistencyRetryPolicy; use std::sync::Arc; -use crate::argconv::ArcFFI; +use crate::argconv::{ArcFFI, CassSharedPtr}; pub enum RetryPolicy { DefaultRetryPolicy(Arc), @@ -15,27 +15,28 @@ pub type CassRetryPolicy = RetryPolicy; impl ArcFFI for CassRetryPolicy {} #[no_mangle] -pub extern "C" fn cass_retry_policy_default_new() -> *const CassRetryPolicy { +pub extern "C" fn cass_retry_policy_default_new() -> CassSharedPtr { ArcFFI::into_ptr(Arc::new(RetryPolicy::DefaultRetryPolicy(Arc::new( DefaultRetryPolicy, )))) } #[no_mangle] -pub extern "C" fn cass_retry_policy_downgrading_consistency_new() -> *const CassRetryPolicy { +pub extern "C" fn cass_retry_policy_downgrading_consistency_new() -> CassSharedPtr +{ ArcFFI::into_ptr(Arc::new(RetryPolicy::DowngradingConsistencyRetryPolicy( Arc::new(DowngradingConsistencyRetryPolicy), ))) } #[no_mangle] -pub extern "C" fn cass_retry_policy_fallthrough_new() -> *const CassRetryPolicy { +pub extern "C" fn cass_retry_policy_fallthrough_new() -> CassSharedPtr { ArcFFI::into_ptr(Arc::new(RetryPolicy::FallthroughRetryPolicy(Arc::new( FallthroughRetryPolicy, )))) } #[no_mangle] -pub unsafe extern "C" fn cass_retry_policy_free(retry_policy: *const CassRetryPolicy) { +pub unsafe extern "C" fn cass_retry_policy_free(retry_policy: CassSharedPtr) { ArcFFI::free(retry_policy); } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 1517cebc..62b06dc5 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -75,7 +75,7 @@ impl CassSessionInner { session_opt: &'static RwLock>, cluster: &CassCluster, keyspace: Option, - ) -> *const CassFuture { + ) -> CassSharedPtr { let session_builder = build_session_builder(cluster); let exec_profile_map = cluster.execution_profile_map().clone(); @@ -142,40 +142,40 @@ pub type CassSession = RwLock>; impl ArcFFI for CassSession {} #[no_mangle] -pub unsafe extern "C" fn cass_session_new() -> *mut CassSession { +pub unsafe extern "C" fn cass_session_new() -> CassSharedPtr { let session = Arc::new(RwLock::new(None::)); - ArcFFI::into_ptr(session) as *mut CassSession + ArcFFI::into_ptr(session) } #[no_mangle] pub unsafe extern "C" fn cass_session_connect( - session_raw: *mut CassSession, - cluster_raw: *const CassCluster, -) -> *const CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); - let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); + session_raw: CassSharedPtr, + cluster_raw: CassExclusiveConstPtr, +) -> CassSharedPtr { + let session_opt = ArcFFI::into_ref(session_raw).unwrap(); + let cluster: &CassCluster = BoxFFI::as_ref(&cluster_raw).unwrap(); CassSessionInner::connect(session_opt, cluster, None) } #[no_mangle] pub unsafe extern "C" fn cass_session_connect_keyspace( - session_raw: *mut CassSession, - cluster_raw: *const CassCluster, + session_raw: CassSharedPtr, + cluster_raw: CassExclusiveConstPtr, keyspace: *const c_char, -) -> *const CassFuture { +) -> CassSharedPtr { cass_session_connect_keyspace_n(session_raw, cluster_raw, keyspace, strlen(keyspace)) } #[no_mangle] pub unsafe extern "C" fn cass_session_connect_keyspace_n( - session_raw: *mut CassSession, - cluster_raw: *const CassCluster, + session_raw: CassSharedPtr, + cluster_raw: CassExclusiveConstPtr, keyspace: *const c_char, keyspace_length: size_t, -) -> *const CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); - let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw); +) -> CassSharedPtr { + let session_opt = ArcFFI::into_ref(session_raw).unwrap(); + let cluster: &CassCluster = BoxFFI::as_ref(&cluster_raw).unwrap(); let keyspace = ptr_to_cstr_n(keyspace, keyspace_length).map(ToOwned::to_owned); CassSessionInner::connect(session_opt, cluster, keyspace) @@ -183,11 +183,11 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n( #[no_mangle] pub unsafe extern "C" fn cass_session_execute_batch( - session_raw: *mut CassSession, - batch_raw: *const CassBatch, -) -> *const CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); - let batch_from_raw = BoxFFI::as_ref(batch_raw); + session_raw: CassSharedPtr, + batch_raw: CassExclusiveConstPtr, +) -> CassSharedPtr { + let session_opt = ArcFFI::into_ref(session_raw).unwrap(); + let batch_from_raw = BoxFFI::as_ref(&batch_raw).unwrap(); let mut state = batch_from_raw.state.clone(); let request_timeout_ms = batch_from_raw.batch_request_timeout_ms; @@ -249,13 +249,13 @@ async fn request_with_timeout( #[no_mangle] pub unsafe extern "C" fn cass_session_execute( - session_raw: *mut CassSession, - statement_raw: *const CassStatement, -) -> *const CassFuture { - let session_opt = ArcFFI::as_ref(session_raw); + session_raw: CassSharedPtr, + statement_raw: CassExclusiveConstPtr, +) -> CassSharedPtr { + let session_opt = ArcFFI::into_ref(session_raw).unwrap(); // DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault. - let statement_opt = BoxFFI::as_ref(statement_raw); + let statement_opt = BoxFFI::as_ref(&statement_raw).unwrap(); let paging_state = statement_opt.paging_state.clone(); let paging_enabled = statement_opt.paging_enabled; let bound_values = statement_opt.bound_values.clone(); @@ -376,11 +376,11 @@ pub unsafe extern "C" fn cass_session_execute( #[no_mangle] pub unsafe extern "C" fn cass_session_prepare_from_existing( - cass_session: *mut CassSession, - statement: *const CassStatement, -) -> *const CassFuture { - let session = ArcFFI::as_ref(cass_session); - let cass_statement = BoxFFI::as_ref(statement); + cass_session: CassSharedPtr, + statement: CassExclusiveConstPtr, +) -> CassSharedPtr { + let session = ArcFFI::into_ref(cass_session).unwrap(); + let cass_statement = BoxFFI::as_ref(&statement).unwrap(); let statement = cass_statement.statement.clone(); CassFuture::make_raw(async move { @@ -412,18 +412,18 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( #[no_mangle] pub unsafe extern "C" fn cass_session_prepare( - session: *mut CassSession, + session: CassSharedPtr, query: *const c_char, -) -> *const CassFuture { +) -> CassSharedPtr { cass_session_prepare_n(session, query, strlen(query)) } #[no_mangle] pub unsafe extern "C" fn cass_session_prepare_n( - cass_session_raw: *mut CassSession, + cass_session_raw: CassSharedPtr, query: *const c_char, query_length: size_t, -) -> *const CassFuture { +) -> CassSharedPtr { let query_str = ptr_to_cstr_n(query, query_length) // Apparently nullptr denotes an empty statement string. // It seems to be intended (for some weird reason, why not save a round-trip???) @@ -431,7 +431,7 @@ pub unsafe extern "C" fn cass_session_prepare_n( // There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`. .unwrap_or_default(); let query = Query::new(query_str.to_string()); - let cass_session = ArcFFI::as_ref(cass_session_raw); + let cass_session = ArcFFI::into_ref(cass_session_raw).unwrap(); CassFuture::make_raw(async move { let session_guard = cass_session.read().await; @@ -458,13 +458,15 @@ pub unsafe extern "C" fn cass_session_prepare_n( } #[no_mangle] -pub unsafe extern "C" fn cass_session_free(session_raw: *mut CassSession) { +pub unsafe extern "C" fn cass_session_free(session_raw: CassSharedPtr) { ArcFFI::free(session_raw); } #[no_mangle] -pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const CassFuture { - let session_opt = ArcFFI::as_ref(session); +pub unsafe extern "C" fn cass_session_close( + session: CassSharedPtr, +) -> CassSharedPtr { + let session_opt = ArcFFI::into_ref(session).unwrap(); CassFuture::make_raw(async move { let mut session_guard = session_opt.write().await; @@ -482,8 +484,10 @@ pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const } #[no_mangle] -pub unsafe extern "C" fn cass_session_get_client_id(session: *const CassSession) -> CassUuid { - let cass_session = ArcFFI::as_ref(session); +pub unsafe extern "C" fn cass_session_get_client_id( + session: CassSharedPtr, +) -> CassUuid { + let cass_session = ArcFFI::as_ref(&session).unwrap(); let client_id: uuid::Uuid = cass_session.blocking_read().as_ref().unwrap().client_id; client_id.into() @@ -491,9 +495,9 @@ pub unsafe extern "C" fn cass_session_get_client_id(session: *const CassSession) #[no_mangle] pub unsafe extern "C" fn cass_session_get_schema_meta( - session: *const CassSession, -) -> *const CassSchemaMeta { - let cass_session = ArcFFI::as_ref(session); + session: CassSharedPtr, +) -> CassExclusiveConstPtr { + let cass_session = ArcFFI::as_ref(&session).unwrap(); let mut keyspaces: HashMap = HashMap::new(); for (keyspace_name, keyspace) in cass_session @@ -602,7 +606,9 @@ mod tests { future::{ cass_future_error_code, cass_future_error_message, cass_future_free, cass_future_wait, }, - retry_policy::{cass_retry_policy_default_new, cass_retry_policy_fallthrough_new}, + retry_policy::{ + cass_retry_policy_default_new, cass_retry_policy_fallthrough_new, CassRetryPolicy, + }, statement::{cass_statement_free, cass_statement_new, cass_statement_set_retry_policy}, testing::assert_cass_error_eq, types::cass_bool_t, @@ -623,15 +629,15 @@ mod tests { .try_init(); } - unsafe fn cass_future_wait_check_and_free(fut: *const CassFuture) { - cass_future_wait(fut); - if cass_future_error_code(fut) != CassError::CASS_OK { + unsafe fn cass_future_wait_check_and_free(fut: CassSharedPtr) { + cass_future_wait(fut.clone()); + if cass_future_error_code(fut.clone()) != CassError::CASS_OK { let mut message: *const c_char = std::ptr::null(); let mut message_len: size_t = 0; - cass_future_error_message(fut as *mut CassFuture, &mut message, &mut message_len); + cass_future_error_message(fut.clone(), &mut message, &mut message_len); eprintln!("{:?}", ptr_to_cstr_n(message, message_len)); } - assert_cass_error_eq!(cass_future_error_code(fut), CassError::CASS_OK); + assert_cass_error_eq!(cass_future_error_code(fut.clone()), CassError::CASS_OK); cass_future_free(fut); } @@ -714,36 +720,49 @@ mod tests { let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw, c_ip, c_ip_len), + cass_cluster_set_contact_points_n(cluster_raw.clone(), c_ip, c_ip_len), CassError::CASS_OK ); let session_raw = cass_session_new(); let profile_raw = cass_execution_profile_new(); { - cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); + cass_future_wait_check_and_free(cass_session_connect( + session_raw.clone(), + cluster_raw.clone().into_const(), + )); // Initially, the profile map is empty. - assert!(ArcFFI::as_ref(session_raw) + assert!(ArcFFI::as_ref(&session_raw) + .unwrap() .blocking_read() .as_ref() .unwrap() .exec_profile_map .is_empty()); - cass_cluster_set_execution_profile(cluster_raw, make_c_str!("prof"), profile_raw); + cass_cluster_set_execution_profile( + cluster_raw.clone(), + make_c_str!("prof"), + profile_raw.clone().into_const(), + ); // Mutations in cluster do not affect the session that was connected before. - assert!(ArcFFI::as_ref(session_raw) + assert!(ArcFFI::as_ref(&session_raw) + .unwrap() .blocking_read() .as_ref() .unwrap() .exec_profile_map .is_empty()); - cass_future_wait_check_and_free(cass_session_close(session_raw)); + cass_future_wait_check_and_free(cass_session_close(session_raw.clone())); // Mutations in cluster are now propagated to the session. - cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); - let profile_map_keys = ArcFFI::as_ref(session_raw) + cass_future_wait_check_and_free(cass_session_connect( + session_raw.clone(), + cluster_raw.clone().into_const(), + )); + let profile_map_keys = ArcFFI::as_ref(&session_raw) + .unwrap() .blocking_read() .as_ref() .unwrap() @@ -756,7 +775,7 @@ mod tests { std::iter::once(ExecProfileName::try_from("prof".to_owned()).unwrap()) .collect::>() ); - cass_future_wait_check_and_free(cass_session_close(session_raw)); + cass_future_wait_check_and_free(cass_session_close(session_raw.clone())); } cass_execution_profile_free(profile_raw); cass_session_free(session_raw); @@ -798,7 +817,7 @@ mod tests { let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw, c_ip, c_ip_len), + cass_cluster_set_contact_points_n(cluster_raw.clone(), c_ip, c_ip_len), CassError::CASS_OK ); @@ -817,31 +836,41 @@ mod tests { let statement_raw = cass_statement_new(invalid_query, 0); let batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); assert_cass_error_eq!( - cass_batch_add_statement(batch_raw, statement_raw), + cass_batch_add_statement(batch_raw.clone(), statement_raw.clone().into_const()), CassError::CASS_OK ); assert_cass_error_eq!( - cass_cluster_set_execution_profile(cluster_raw, valid_name_c_str, profile_raw,), + cass_cluster_set_execution_profile( + cluster_raw.clone(), + valid_name_c_str, + profile_raw.clone().into_const(), + ), CassError::CASS_OK ); - cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); + cass_future_wait_check_and_free(cass_session_connect( + session_raw.clone(), + cluster_raw.clone().into_const(), + )); { /* Test valid configurations */ - let statement = BoxFFI::as_ref(statement_raw); - let batch = BoxFFI::as_ref(batch_raw); + let statement = BoxFFI::as_ref(&statement_raw).unwrap(); + let batch = BoxFFI::as_ref(&batch_raw).unwrap(); { assert!(statement.exec_profile.is_none()); assert!(batch.exec_profile.is_none()); // Set exec profile - it is not yet resolved. assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, valid_name_c_str,), + cass_statement_set_execution_profile( + statement_raw.clone(), + valid_name_c_str, + ), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, valid_name_c_str,), + cass_batch_set_execution_profile(batch_raw.clone(), valid_name_c_str,), CassError::CASS_OK ); assert_eq!( @@ -871,7 +900,10 @@ mod tests { // Make a query - this should resolve the profile. assert_cass_error_eq!( - cass_future_error_code(cass_session_execute(session_raw, statement_raw)), + cass_future_error_code(cass_session_execute( + session_raw.clone(), + statement_raw.clone().into_const() + )), CassError::CASS_ERROR_SERVER_WRITE_FAILURE ); assert!(statement @@ -884,7 +916,10 @@ mod tests { .as_handle() .is_some()); assert_cass_error_eq!( - cass_future_error_code(cass_session_execute_batch(session_raw, batch_raw,)), + cass_future_error_code(cass_session_execute_batch( + session_raw.clone(), + batch_raw.clone().into_const(), + )), CassError::CASS_ERROR_SERVER_WRITE_FAILURE ); assert!(batch @@ -899,11 +934,14 @@ mod tests { // NULL name sets exec profile to None assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, std::ptr::null::()), + cass_statement_set_execution_profile( + statement_raw.clone(), + std::ptr::null::() + ), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, std::ptr::null::()), + cass_batch_set_execution_profile(batch_raw.clone(), std::ptr::null::()), CassError::CASS_OK ); assert!(statement.exec_profile.is_none()); @@ -912,7 +950,7 @@ mod tests { // valid name again, but of nonexisting profile! assert_cass_error_eq!( cass_statement_set_execution_profile_n( - statement_raw, + statement_raw.clone(), nonexisting_name_c_str, nonexisting_name_len, ), @@ -920,7 +958,7 @@ mod tests { ); assert_cass_error_eq!( cass_batch_set_execution_profile_n( - batch_raw, + batch_raw.clone(), nonexisting_name_c_str, nonexisting_name_len, ), @@ -953,7 +991,10 @@ mod tests { // So when we now issue a query, it should end with error and leave exec_profile_handle uninitialised. assert_cass_error_eq!( - cass_future_error_code(cass_session_execute(session_raw, statement_raw)), + cass_future_error_code(cass_session_execute( + session_raw.clone(), + statement_raw.clone().into_const() + )), CassError::CASS_ERROR_LIB_EXECUTION_PROFILE_INVALID ); assert_eq!( @@ -969,7 +1010,10 @@ mod tests { &nonexisting_name.to_owned().try_into().unwrap() ); assert_cass_error_eq!( - cass_future_error_code(cass_session_execute_batch(session_raw, batch_raw)), + cass_future_error_code(cass_session_execute_batch( + session_raw.clone(), + batch_raw.clone().into_const() + )), CassError::CASS_ERROR_LIB_EXECUTION_PROFILE_INVALID ); assert_eq!( @@ -987,7 +1031,7 @@ mod tests { } } - cass_future_wait_check_and_free(cass_session_close(session_raw)); + cass_future_wait_check_and_free(cass_session_close(session_raw.clone())); cass_execution_profile_free(profile_raw); cass_statement_free(statement_raw); cass_batch_free(batch_raw); @@ -1062,13 +1106,13 @@ mod tests { let (c_ip, c_ip_len) = str_to_c_str_n(ip.as_str()); assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw, c_ip, c_ip_len,), + cass_cluster_set_contact_points_n(cluster_raw.clone(), c_ip, c_ip_len,), CassError::CASS_OK ); let fallthrough_policy = cass_retry_policy_fallthrough_new(); let default_policy = cass_retry_policy_default_new(); - cass_cluster_set_retry_policy(cluster_raw, fallthrough_policy); + cass_cluster_set_retry_policy(cluster_raw.clone(), fallthrough_policy.clone()); let session_raw = cass_session_new(); @@ -1077,7 +1121,10 @@ mod tests { let profile_name_c_str = make_c_str!("profile"); assert_cass_error_eq!( - cass_execution_profile_set_retry_policy(profile_raw, default_policy), + cass_execution_profile_set_retry_policy( + profile_raw.clone(), + default_policy.clone() + ), CassError::CASS_OK ); @@ -1085,21 +1132,36 @@ mod tests { let statement_raw = cass_statement_new(query, 0); let batch_raw = cass_batch_new(CassBatchType::CASS_BATCH_TYPE_LOGGED); assert_cass_error_eq!( - cass_batch_add_statement(batch_raw, statement_raw), + cass_batch_add_statement(batch_raw.clone(), statement_raw.clone().into_const()), CassError::CASS_OK ); assert_cass_error_eq!( - cass_cluster_set_execution_profile(cluster_raw, profile_name_c_str, profile_raw,), + cass_cluster_set_execution_profile( + cluster_raw.clone(), + profile_name_c_str, + profile_raw.clone().into_const(), + ), CassError::CASS_OK ); - cass_future_wait_check_and_free(cass_session_connect(session_raw, cluster_raw)); + cass_future_wait_check_and_free(cass_session_connect( + session_raw.clone(), + cluster_raw.clone().into_const(), + )); { - let execute_query = - || cass_future_error_code(cass_session_execute(session_raw, statement_raw)); - let execute_batch = - || cass_future_error_code(cass_session_execute_batch(session_raw, batch_raw)); + let execute_query = || { + cass_future_error_code(cass_session_execute( + session_raw.clone(), + statement_raw.clone().into_const(), + )) + }; + let execute_batch = || { + cass_future_error_code(cass_session_execute_batch( + session_raw.clone(), + batch_raw.clone().into_const(), + )) + }; fn reset_proxy_rules(proxy: &mut RunningProxy) { proxy.running_nodes[0].change_request_rules(Some( @@ -1138,11 +1200,11 @@ mod tests { let set_provided_exec_profile = |name| { // Set statement/batch exec profile. assert_cass_error_eq!( - cass_statement_set_execution_profile(statement_raw, name,), + cass_statement_set_execution_profile(statement_raw.clone(), name,), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_execution_profile(batch_raw, name,), + cass_batch_set_execution_profile(batch_raw.clone(), name,), CassError::CASS_OK ); }; @@ -1152,18 +1214,18 @@ mod tests { let unset_exec_profile = || { set_provided_exec_profile(std::ptr::null::()); }; - let set_retry_policy_on_stmt = |policy| { + let set_retry_policy_on_stmt = |policy: CassSharedPtr| { assert_cass_error_eq!( - cass_statement_set_retry_policy(statement_raw, policy,), + cass_statement_set_retry_policy(statement_raw.clone(), policy.clone(),), CassError::CASS_OK ); assert_cass_error_eq!( - cass_batch_set_retry_policy(batch_raw, policy,), + cass_batch_set_retry_policy(batch_raw.clone(), policy,), CassError::CASS_OK ); }; let unset_retry_policy_on_stmt = || { - set_retry_policy_on_stmt(std::ptr::null()); + set_retry_policy_on_stmt(ArcFFI::null()); }; // ### START TESTING @@ -1183,7 +1245,7 @@ mod tests { assert_query_with_fallthrough_policy(&mut proxy); // F - F - set_retry_policy_on_stmt(fallthrough_policy); + set_retry_policy_on_stmt(fallthrough_policy.clone()); assert_query_with_fallthrough_policy(&mut proxy); // F D F @@ -1195,15 +1257,15 @@ mod tests { assert_query_with_default_policy(&mut proxy); // F D F - set_retry_policy_on_stmt(fallthrough_policy); + set_retry_policy_on_stmt(fallthrough_policy.clone()); assert_query_with_fallthrough_policy(&mut proxy); // F D D - set_retry_policy_on_stmt(default_policy); + set_retry_policy_on_stmt(default_policy.clone()); assert_query_with_default_policy(&mut proxy); // F D F - set_retry_policy_on_stmt(fallthrough_policy); + set_retry_policy_on_stmt(fallthrough_policy.clone()); assert_query_with_fallthrough_policy(&mut proxy); // F - F @@ -1227,7 +1289,7 @@ mod tests { assert_query_with_default_policy(&mut proxy); } - cass_future_wait_check_and_free(cass_session_close(session_raw)); + cass_future_wait_check_and_free(cass_session_close(session_raw.clone())); cass_execution_profile_free(profile_raw); cass_statement_free(statement_raw); cass_batch_free(batch_raw); @@ -1248,20 +1310,28 @@ mod tests { let (c_ip, c_ip_len) = str_to_c_str_n(ip); assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw, c_ip, c_ip_len), + cass_cluster_set_contact_points_n(cluster_raw.clone(), c_ip, c_ip_len), CassError::CASS_OK ); - cass_cluster_set_latency_aware_routing(cluster_raw, true as cass_bool_t); + cass_cluster_set_latency_aware_routing(cluster_raw.clone(), true as cass_bool_t); let session_raw = cass_session_new(); let profile_raw = cass_execution_profile_new(); assert_cass_error_eq!( - cass_execution_profile_set_latency_aware_routing(profile_raw, true as cass_bool_t), + cass_execution_profile_set_latency_aware_routing( + profile_raw.clone(), + true as cass_bool_t + ), CassError::CASS_OK ); let profile_name = make_c_str!("latency_aware"); - cass_cluster_set_execution_profile(cluster_raw, profile_name, profile_raw); + cass_cluster_set_execution_profile( + cluster_raw.clone(), + profile_name, + profile_raw.clone().into_const(), + ); { - let cass_future = cass_session_connect(session_raw, cluster_raw); + let cass_future = + cass_session_connect(session_raw.clone(), cluster_raw.clone().into_const()); cass_future_wait(cass_future); // The exact outcome is not important, we only test that we don't panic. } @@ -1284,19 +1354,19 @@ mod tests { let cluster_raw = cass_cluster_new(); assert_cass_error_eq!( - cass_cluster_set_contact_points_n(cluster_raw, c_ip, c_ip_len), + cass_cluster_set_contact_points_n(cluster_raw.clone(), c_ip, c_ip_len), CassError::CASS_OK ); - cass_cluster_set_latency_aware_routing(cluster_raw, true as cass_bool_t); + cass_cluster_set_latency_aware_routing(cluster_raw.clone(), true as cass_bool_t); let session_raw = cass_session_new(); let profile_raw = cass_execution_profile_new(); assert_cass_error_eq!( - cass_execution_profile_set_latency_aware_routing(profile_raw, true as cass_bool_t), + cass_execution_profile_set_latency_aware_routing(profile_raw.clone(), true as cass_bool_t), CassError::CASS_OK ); - cass_cluster_set_execution_profile(cluster_raw, profile_name, profile_raw); + cass_cluster_set_execution_profile(cluster_raw.clone(), profile_name, profile_raw.clone().into_const()); { - let cass_future = cass_session_connect(session_raw, cluster_raw); + let cass_future = cass_session_connect(session_raw.clone(), cluster_raw.clone().into_const()); // This checks that we don't use-after-free the cluster inside the future. cass_cluster_free(cluster_raw); diff --git a/scylla-rust-wrapper/src/ssl.rs b/scylla-rust-wrapper/src/ssl.rs index ba14a24b..70cf992d 100644 --- a/scylla-rust-wrapper/src/ssl.rs +++ b/scylla-rust-wrapper/src/ssl.rs @@ -1,4 +1,5 @@ use crate::argconv::ArcFFI; +use crate::argconv::CassSharedPtr; use crate::cass_error::CassError; use crate::types::size_t; use libc::{c_int, strlen}; @@ -27,13 +28,13 @@ pub const CASS_SSL_VERIFY_PEER_IDENTITY: i32 = 0x02; pub const CASS_SSL_VERIFY_PEER_IDENTITY_DNS: i32 = 0x04; #[no_mangle] -pub unsafe extern "C" fn cass_ssl_new() -> *const CassSsl { +pub unsafe extern "C" fn cass_ssl_new() -> CassSharedPtr { openssl_sys::init(); cass_ssl_new_no_lib_init() } #[no_mangle] -pub unsafe extern "C" fn cass_ssl_new_no_lib_init() -> *const CassSsl { +pub unsafe extern "C" fn cass_ssl_new_no_lib_init() -> CassSharedPtr { let ssl_context: *mut SSL_CTX = SSL_CTX_new(TLS_method()); let trusted_store: *mut X509_STORE = X509_STORE_new(); @@ -64,7 +65,7 @@ impl Drop for CassSsl { } #[no_mangle] -pub unsafe extern "C" fn cass_ssl_free(ssl: *mut CassSsl) { +pub unsafe extern "C" fn cass_ssl_free(ssl: CassSharedPtr) { ArcFFI::free(ssl); } @@ -96,7 +97,7 @@ unsafe extern "C" fn pem_password_callback( #[no_mangle] pub unsafe extern "C" fn cass_ssl_add_trusted_cert( - ssl: *mut CassSsl, + ssl: CassSharedPtr, cert: *const c_char, ) -> CassError { if cert.is_null() { @@ -108,11 +109,11 @@ pub unsafe extern "C" fn cass_ssl_add_trusted_cert( #[no_mangle] pub unsafe extern "C" fn cass_ssl_add_trusted_cert_n( - ssl: *mut CassSsl, + ssl: CassSharedPtr, cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); let bio = BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()); if bio.is_null() { @@ -139,8 +140,8 @@ pub unsafe extern "C" fn cass_ssl_add_trusted_cert_n( } #[no_mangle] -pub unsafe extern "C" fn cass_ssl_set_verify_flags(ssl: *mut CassSsl, flags: i32) { - let ssl = ArcFFI::cloned_from_ptr(ssl); +pub unsafe extern "C" fn cass_ssl_set_verify_flags(ssl: CassSharedPtr, flags: i32) { + let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); match flags { CASS_SSL_VERIFY_NONE => { @@ -164,7 +165,10 @@ pub unsafe extern "C" fn cass_ssl_set_verify_flags(ssl: *mut CassSsl, flags: i32 } #[no_mangle] -pub unsafe extern "C" fn cass_ssl_set_cert(ssl: *mut CassSsl, cert: *const c_char) -> CassError { +pub unsafe extern "C" fn cass_ssl_set_cert( + ssl: CassSharedPtr, + cert: *const c_char, +) -> CassError { if cert.is_null() { return CassError::CASS_ERROR_SSL_INVALID_CERT; } @@ -174,11 +178,11 @@ pub unsafe extern "C" fn cass_ssl_set_cert(ssl: *mut CassSsl, cert: *const c_cha #[no_mangle] pub unsafe extern "C" fn cass_ssl_set_cert_n( - ssl: *mut CassSsl, + ssl: CassSharedPtr, cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); let bio = BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()); if bio.is_null() { @@ -246,7 +250,7 @@ unsafe extern "C" fn SSL_CTX_use_certificate_chain_bio( #[no_mangle] pub unsafe extern "C" fn cass_ssl_set_private_key( - ssl: *mut CassSsl, + ssl: CassSharedPtr, key: *const c_char, password: *mut c_char, ) -> CassError { @@ -265,13 +269,13 @@ pub unsafe extern "C" fn cass_ssl_set_private_key( #[no_mangle] pub unsafe extern "C" fn cass_ssl_set_private_key_n( - ssl: *mut CassSsl, + ssl: CassSharedPtr, key: *const c_char, key_length: size_t, password: *mut c_char, _password_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl); + let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); let bio = BIO_new_mem_buf(key as *const c_void, key_length.try_into().unwrap()); if bio.is_null() { diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index 7627d116..e1e354ec 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -162,7 +162,7 @@ impl CassStatement { pub unsafe extern "C" fn cass_statement_new( query: *const c_char, parameter_count: size_t, -) -> *mut CassStatement { +) -> CassExclusiveMutPtr { cass_statement_new_n(query, strlen(query), parameter_count) } @@ -171,10 +171,10 @@ pub unsafe extern "C" fn cass_statement_new_n( query: *const c_char, query_length: size_t, parameter_count: size_t, -) -> *mut CassStatement { +) -> CassExclusiveMutPtr { let query_str = match ptr_to_cstr_n(query, query_length) { Some(v) => v, - None => return std::ptr::null_mut(), + None => return BoxFFI::null_mut(), }; let query = Query::new(query_str.to_string()); @@ -196,19 +196,19 @@ pub unsafe extern "C" fn cass_statement_new_n( } #[no_mangle] -pub unsafe extern "C" fn cass_statement_free(statement_raw: *mut CassStatement) { +pub unsafe extern "C" fn cass_statement_free(statement_raw: CassExclusiveMutPtr) { BoxFFI::free(statement_raw); } #[no_mangle] pub unsafe extern "C" fn cass_statement_set_consistency( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, consistency: CassConsistency, ) -> CassError { let consistency_opt = get_consistency_from_cass_consistency(consistency); if let Some(consistency) = consistency_opt { - match &mut BoxFFI::as_mut_ref(statement).statement { + match &mut BoxFFI::as_mut_ref(&mut statement).unwrap().statement { Statement::Simple(inner) => inner.query.set_consistency(consistency), Statement::Prepared(inner) => { Arc::make_mut(inner).statement.set_consistency(consistency) @@ -221,10 +221,10 @@ pub unsafe extern "C" fn cass_statement_set_consistency( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_paging_size( - statement_raw: *mut CassStatement, + mut statement_raw: CassExclusiveMutPtr, page_size: c_int, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement_raw); + let statement = BoxFFI::as_mut_ref(&mut statement_raw).unwrap(); if page_size <= 0 { // Cpp driver sets the page size flag only for positive page size provided by user. statement.paging_enabled = false; @@ -241,11 +241,11 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_paging_state( - statement: *mut CassStatement, - result: *const CassResult, + mut statement: CassExclusiveMutPtr, + result: CassSharedPtr, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement); - let result = ArcFFI::as_ref(result); + let statement = BoxFFI::as_mut_ref(&mut statement).unwrap(); + let result = ArcFFI::as_ref(&result).unwrap(); match &result.paging_state_response { PagingStateResponse::HasMorePages { state } => statement.paging_state.clone_from(state), @@ -256,11 +256,11 @@ pub unsafe extern "C" fn cass_statement_set_paging_state( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_paging_state_token( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, paging_state: *const c_char, paging_state_size: size_t, ) -> CassError { - let statement_from_raw = BoxFFI::as_mut_ref(statement); + let statement_from_raw = BoxFFI::as_mut_ref(&mut statement).unwrap(); if paging_state.is_null() { statement_from_raw.paging_state = PagingState::start(); @@ -275,10 +275,10 @@ pub unsafe extern "C" fn cass_statement_set_paging_state_token( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_is_idempotent( - statement_raw: *mut CassStatement, + mut statement_raw: CassExclusiveMutPtr, is_idempotent: cass_bool_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement_raw).statement { + match &mut BoxFFI::as_mut_ref(&mut statement_raw).unwrap().statement { Statement::Simple(inner) => inner.query.set_is_idempotent(is_idempotent != 0), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -290,10 +290,10 @@ pub unsafe extern "C" fn cass_statement_set_is_idempotent( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_tracing( - statement_raw: *mut CassStatement, + mut statement_raw: CassExclusiveMutPtr, enabled: cass_bool_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement_raw).statement { + match &mut BoxFFI::as_mut_ref(&mut statement_raw).unwrap().statement { Statement::Simple(inner) => inner.query.set_tracing(enabled != 0), Statement::Prepared(inner) => Arc::make_mut(inner).statement.set_tracing(enabled != 0), } @@ -303,11 +303,11 @@ pub unsafe extern "C" fn cass_statement_set_tracing( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_retry_policy( - statement: *mut CassStatement, - retry_policy: *const CassRetryPolicy, + mut statement: CassExclusiveMutPtr, + retry_policy: CassSharedPtr, ) -> CassError { let maybe_arced_retry_policy: Option> = - ArcFFI::as_maybe_ref(retry_policy).map(|policy| match policy { + ArcFFI::as_ref(&retry_policy).map(|policy| match policy { CassRetryPolicy::DefaultRetryPolicy(default) => { default.clone() as Arc } @@ -315,7 +315,7 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( CassRetryPolicy::DowngradingConsistencyRetryPolicy(downgrading) => downgrading.clone(), }); - match &mut BoxFFI::as_mut_ref(statement).statement { + match &mut BoxFFI::as_mut_ref(&mut statement).unwrap().statement { Statement::Simple(inner) => inner.query.set_retry_policy(maybe_arced_retry_policy), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -327,7 +327,7 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_serial_consistency( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, serial_consistency: CassConsistency, ) -> CassError { // cpp-driver doesn't validate passed value in any way. @@ -344,7 +344,7 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( _ => return CassError::CASS_ERROR_LIB_BAD_PARAMS, }; - match &mut BoxFFI::as_mut_ref(statement).statement { + match &mut BoxFFI::as_mut_ref(&mut statement).unwrap().statement { Statement::Simple(inner) => inner.query.set_serial_consistency(Some(consistency)), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -373,10 +373,10 @@ fn get_consistency_from_cass_consistency(consistency: CassConsistency) -> Option #[no_mangle] pub unsafe extern "C" fn cass_statement_set_timestamp( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, timestamp: cass_int64_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement).statement { + match &mut BoxFFI::as_mut_ref(&mut statement).unwrap().statement { Statement::Simple(inner) => inner.query.set_timestamp(Some(timestamp)), Statement::Prepared(inner) => Arc::make_mut(inner) .statement @@ -388,7 +388,7 @@ pub unsafe extern "C" fn cass_statement_set_timestamp( #[no_mangle] pub unsafe extern "C" fn cass_statement_set_request_timeout( - statement: *mut CassStatement, + mut statement: CassExclusiveMutPtr, timeout_ms: cass_uint64_t, ) -> CassError { // The maximum duration for a sleep is 68719476734 milliseconds (approximately 2.2 years). @@ -399,7 +399,7 @@ pub unsafe extern "C" fn cass_statement_set_request_timeout( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let statement_from_raw = BoxFFI::as_mut_ref(statement); + let statement_from_raw = BoxFFI::as_mut_ref(&mut statement).unwrap(); statement_from_raw.request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -407,10 +407,10 @@ pub unsafe extern "C" fn cass_statement_set_request_timeout( #[no_mangle] pub unsafe extern "C" fn cass_statement_reset_parameters( - statement_raw: *mut CassStatement, + mut statement_raw: CassExclusiveMutPtr, count: size_t, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement_raw); + let statement = BoxFFI::as_mut_ref(&mut statement_raw).unwrap(); statement.reset_bound_values(count as usize); CassError::CASS_OK diff --git a/scylla-rust-wrapper/src/testing.rs b/scylla-rust-wrapper/src/testing.rs index 344d228c..2f28e020 100644 --- a/scylla-rust-wrapper/src/testing.rs +++ b/scylla-rust-wrapper/src/testing.rs @@ -21,7 +21,7 @@ macro_rules! assert_cass_future_error_message_eq { let mut ___message: *const c_char = ::std::ptr::null(); let mut ___msg_len: size_t = 0; - cass_future_error_message($cass_fut, &mut ___message, &mut ___msg_len); + cass_future_error_message($cass_fut.clone(), &mut ___message, &mut ___msg_len); assert_eq!(ptr_to_cstr_n(___message, ___msg_len), $error_msg_opt); }; } diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 288d5c67..3d385a32 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -63,7 +63,7 @@ impl From<&CassTuple> for CassCqlValue { } #[no_mangle] -pub unsafe extern "C" fn cass_tuple_new(item_count: size_t) -> *mut CassTuple { +pub unsafe extern "C" fn cass_tuple_new(item_count: size_t) -> CassExclusiveMutPtr { BoxFFI::into_ptr(Box::new(CassTuple { data_type: None, items: vec![None; item_count as usize], @@ -72,12 +72,12 @@ pub unsafe extern "C" fn cass_tuple_new(item_count: size_t) -> *mut CassTuple { #[no_mangle] unsafe extern "C" fn cass_tuple_new_from_data_type( - data_type: *const CassDataType, -) -> *mut CassTuple { - let data_type = ArcFFI::cloned_from_ptr(data_type); + data_type: CassSharedPtr, +) -> CassExclusiveMutPtr { + let data_type = ArcFFI::cloned_from_ptr(data_type).unwrap(); let item_count = match data_type.get_unchecked() { CassDataTypeInner::Tuple(v) => v.len(), - _ => return std::ptr::null_mut(), + _ => return BoxFFI::null_mut(), }; BoxFFI::into_ptr(Box::new(CassTuple { data_type: Some(data_type), @@ -86,13 +86,15 @@ unsafe extern "C" fn cass_tuple_new_from_data_type( } #[no_mangle] -unsafe extern "C" fn cass_tuple_free(tuple: *mut CassTuple) { +unsafe extern "C" fn cass_tuple_free(tuple: CassExclusiveMutPtr) { BoxFFI::free(tuple); } #[no_mangle] -unsafe extern "C" fn cass_tuple_data_type(tuple: *const CassTuple) -> *const CassDataType { - match &BoxFFI::as_ref(tuple).data_type { +unsafe extern "C" fn cass_tuple_data_type( + tuple: CassExclusiveConstPtr, +) -> CassSharedPtr { + match &BoxFFI::as_ref(&tuple).unwrap().data_type { Some(t) => ArcFFI::as_ptr(t), None => ArcFFI::as_ptr(&UNTYPED_TUPLE_TYPE), } @@ -136,16 +138,14 @@ mod tests { let empty_tuple = cass_tuple_new(2); // This would previously return a non Arc-based pointer. - let empty_tuple_dt = cass_tuple_data_type(empty_tuple); + let empty_tuple_dt = cass_tuple_data_type(empty_tuple.into_const()); let empty_set_dt = cass_data_type_new(CassValueType::CASS_VALUE_TYPE_SET); // This will try to increment the reference count of `empty_tuple_dt`. // Previously, this would fail, because `empty_tuple_dt` did not originate from an Arc allocation. - cass_data_type_add_sub_type(empty_set_dt, empty_tuple_dt); + cass_data_type_add_sub_type(empty_set_dt.clone(), empty_tuple_dt); - // Cast to *mut, because `cass_data_type_new` returns a *const. See the comment - // in this function to see why. - cass_data_type_free(empty_set_dt as *mut _) + cass_data_type_free(empty_set_dt) } } } diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 9335bfac..4b160109 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -83,9 +83,9 @@ impl From<&CassUserType> for CassCqlValue { #[no_mangle] pub unsafe extern "C" fn cass_user_type_new_from_data_type( - data_type_raw: *const CassDataType, -) -> *mut CassUserType { - let data_type = ArcFFI::cloned_from_ptr(data_type_raw); + data_type_raw: CassSharedPtr, +) -> CassExclusiveMutPtr { + let data_type = ArcFFI::cloned_from_ptr(data_type_raw).unwrap(); match data_type.get_unchecked() { CassDataTypeInner::UDT(udt_data_type) => { @@ -95,19 +95,19 @@ pub unsafe extern "C" fn cass_user_type_new_from_data_type( field_values, })) } - _ => std::ptr::null_mut(), + _ => BoxFFI::null_mut(), } } #[no_mangle] -pub unsafe extern "C" fn cass_user_type_free(user_type: *mut CassUserType) { +pub unsafe extern "C" fn cass_user_type_free(user_type: CassExclusiveMutPtr) { BoxFFI::free(user_type); } #[no_mangle] pub unsafe extern "C" fn cass_user_type_data_type( - user_type: *const CassUserType, -) -> *const CassDataType { - ArcFFI::as_ptr(&BoxFFI::as_ref(user_type).data_type) + user_type: CassExclusiveConstPtr, +) -> CassSharedPtr { + ArcFFI::as_ptr(&BoxFFI::as_ref(&user_type).unwrap().data_type) } prepare_binders_macro!(@index_and_name CassUserType, diff --git a/scylla-rust-wrapper/src/uuid.rs b/scylla-rust-wrapper/src/uuid.rs index b2343803..29ec4772 100644 --- a/scylla-rust-wrapper/src/uuid.rs +++ b/scylla-rust-wrapper/src/uuid.rs @@ -98,7 +98,7 @@ pub unsafe extern "C" fn cass_uuid_max_from_time(timestamp: cass_uint64_t, outpu } #[no_mangle] -pub unsafe extern "C" fn cass_uuid_gen_new() -> *mut CassUuidGen { +pub unsafe extern "C" fn cass_uuid_gen_new() -> CassExclusiveMutPtr { // Inspired by C++ driver implementation in its intent. // The original driver tries to generate a number that // uniquely identifies this machine and the current process. @@ -122,7 +122,9 @@ pub unsafe extern "C" fn cass_uuid_gen_new() -> *mut CassUuidGen { } #[no_mangle] -pub unsafe extern "C" fn cass_uuid_gen_new_with_node(node: cass_uint64_t) -> *mut CassUuidGen { +pub unsafe extern "C" fn cass_uuid_gen_new_with_node( + node: cass_uint64_t, +) -> CassExclusiveMutPtr { BoxFFI::into_ptr(Box::new(CassUuidGen { clock_seq_and_node: rand_clock_seq_and_node(node & 0x0000FFFFFFFFFFFF), last_timestamp: AtomicU64::new(0), @@ -130,8 +132,11 @@ pub unsafe extern "C" fn cass_uuid_gen_new_with_node(node: cass_uint64_t) -> *mu } #[no_mangle] -pub unsafe extern "C" fn cass_uuid_gen_time(uuid_gen: *mut CassUuidGen, output: *mut CassUuid) { - let uuid_gen = BoxFFI::as_mut_ref(uuid_gen); +pub unsafe extern "C" fn cass_uuid_gen_time( + mut uuid_gen: CassExclusiveMutPtr, + output: *mut CassUuid, +) { + let uuid_gen = BoxFFI::as_mut_ref(&mut uuid_gen).unwrap(); let uuid = CassUuid { time_and_version: set_version(monotonic_timestamp(&mut uuid_gen.last_timestamp), 1), @@ -157,11 +162,11 @@ pub unsafe extern "C" fn cass_uuid_gen_random(_uuid_gen: *mut CassUuidGen, outpu #[no_mangle] pub unsafe extern "C" fn cass_uuid_gen_from_time( - uuid_gen: *mut CassUuidGen, + mut uuid_gen: CassExclusiveMutPtr, timestamp: cass_uint64_t, output: *mut CassUuid, ) { - let uuid_gen = BoxFFI::as_mut_ref(uuid_gen); + let uuid_gen = BoxFFI::as_mut_ref(&mut uuid_gen).unwrap(); let uuid = CassUuid { time_and_version: set_version(from_unix_timestamp(timestamp), 1), @@ -250,6 +255,6 @@ pub unsafe extern "C" fn cass_uuid_from_string_n( } #[no_mangle] -pub unsafe extern "C" fn cass_uuid_gen_free(uuid_gen: *mut CassUuidGen) { +pub unsafe extern "C" fn cass_uuid_gen_free(uuid_gen: CassExclusiveMutPtr) { BoxFFI::free(uuid_gen); }