From 1d835a94f87650a7ed314f0f88a8c922c2ebbad7 Mon Sep 17 00:00:00 2001
From: Solomon <108011288+abcpro1@users.noreply.github.com>
Date: Wed, 20 Sep 2023 06:47:05 +0000
Subject: [PATCH] fix: snowflake regression (#2061)

Start the query exec iterator to get a valid schema.
The iterator is considered empty if the schema is `None`.
---
 .../connectors/snowflake/connection/client.rs | 68 ++++++++++++-------
 .../connectors/snowflake/stream_consumer.rs   |  4 +-
 2 files changed, 45 insertions(+), 27 deletions(-)

diff --git a/dozer-ingestion/src/connectors/snowflake/connection/client.rs b/dozer-ingestion/src/connectors/snowflake/connection/client.rs
index dfc6538557..100c2a03b7 100644
--- a/dozer-ingestion/src/connectors/snowflake/connection/client.rs
+++ b/dozer-ingestion/src/connectors/snowflake/connection/client.rs
@@ -20,11 +20,9 @@ use odbc::odbc_safe::{AutocommitOn, Odbc3};
 use odbc::{ColumnDescriptor, Cursor, DiagnosticRecord, Environment, Executed, HasResult};
 use rand::distributions::Alphanumeric;
 use rand::Rng;
-use std::cell::RefCell;
 use std::collections::HashMap;
 use std::fmt::Write;
 use std::ops::Deref;
-use std::rc::Rc;
 
 use super::helpers::is_network_failure;
 use super::pool::{Conn, Pool};
@@ -252,7 +250,7 @@ impl<'env> Client<'env> {
         exec_first_exists(&self.pool, &query).map_or_else(Self::parse_not_exist_error, Ok)
     }
 
-    pub fn fetch(&self, query: String) -> ExecIter<'env> {
+    pub fn fetch(&self, query: String) -> Result<ExecIter<'env>, SnowflakeError> {
         exec_iter(self.pool.clone(), query)
     }
 
@@ -283,7 +281,7 @@ impl<'env> Client<'env> {
             ORDER BY TABLE_NAME, ORDINAL_POSITION"
         );
 
-        let results = exec_iter(self.pool.clone(), query);
+        let results = exec_iter(self.pool.clone(), query)?;
         let mut schemas: IndexMap<String, (usize, Result<Schema, SnowflakeSchemaError>)> =
             IndexMap::new();
         for (idx, result) in results.enumerate() {
@@ -388,7 +386,7 @@ impl<'env> Client<'env> {
     pub fn fetch_keys(&self) -> Result<HashMap<String, Vec<String>>, SnowflakeError> {
         'retry: loop {
             let query = "SHOW PRIMARY KEYS IN SCHEMA".to_string();
-            let results = exec_iter(self.pool.clone(), query);
+            let results = exec_iter(self.pool.clone(), query)?;
             let mut keys: HashMap<String, Vec<String>> = HashMap::new();
             for result in results {
                 let row_data = match result {
@@ -477,16 +475,14 @@ fn exec_first_exists(pool: &Pool, query: &str) -> Result<bool, Box<DiagnosticRec
     }
 }
 
-fn exec_iter(pool: Pool, query: String) -> ExecIter {
+fn exec_iter(pool: Pool, query: String) -> Result<ExecIter, SnowflakeError> {
     use genawaiter::{
         rc::{gen, Gen},
         yield_,
     };
+    use ExecIterResult::*;
 
-    let schema = Rc::new(RefCell::new(None::<Vec<ColumnDescriptor>>));
-    let schema_ref = schema.clone();
-
-    let mut generator: Gen<Vec<Field>, (), _> = gen!({
+    let mut generator: Gen<ExecIterResult, (), _> = gen!({
         let mut cursor_position = 0u64;
         'retry: loop {
             let conn = pool.get_conn().map_err(QueryError)?;
@@ -498,23 +494,22 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
                     None => break,
                 };
                 let cols = data.num_result_cols().map_err(|e| QueryError(e.into()))?;
-                let mut vec = Vec::new();
+                let mut schema = Vec::new();
                 for i in 1..(cols + 1) {
                     let value = i.try_into();
                     let column_descriptor = match value {
                         Ok(v) => data.describe_col(v).map_err(|e| QueryError(e.into()))?,
                         Err(e) => Err(SchemaConversionError(e))?,
                     };
-                    vec.push(column_descriptor)
+                    schema.push(column_descriptor)
                 }
-                schema.borrow_mut().replace(vec);
+                yield_!(Schema(schema.clone()));
 
                 while let Some(cursor) =
                     retry!(data.fetch(),'retry).map_err(|e| QueryError(e.into()))?
                 {
-                    let fields =
-                        get_fields_from_cursor(cursor, cols, schema.borrow().as_deref().unwrap())?;
-                    yield_!(fields);
+                    let fields = get_fields_from_cursor(cursor, cols, &schema)?;
+                    yield_!(Row(fields));
                     cursor_position += 1;
                 }
             }
@@ -524,7 +519,7 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
         Ok::<(), SnowflakeError>(())
     });
 
-    let iterator = std::iter::from_fn(move || {
+    let mut iterator = std::iter::from_fn(move || {
         use genawaiter::GeneratorState::*;
         match generator.resume() {
             Yielded(fields) => Some(Ok(fields)),
@@ -533,20 +528,32 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
         }
     });
 
-    ExecIter {
+    let schema = match iterator.next() {
+        Some(Ok(Schema(schema))) => Some(schema),
+        Some(Err(err)) => Err(err)?,
+        None => None,
+        _ => unreachable!(),
+    };
+
+    Ok(ExecIter {
         iterator: Box::new(iterator),
-        schema: schema_ref,
-    }
+        schema,
+    })
+}
+
+enum ExecIterResult {
+    Schema(Vec<ColumnDescriptor>),
+    Row(Vec<Field>),
 }
 
 pub struct ExecIter<'env> {
-    iterator: Box<dyn Iterator<Item = Result<Vec<Field>, SnowflakeError>> + 'env>,
-    schema: Rc<RefCell<Option<Vec<ColumnDescriptor>>>>,
+    iterator: Box<dyn Iterator<Item = Result<ExecIterResult, SnowflakeError>> + 'env>,
+    schema: Option<Vec<ColumnDescriptor>>,
 }
 
 impl<'env> ExecIter<'env> {
-    pub fn schema(&self) -> Option<Vec<ColumnDescriptor>> {
-        self.schema.borrow().deref().clone()
+    pub fn schema(&self) -> Option<&Vec<ColumnDescriptor>> {
+        self.schema.as_ref()
     }
 }
 
@@ -554,7 +561,18 @@ impl<'env> Iterator for ExecIter<'env> {
     type Item = Result<Vec<Field>, SnowflakeError>;
 
     fn next(&mut self) -> Option<Self::Item> {
-        self.iterator.next()
+        use ExecIterResult::*;
+        loop {
+            let result = match self.iterator.next()? {
+                Ok(Schema(schema)) => {
+                    self.schema = Some(schema);
+                    continue;
+                }
+                Ok(Row(row)) => Ok(row),
+                Err(err) => Err(err),
+            };
+            return Some(result);
+        }
     }
 }
 
diff --git a/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs b/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs
index 8cb5e8e8ef..4ef1c6148c 100644
--- a/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs
+++ b/dozer-ingestion/src/connectors/snowflake/stream_consumer.rs
@@ -113,10 +113,10 @@ impl StreamConsumer {
             client.exec(&query)?;
         }
 
-        let rows = client.fetch(format!("SELECT * FROM {temp_table_name};"));
+        let rows = client.fetch(format!("SELECT * FROM {temp_table_name};"))?;
         if let Some(schema) = rows.schema() {
             let schema_len = schema.len();
-            let mut truncated_schema = schema;
+            let mut truncated_schema = schema.clone();
             truncated_schema.truncate(schema_len - 3);
 
             let columns_length = schema_len;