diff --git a/src/execution/executor/dml/copy_from_file.rs b/src/execution/executor/dml/copy_from_file.rs index 4429a577..4bd42902 100644 --- a/src/execution/executor/dml/copy_from_file.rs +++ b/src/execution/executor/dml/copy_from_file.rs @@ -1,30 +1,23 @@ use crate::binder::copy::FileFormat; -use futures_async_stream::try_stream; use crate::execution::executor::{BoxedExecutor, Executor}; use crate::execution::ExecutorError; use crate::planner::operator::copy_from_file::CopyFromFileOperator; use crate::storage::{Storage, Transaction}; use crate::types::tuple::Tuple; +use crate::types::tuple_builder::TupleBuilder; +use futures_async_stream::try_stream; use std::fs::File; use std::io::BufReader; use tokio::sync::mpsc::Sender; -use crate::types::tuple_builder::TupleBuilder; - - pub struct CopyFromFile { op: CopyFromFileOperator, size: usize, - } - impl From for CopyFromFile { fn from(op: CopyFromFileOperator) -> Self { - CopyFromFile { - op, - size: 0, - } + CopyFromFile { op, size: 0 } } } @@ -34,31 +27,31 @@ impl Executor for CopyFromFile { } } - impl CopyFromFile { #[try_stream(boxed, ok = Tuple, error = ExecutorError)] pub async fn _execute(self, storage: S) { let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let (tx1, mut rx1) = tokio::sync::mpsc::channel(1); // # Cancellation // When this stream is dropped, the `rx` is dropped, the spawned task will fail to send to // `tx`, then the task will finish. let table_name = self.op.table.clone(); - let mut txn = storage.transaction(&table_name).await.unwrap(); - let handle = tokio::task::spawn_blocking(|| self.read_file_blocking(tx)); - let mut size = 0 as usize; - while let Some(chunk) = rx.recv().await { - txn.append(chunk, false)?; - size += 1; - } - handle.await.unwrap()?; - txn.commit().await?; + if let Some(mut txn) = storage.transaction(&table_name).await { + let handle = tokio::task::spawn_blocking(|| self.read_file_blocking(tx)); + let mut size = 0 as usize; + while let Some(chunk) = rx.recv().await { + txn.append(chunk, false)?; + size += 1; + } + handle.await?; + txn.commit().await?; - let (tx1, mut rx1) = tokio::sync::mpsc::channel(1); - let handle = tokio::task::spawn_blocking(move || return_result(size.clone(), tx1)); - while let Some(chunk) = rx1.recv().await { - yield chunk; + let handle = tokio::task::spawn_blocking(move || return_result(size.clone(), tx1)); + while let Some(chunk) = rx1.recv().await { + yield chunk; + } + handle.await?; } - handle.await.unwrap()?; } /// Read records from file using blocking IO. /// @@ -82,8 +75,10 @@ impl CopyFromFile { let column_count = self.op.types.len(); let mut size_count = 0; + for record in reader.records() { - let mut tuple_builder = TupleBuilder::new(self.op.types.clone(), self.op.columns.clone()); + let mut tuple_builder = + TupleBuilder::new(self.op.types.clone(), self.op.columns.clone()); // read records and push raw str rows into data chunk builder let record = record?; @@ -110,24 +105,24 @@ impl CopyFromFile { fn return_result(size: usize, tx: Sender) -> Result<(), ExecutorError> { let tuple_builder = TupleBuilder::new_result(); - let tuple = tuple_builder.push_result("COPY FROM SOURCE", format!("import {} rows", size).as_str())?; + let tuple = + tuple_builder.push_result("COPY FROM SOURCE", format!("import {} rows", size).as_str())?; tx.blocking_send(tuple).map_err(|_| ExecutorError::Abort)?; Ok(()) } #[cfg(test)] mod tests { + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::Database; + use futures::StreamExt; use std::io::Write; use std::sync::Arc; - use futures::StreamExt; use tempfile::TempDir; - use crate::catalog::{ColumnCatalog, ColumnDesc}; - use crate::db::Database; use super::*; - use crate::types::LogicalType; use crate::binder::copy::ExtSource; - + use crate::types::LogicalType; #[tokio::test] async fn read_csv() { @@ -189,12 +184,15 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let db = Database::with_kipdb(temp_dir.path()).await.unwrap(); - let _ = db.run("create table test_copy (a int primary key, b float, c varchar(10))").await; + let _ = db + .run("create table test_copy (a int primary key, b float, c varchar(10))") + .await; let actual = executor.execute(&db.storage).next().await.unwrap().unwrap(); - let tuple_builder = TupleBuilder::new_result(); - let expected = tuple_builder.push_result("COPY FROM SOURCE", format!("import {} rows", 2).as_str()).unwrap(); + let expected = tuple_builder + .push_result("COPY FROM SOURCE", format!("import {} rows", 2).as_str()) + .unwrap(); assert_eq!(actual, expected); } } diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 727276a5..49bc412c 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -56,4 +56,12 @@ pub enum ExecutorError { LengthMismatch { expected: usize, actual: usize }, #[error("abort")] Abort, + #[error("unknown error")] + Unknown, + #[error("join error")] + JoinError( + #[from] + #[source] + tokio::task::JoinError, + ), } diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index 40288ca0..59bc59ee 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; -use std::sync::Arc; use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::types::errors::TypeError; -use crate::types::LogicalType; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, ValueRef}; +use crate::types::LogicalType; +use std::collections::HashMap; +use std::sync::Arc; pub struct TupleBuilder { data_types: Vec, @@ -29,13 +29,9 @@ impl TupleBuilder { } } - pub fn push_result(self,header: &str, message: &str) -> Result { - let columns: Vec = vec![ - Arc::new(ColumnCatalog::new_dummy(header.to_string())), - ]; - let values: Vec = vec![ - Arc::new(DataValue::Utf8(Some(String::from(message)))), - ]; + pub fn push_result(self, header: &str, message: &str) -> Result { + let columns: Vec = vec![Arc::new(ColumnCatalog::new_dummy(header.to_string()))]; + let values: Vec = vec![Arc::new(DataValue::Utf8(Some(String::from(message))))]; let t = Tuple { id: None, columns, @@ -46,32 +42,32 @@ impl TupleBuilder { pub fn push_str_row<'a>( &mut self, - row: impl IntoIterator, + row: impl IntoIterator, ) -> Result, TypeError> { let mut primary_key_index = None; let columns = self.columns.clone(); let mut tuple_map = HashMap::new(); + for (i, value) in row.into_iter().enumerate() { let data_value = DataValue::Utf8(Some(value.to_string())); - let cast_data_value = data_value.cast(&self.data_types[i]).unwrap(); + let cast_data_value = data_value.cast(&self.data_types[i])?; self.data_values.push(Arc::new(cast_data_value.clone())); let col = &columns[i]; - if let Some(col_id) = col.id { - tuple_map.insert(col_id, Arc::new(cast_data_value)); + col.id + .map(|col_id| tuple_map.insert(col_id, Arc::new(cast_data_value.clone()))); + if col.desc.is_primary { + primary_key_index = Some(i); } } + let primary_col_id = primary_key_index + .map(|i| columns[i].id.unwrap()) + .ok_or_else(|| TypeError::InvalidType)?; - let primary_col_id = primary_key_index.get_or_insert_with(|| { - self.columns.iter() - .find(|col| col.desc.is_primary) - .map(|col| col.id.unwrap()) - .unwrap() - }); - - let tuple_id = tuple_map.get(primary_col_id) - .cloned() - .unwrap(); + let tuple_id = tuple_map + .get(&primary_col_id) + .ok_or_else(|| TypeError::InvalidType)? + .clone(); let tuple = if self.data_values.len() == self.data_types.len() { Some(Tuple { @@ -84,4 +80,4 @@ impl TupleBuilder { }; Ok(tuple) } -} \ No newline at end of file +}