Skip to content

Commit

Permalink
query_result: Introduce CassIteratorStateInfo enum
Browse files Browse the repository at this point in the history
Refactored result and row iterators to wrap value and position fields in
state info enum.
  • Loading branch information
Gor027 committed Jun 27, 2023
1 parent 80e0985 commit aa2c76c
Showing 1 changed file with 71 additions and 42 deletions.
113 changes: 71 additions & 42 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,46 @@ pub struct CassValue {
pub column_type: &'static ColumnType,
}

enum CassIteratorStateInfo<T> {
NoValue,
ValueNoPosition { value: T },
PositionNoValue { position: usize },
Value { value: T, position: usize },
}

impl<T> CassIteratorStateInfo<T> {
/// Update iterator's state and return new state info.
/// Increments an existing position, otherwise sets it to 0.
fn advance(&mut self) {
// Store a dummy NoValue temporarily as we cannot move out StateInfo fields.
let old_state_info = std::mem::replace(self, CassIteratorStateInfo::NoValue);
*self = match old_state_info {
CassIteratorStateInfo::Value { value, position } => CassIteratorStateInfo::Value {
value,
position: position + 1,
},
CassIteratorStateInfo::PositionNoValue { position } => {
CassIteratorStateInfo::PositionNoValue {
position: position + 1,
}
}
CassIteratorStateInfo::ValueNoPosition { value } => {
CassIteratorStateInfo::Value { value, position: 0 }
}
CassIteratorStateInfo::NoValue => {
CassIteratorStateInfo::PositionNoValue { position: 0 }
}
};
}
}

pub struct CassResultIterator {
result: Arc<CassResult>,
row: CassRow,
position: Option<usize>,
state_info: CassIteratorStateInfo<CassRow>,
}

pub struct CassRowIterator {
row: &'static CassRow,
position: Option<usize>,
state_info: CassIteratorStateInfo<&'static CassRow>,
}

/// For sequential iteration over collection types
Expand Down Expand Up @@ -227,24 +258,29 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass

match iter {
CassIterator::CassResultIterator(result_iterator) => {
let new_pos: usize = result_iterator.position.map_or(0, |prev_pos| prev_pos + 1);

result_iterator.position = Some(new_pos);

match result_iterator.result.result.rows_num() {
Some(rs) if new_pos < rs => {
decode_next_row(result_iterator.result.as_ref(), &mut result_iterator.row)
as cass_bool_t
}
_ => false as cass_bool_t,
result_iterator.state_info.advance();

if let CassIteratorStateInfo::Value { value, position } =
&mut result_iterator.state_info
{
return match result_iterator.result.result.rows_num() {
Some(rs) if *position < rs => {
decode_next_row(result_iterator.result.as_ref(), value) as cass_bool_t
}
_ => false as cass_bool_t,
};
}

false as cass_bool_t
}
CassIterator::CassRowIterator(row_iterator) => {
let new_pos: usize = row_iterator.position.map_or(0, |prev_pos| prev_pos + 1);
row_iterator.state_info.advance();

row_iterator.position = Some(new_pos);
if let CassIteratorStateInfo::Value { value, position } = row_iterator.state_info {
return (position < value.columns.len()) as cass_bool_t;
}

(new_pos < row_iterator.row.columns.len()) as cass_bool_t
false as cass_bool_t
}
CassIterator::CassCollectionIterator(collection_iterator) => match collection_iterator {
CassCollectionIterator::SequenceIterator(seq_iterator) => {
Expand Down Expand Up @@ -414,14 +450,13 @@ pub unsafe extern "C" fn cass_iterator_get_row(iterator: *const CassIterator) ->
let iter = ptr_to_ref(iterator);

// Defined only for result iterator, for other types should return null
if let CassIterator::CassResultIterator(result_iterator) = iter {
let iter_position = match result_iterator.position {
Some(pos) => pos,
None => return std::ptr::null(),
};

return match result_iterator.result.result.rows_num() {
Some(rows_count) if iter_position < rows_count => &result_iterator.row,
if let CassIterator::CassResultIterator(CassResultIterator {
result,
state_info: CassIteratorStateInfo::Value { value, position },
}) = iter
{
return match result.result.rows_num() {
Some(rows_count) if *position < rows_count => value,
_ => std::ptr::null(),
};
}
Expand All @@ -436,13 +471,11 @@ pub unsafe extern "C" fn cass_iterator_get_column(
let iter = ptr_to_ref(iterator);

// Defined only for row iterator, for other types should return null
if let CassIterator::CassRowIterator(row_iterator) = iter {
let iter_position = match row_iterator.position {
Some(pos) => pos,
None => return std::ptr::null(),
};

let value = match row_iterator.row.columns.get(iter_position) {
if let CassIterator::CassRowIterator(CassRowIterator {
state_info: CassIteratorStateInfo::Value { value, position },
}) = iter
{
let value = match value.columns.get(*position) {
Some(col) => col,
None => return std::ptr::null(),
};
Expand Down Expand Up @@ -553,11 +586,7 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name(
..
}) => {
assert!(position.map(|pos| pos < *count).is_some()); // assertion copied from c++ driver
write_str_to_c(
field_name.as_str(), // safe to unwrap if cass_iterator_next succeeded
name,
name_length,
);
write_str_to_c(field_name.as_str(), name, name_length);
CassError::CASS_OK
}
_ => CassError::CASS_ERROR_LIB_BAD_PARAMS,
Expand Down Expand Up @@ -760,8 +789,7 @@ pub unsafe extern "C" fn cass_iterator_from_result(result: *const CassResult) ->

let iterator = CassResultIterator {
result: result_from_raw,
row,
position: None,
state_info: CassIteratorStateInfo::ValueNoPosition { value: row },
};

Box::into_raw(Box::new(CassIterator::CassResultIterator(iterator)))
Expand All @@ -772,8 +800,9 @@ pub unsafe extern "C" fn cass_iterator_from_row(row: *const CassRow) -> *mut Cas
let row_from_raw = ptr_to_ref(row);

let iterator = CassRowIterator {
row: row_from_raw,
position: None,
state_info: CassIteratorStateInfo::ValueNoPosition {
value: row_from_raw,
},
};

Box::into_raw(Box::new(CassIterator::CassRowIterator(iterator)))
Expand Down Expand Up @@ -883,7 +912,7 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type(
ColumnType::UserDefinedType { field_types, .. } => field_types.as_slice(),
_ => panic!("Unexpected column type for map collection"),
};
let udt_iterator = UdtIterator::new(fields, frame_slice); // safe to unwrap as is_null is false
let udt_iterator = UdtIterator::new(fields, frame_slice);
let iterator = CassUdtIterator {
udt_iterator,
field_name: None,
Expand Down

0 comments on commit aa2c76c

Please sign in to comment.