diff --git a/dozer-ingestion/src/connectors/snowflake/connection/client.rs b/dozer-ingestion/src/connectors/snowflake/connection/client.rs index dfc6538557..100c2a03b7 100644 --- a/dozer-ingestion/src/connectors/snowflake/connection/client.rs +++ b/dozer-ingestion/src/connectors/snowflake/connection/client.rs @@ -20,11 +20,9 @@ use odbc::odbc_safe::{AutocommitOn, Odbc3}; use odbc::{ColumnDescriptor, Cursor, DiagnosticRecord, Environment, Executed, HasResult}; use rand::distributions::Alphanumeric; use rand::Rng; -use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Write; use std::ops::Deref; -use std::rc::Rc; use super::helpers::is_network_failure; use super::pool::{Conn, Pool}; @@ -252,7 +250,7 @@ impl<'env> Client<'env> { exec_first_exists(&self.pool, &query).map_or_else(Self::parse_not_exist_error, Ok) } - pub fn fetch(&self, query: String) -> ExecIter<'env> { + pub fn fetch(&self, query: String) -> Result, SnowflakeError> { exec_iter(self.pool.clone(), query) } @@ -283,7 +281,7 @@ impl<'env> Client<'env> { ORDER BY TABLE_NAME, ORDINAL_POSITION" ); - let results = exec_iter(self.pool.clone(), query); + let results = exec_iter(self.pool.clone(), query)?; let mut schemas: IndexMap)> = IndexMap::new(); for (idx, result) in results.enumerate() { @@ -388,7 +386,7 @@ impl<'env> Client<'env> { pub fn fetch_keys(&self) -> Result>, SnowflakeError> { 'retry: loop { let query = "SHOW PRIMARY KEYS IN SCHEMA".to_string(); - let results = exec_iter(self.pool.clone(), query); + let results = exec_iter(self.pool.clone(), query)?; let mut keys: HashMap> = HashMap::new(); for result in results { let row_data = match result { @@ -477,16 +475,14 @@ fn exec_first_exists(pool: &Pool, query: &str) -> Result ExecIter { +fn exec_iter(pool: Pool, query: String) -> Result { use genawaiter::{ rc::{gen, Gen}, yield_, }; + use ExecIterResult::*; - let schema = Rc::new(RefCell::new(None::>)); - let schema_ref = schema.clone(); - - let mut generator: Gen, (), _> = gen!({ + let mut generator: Gen = gen!({ let mut cursor_position = 0u64; 'retry: loop { let conn = pool.get_conn().map_err(QueryError)?; @@ -498,23 +494,22 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter { None => break, }; let cols = data.num_result_cols().map_err(|e| QueryError(e.into()))?; - let mut vec = Vec::new(); + let mut schema = Vec::new(); for i in 1..(cols + 1) { let value = i.try_into(); let column_descriptor = match value { Ok(v) => data.describe_col(v).map_err(|e| QueryError(e.into()))?, Err(e) => Err(SchemaConversionError(e))?, }; - vec.push(column_descriptor) + schema.push(column_descriptor) } - schema.borrow_mut().replace(vec); + yield_!(Schema(schema.clone())); while let Some(cursor) = retry!(data.fetch(),'retry).map_err(|e| QueryError(e.into()))? { - let fields = - get_fields_from_cursor(cursor, cols, schema.borrow().as_deref().unwrap())?; - yield_!(fields); + let fields = get_fields_from_cursor(cursor, cols, &schema)?; + yield_!(Row(fields)); cursor_position += 1; } } @@ -524,7 +519,7 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter { Ok::<(), SnowflakeError>(()) }); - let iterator = std::iter::from_fn(move || { + let mut iterator = std::iter::from_fn(move || { use genawaiter::GeneratorState::*; match generator.resume() { Yielded(fields) => Some(Ok(fields)), @@ -533,20 +528,32 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter { } }); - ExecIter { + let schema = match iterator.next() { + Some(Ok(Schema(schema))) => Some(schema), + Some(Err(err)) => Err(err)?, + None => None, + _ => unreachable!(), + }; + + Ok(ExecIter { iterator: Box::new(iterator), - schema: schema_ref, - } + schema, + }) +} + +enum ExecIterResult { + Schema(Vec), + Row(Vec), } pub struct ExecIter<'env> { - iterator: Box, SnowflakeError>> + 'env>, - schema: Rc>>>, + iterator: Box> + 'env>, + schema: Option>, } impl<'env> ExecIter<'env> { - pub fn schema(&self) -> Option> { - self.schema.borrow().deref().clone() + pub fn schema(&self) -> Option<&Vec> { + self.schema.as_ref() } } @@ -554,7 +561,18 @@ impl<'env> Iterator for ExecIter<'env> { type Item = Result, SnowflakeError>; fn next(&mut self) -> Option { - self.iterator.next() + use ExecIterResult::*; + loop { + let result = match self.iterator.next()? { + Ok(Schema(schema)) => { + self.schema = Some(schema); + continue; + } + Ok(Row(row)) => Ok(row), + Err(err) => Err(err), + }; + return Some(result); + } } } diff --git a/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs b/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs index 8cb5e8e8ef..6511a5895c 100644 --- a/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs +++ b/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs @@ -113,10 +113,11 @@ impl StreamConsumer { client.exec(&query)?; } - let rows = client.fetch(format!("SELECT * FROM {temp_table_name};")); - if let Some(schema) = rows.schema() { + let rows = client.fetch(format!("SELECT * FROM {temp_table_name};"))?; + let schema = rows.schema(); + if let Some(schema) = schema { let schema_len = schema.len(); - let mut truncated_schema = schema; + let mut truncated_schema = schema.clone(); truncated_schema.truncate(schema_len - 3); let columns_length = schema_len;