diff --git a/src/binder.rs b/src/binder.rs index 4d535b9..f4767ab 100644 --- a/src/binder.rs +++ b/src/binder.rs @@ -78,6 +78,7 @@ pub struct BoundSubqueryTableReferenceAST { #[derive(Debug, PartialEq, Eq, Clone)] pub struct BoundInsertStatementAST { pub table_name: String, + pub column_names: Option>, pub values: Vec, pub first_page_id: PageID, pub table_schema: Schema, @@ -440,12 +441,37 @@ impl Binder { .lock() .map_err(|_| anyhow::anyhow!("lock error"))? .get_schema_by_table_name(&statement.table_name, self.txn_id)?; - if statement.values.len() != schema.columns.len() { - return Err(anyhow::anyhow!( - "expected {} values, but got {}", - schema.columns.len(), - statement.values.len() - )); + match &statement.column_names { + Some(column_names) => { + if column_names.len() != statement.values.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + column_names.len(), + statement.values.len() + )); + } + if column_names.len() > schema.columns.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + schema.columns.len(), + column_names.len() + )); + } + for column_name in column_names { + if !schema.columns.iter().any(|column| column.name == *column_name) { + return Err(anyhow::anyhow!("column {} not found", column_name)); + } + } + } + None => { + if statement.values.len() != schema.columns.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + schema.columns.len(), + statement.values.len() + )); + } + } } let mut values = Vec::new(); for value in &statement.values { @@ -453,6 +479,7 @@ impl Binder { } Ok(BoundStatementAST::Insert(BoundInsertStatementAST { table_name: statement.table_name.clone(), + column_names: statement.column_names.clone(), values, first_page_id, table_schema: schema, @@ -1283,6 +1310,7 @@ mod tests { bound_statement, BoundStatementAST::Insert(BoundInsertStatementAST { table_name: "t1".to_string(), + column_names: None, values: vec![ BoundExpressionAST::Literal(BoundLiteralExpressionAST { value: Value::Integer(IntegerValue(1)), diff --git a/src/executor/insert_executor.rs b/src/executor/insert_executor.rs index e0f4090..db9f4b1 100644 --- a/src/executor/insert_executor.rs +++ b/src/executor/insert_executor.rs @@ -30,7 +30,20 @@ impl InsertExecutor<'_> { .iter() .enumerate() .map(|(i, c)| { - let raw_value = self.plan.values[i].eval( + let index; + match &self.plan.column_names { + Some(column_names) => { + let position = column_names.iter().position(|x| x == &c.name); + match position { + Some(pos) => index = pos, + None => return Ok(Value::Null) + } + }, + None => { + index = i; + } + } + let raw_value = self.plan.values[index].eval( &vec![&Tuple::new(None, &[])], &vec![&Schema { columns: vec![] }], )?; diff --git a/src/parser.rs b/src/parser.rs index af6360f..4ce0c90 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -95,6 +95,7 @@ pub struct LimitAST { #[derive(Debug, PartialEq, Eq, Clone)] pub struct InsertStatementAST { pub table_name: String, + pub column_names: Option>, // TODO: support multiple rows pub values: Vec, } @@ -460,6 +461,18 @@ impl Parser { self.consume_token_or_error(Token::Keyword(Keyword::Insert))?; self.consume_token_or_error(Token::Keyword(Keyword::Into))?; let table_name = self.identifier()?; + let mut column_names: Option> = None; + if self.consume_token(Token::LeftParen) { + let mut names = Vec::new(); + loop { + names.push(self.identifier()?); + if !self.consume_token(Token::Comma) { + break; + } + } + self.consume_token_or_error(Token::RightParen)?; + column_names = Some(names); + } self.consume_token_or_error(Token::Keyword(Keyword::Values))?; self.consume_token_or_error(Token::LeftParen)?; let mut values = Vec::new(); @@ -470,7 +483,7 @@ impl Parser { } } self.consume_token_or_error(Token::RightParen)?; - Ok(InsertStatementAST { table_name, values }) + Ok(InsertStatementAST { table_name, column_names, values }) } fn delete_statement(&mut self) -> Result { self.consume_token_or_error(Token::Keyword(Keyword::Delete))?; @@ -956,6 +969,38 @@ mod tests { statement, StatementAST::Insert(InsertStatementAST { table_name: String::from("users"), + column_names: None, + values: vec![ + ExpressionAST::Literal(LiteralExpressionAST { + value: Value::Integer(IntegerValue(1)), + }), + ExpressionAST::Literal(LiteralExpressionAST { + value: Value::Varchar(VarcharValue(String::from("foo"))), + }), + ExpressionAST::Literal(LiteralExpressionAST { + value: Value::Boolean(BooleanValue(true)), + }), + ], + }) + ); + Ok(()) + } + + #[test] + fn test_parse_insert_with_columns() -> Result<()> { + let sql = "INSERT INTO users (id, name, is_deleted) VALUES (1, 'foo', true)"; + let mut parser = Parser::new(tokenize(&mut sql.chars().peekable())?); + + let statement = parser.parse()?; + assert_eq!( + statement, + StatementAST::Insert(InsertStatementAST { + table_name: String::from("users"), + column_names: Some(vec![ + String::from("id"), + String::from("name"), + String::from("is_deleted"), + ]), values: vec![ ExpressionAST::Literal(LiteralExpressionAST { value: Value::Integer(IntegerValue(1)), diff --git a/src/plan.rs b/src/plan.rs index 1e3d1ba..24cbd46 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -137,6 +137,7 @@ pub struct EmptyRowPlan { pub struct InsertPlan { pub first_page_id: PageID, pub table_schema: Schema, + pub column_names: Option>, pub values: Vec, pub schema: Schema, pub table_name: String, @@ -382,6 +383,7 @@ impl Planner { Plan::Insert(InsertPlan { first_page_id: insert_statement.first_page_id, table_schema: insert_statement.table_schema.clone(), + column_names: insert_statement.column_names.clone(), values: insert_statement.values.clone(), schema: Schema { columns: vec![Column {