Skip to content

Commit

Permalink
Merge branch 'main' into fix/json_path
Browse files Browse the repository at this point in the history
  • Loading branch information
Chloe Kim authored Sep 20, 2023
2 parents f40b3dd + 1d835a9 commit 1d5f5a4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
68 changes: 43 additions & 25 deletions dozer-ingestion/src/connectors/snowflake/connection/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ use odbc::odbc_safe::{AutocommitOn, Odbc3};
use odbc::{ColumnDescriptor, Cursor, DiagnosticRecord, Environment, Executed, HasResult};
use rand::distributions::Alphanumeric;
use rand::Rng;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Write;
use std::ops::Deref;
use std::rc::Rc;

use super::helpers::is_network_failure;
use super::pool::{Conn, Pool};
Expand Down Expand Up @@ -252,7 +250,7 @@ impl<'env> Client<'env> {
exec_first_exists(&self.pool, &query).map_or_else(Self::parse_not_exist_error, Ok)
}

pub fn fetch(&self, query: String) -> ExecIter<'env> {
pub fn fetch(&self, query: String) -> Result<ExecIter<'env>, SnowflakeError> {
exec_iter(self.pool.clone(), query)
}

Expand Down Expand Up @@ -283,7 +281,7 @@ impl<'env> Client<'env> {
ORDER BY TABLE_NAME, ORDINAL_POSITION"
);

let results = exec_iter(self.pool.clone(), query);
let results = exec_iter(self.pool.clone(), query)?;
let mut schemas: IndexMap<String, (usize, Result<Schema, SnowflakeSchemaError>)> =
IndexMap::new();
for (idx, result) in results.enumerate() {
Expand Down Expand Up @@ -388,7 +386,7 @@ impl<'env> Client<'env> {
pub fn fetch_keys(&self) -> Result<HashMap<String, Vec<String>>, SnowflakeError> {
'retry: loop {
let query = "SHOW PRIMARY KEYS IN SCHEMA".to_string();
let results = exec_iter(self.pool.clone(), query);
let results = exec_iter(self.pool.clone(), query)?;
let mut keys: HashMap<String, Vec<String>> = HashMap::new();
for result in results {
let row_data = match result {
Expand Down Expand Up @@ -477,16 +475,14 @@ fn exec_first_exists(pool: &Pool, query: &str) -> Result<bool, Box<DiagnosticRec
}
}

fn exec_iter(pool: Pool, query: String) -> ExecIter {
fn exec_iter(pool: Pool, query: String) -> Result<ExecIter, SnowflakeError> {
use genawaiter::{
rc::{gen, Gen},
yield_,
};
use ExecIterResult::*;

let schema = Rc::new(RefCell::new(None::<Vec<ColumnDescriptor>>));
let schema_ref = schema.clone();

let mut generator: Gen<Vec<Field>, (), _> = gen!({
let mut generator: Gen<ExecIterResult, (), _> = gen!({
let mut cursor_position = 0u64;
'retry: loop {
let conn = pool.get_conn().map_err(QueryError)?;
Expand All @@ -498,23 +494,22 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
None => break,
};
let cols = data.num_result_cols().map_err(|e| QueryError(e.into()))?;
let mut vec = Vec::new();
let mut schema = Vec::new();
for i in 1..(cols + 1) {
let value = i.try_into();
let column_descriptor = match value {
Ok(v) => data.describe_col(v).map_err(|e| QueryError(e.into()))?,
Err(e) => Err(SchemaConversionError(e))?,
};
vec.push(column_descriptor)
schema.push(column_descriptor)
}
schema.borrow_mut().replace(vec);
yield_!(Schema(schema.clone()));

while let Some(cursor) =
retry!(data.fetch(),'retry).map_err(|e| QueryError(e.into()))?
{
let fields =
get_fields_from_cursor(cursor, cols, schema.borrow().as_deref().unwrap())?;
yield_!(fields);
let fields = get_fields_from_cursor(cursor, cols, &schema)?;
yield_!(Row(fields));
cursor_position += 1;
}
}
Expand All @@ -524,7 +519,7 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
Ok::<(), SnowflakeError>(())
});

let iterator = std::iter::from_fn(move || {
let mut iterator = std::iter::from_fn(move || {
use genawaiter::GeneratorState::*;
match generator.resume() {
Yielded(fields) => Some(Ok(fields)),
Expand All @@ -533,28 +528,51 @@ fn exec_iter(pool: Pool, query: String) -> ExecIter {
}
});

ExecIter {
let schema = match iterator.next() {
Some(Ok(Schema(schema))) => Some(schema),
Some(Err(err)) => Err(err)?,
None => None,
_ => unreachable!(),
};

Ok(ExecIter {
iterator: Box::new(iterator),
schema: schema_ref,
}
schema,
})
}

enum ExecIterResult {
Schema(Vec<ColumnDescriptor>),
Row(Vec<Field>),
}

pub struct ExecIter<'env> {
iterator: Box<dyn Iterator<Item = Result<Vec<Field>, SnowflakeError>> + 'env>,
schema: Rc<RefCell<Option<Vec<ColumnDescriptor>>>>,
iterator: Box<dyn Iterator<Item = Result<ExecIterResult, SnowflakeError>> + 'env>,
schema: Option<Vec<ColumnDescriptor>>,
}

impl<'env> ExecIter<'env> {
pub fn schema(&self) -> Option<Vec<ColumnDescriptor>> {
self.schema.borrow().deref().clone()
pub fn schema(&self) -> Option<&Vec<ColumnDescriptor>> {
self.schema.as_ref()
}
}

impl<'env> Iterator for ExecIter<'env> {
type Item = Result<Vec<Field>, SnowflakeError>;

fn next(&mut self) -> Option<Self::Item> {
self.iterator.next()
use ExecIterResult::*;
loop {
let result = match self.iterator.next()? {
Ok(Schema(schema)) => {
self.schema = Some(schema);
continue;
}
Ok(Row(row)) => Ok(row),
Err(err) => Err(err),
};
return Some(result);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions dozer-ingestion/src/connectors/snowflake/stream_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ impl StreamConsumer {
client.exec(&query)?;
}

let rows = client.fetch(format!("SELECT * FROM {temp_table_name};"));
let rows = client.fetch(format!("SELECT * FROM {temp_table_name};"))?;
if let Some(schema) = rows.schema() {
let schema_len = schema.len();
let mut truncated_schema = schema;
let mut truncated_schema = schema.clone();
truncated_schema.truncate(schema_len - 3);

let columns_length = schema_len;
Expand Down

0 comments on commit 1d5f5a4

Please sign in to comment.