diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 15c48c54..009846ba 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -294,10 +294,7 @@ impl<'a, T: Transaction> Binder<'a, T> { try_default!(&table_name, column_name); } if let Some(table) = table_name.or(bind_table_name) { - let table_catalog = self - .context - .table(Arc::new(table.clone())) - .ok_or_else(|| DatabaseError::TableNotFound)?; + let table_catalog = self.context.bind_table(&table)?; let column_catalog = table_catalog .get_column_by_name(&column_name) diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 95e44399..65f5ca79 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -24,84 +24,81 @@ impl<'a, T: Transaction> Binder<'a, T> { self.context.allow_default = true; let table_name = Arc::new(lower_case_name(name)?); - if let Some(table) = self.context.table(table_name.clone()) { - let mut _schema_ref = None; - let values_len = expr_rows[0].len(); + let table = self.context.table_and_bind(table_name.clone(), None)?; + let mut _schema_ref = None; + let values_len = expr_rows[0].len(); - if idents.is_empty() { - let temp_schema_ref = table.schema_ref().clone(); - if values_len > temp_schema_ref.len() { - return Err(DatabaseError::ValuesLenMismatch( - temp_schema_ref.len(), - values_len, - )); - } - _schema_ref = Some(temp_schema_ref); - } else { - let mut columns = Vec::with_capacity(idents.len()); - for ident in idents { - match self.bind_column_ref_from_identifiers( - slice::from_ref(ident), - Some(table_name.to_string()), - )? { - ScalarExpression::ColumnRef(catalog) => columns.push(catalog), - _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), - } - } - if values_len != columns.len() { - return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); + if idents.is_empty() { + let temp_schema_ref = table.schema_ref().clone(); + if values_len > temp_schema_ref.len() { + return Err(DatabaseError::ValuesLenMismatch( + temp_schema_ref.len(), + values_len, + )); + } + _schema_ref = Some(temp_schema_ref); + } else { + let mut columns = Vec::with_capacity(idents.len()); + for ident in idents { + match self.bind_column_ref_from_identifiers( + slice::from_ref(ident), + Some(table_name.to_string()), + )? { + ScalarExpression::ColumnRef(catalog) => columns.push(catalog), + _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), } - _schema_ref = Some(Arc::new(columns)); } - let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?; - let mut rows = Vec::with_capacity(expr_rows.len()); + if values_len != columns.len() { + return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); + } + _schema_ref = Some(Arc::new(columns)); + } + let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?; + let mut rows = Vec::with_capacity(expr_rows.len()); - for expr_row in expr_rows { - if expr_row.len() != values_len { - return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); - } - let mut row = Vec::with_capacity(expr_row.len()); + for expr_row in expr_rows { + if expr_row.len() != values_len { + return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); + } + let mut row = Vec::with_capacity(expr_row.len()); - for (i, expr) in expr_row.iter().enumerate() { - let mut expression = self.bind_expr(expr)?; + for (i, expr) in expr_row.iter().enumerate() { + let mut expression = self.bind_expr(expr)?; - expression.constant_calculation()?; - match expression { - ScalarExpression::Constant(mut value) => { - let ty = schema_ref[i].datatype(); - // Check if the value length is too long - value.check_len(ty)?; + expression.constant_calculation()?; + match expression { + ScalarExpression::Constant(mut value) => { + let ty = schema_ref[i].datatype(); + // Check if the value length is too long + value.check_len(ty)?; - if value.logical_type() != *ty { - value = Arc::new(DataValue::clone(&value).cast(ty)?); - } - row.push(value); + if value.logical_type() != *ty { + value = Arc::new(DataValue::clone(&value).cast(ty)?); } - ScalarExpression::Empty => { - row.push(schema_ref[i].default_value().ok_or_else(|| { - DatabaseError::InvalidColumn( - "column does not exist default".to_string(), - ) - })?); - } - _ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())), + row.push(value); + } + ScalarExpression::Empty => { + row.push(schema_ref[i].default_value().ok_or_else(|| { + DatabaseError::InvalidColumn( + "column does not exist default".to_string(), + ) + })?); } + _ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())), } - rows.push(row); } - self.context.allow_default = false; - let values_plan = self.bind_values(rows, schema_ref); - - Ok(LogicalPlan::new( - Operator::Insert(InsertOperator { - table_name, - is_overwrite, - }), - vec![values_plan], - )) - } else { - Err(DatabaseError::TableNotFound) + rows.push(row); } + self.context.allow_default = false; + let values_plan = self.bind_values(rows, schema_ref); + + Ok(LogicalPlan::new( + Operator::Insert(InsertOperator { + table_name, + is_overwrite, + }), + vec![values_plan], + )) } pub(crate) fn bind_values( diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 94f3059c..76836dbe 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -151,6 +151,17 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(table) } + /// get table from bindings + pub fn bind_table(&self, table_name: &str) -> Result<&TableCatalog, DatabaseError> { + let default_name = Arc::new(table_name.to_owned()); + let real_name = self.table_aliases.get(table_name).unwrap_or(&default_name); + self.bind_table + .iter() + .find(|((t, _), _)| t == real_name) + .ok_or(DatabaseError::InvalidTable(table_name.into())) + .map(|v| *v.1) + } + // Tips: The order of this index is based on Aggregate being bound first. pub fn input_ref_index(&self, ty: InputRefType) -> usize { match ty {