From 09defe69d3967179db48527b9679212e5a92c91b Mon Sep 17 00:00:00 2001 From: mnmandahalf Date: Mon, 13 May 2024 00:04:49 +0900 Subject: [PATCH] make column_name type optional --- src/binder.rs | 57 +++++++++++++++++---------------- src/executor/insert_executor.rs | 22 ++++++------- src/parser.rs | 24 +++++++------- src/plan.rs | 2 +- 4 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/binder.rs b/src/binder.rs index 8d60484..f4767ab 100644 --- a/src/binder.rs +++ b/src/binder.rs @@ -78,7 +78,7 @@ pub struct BoundSubqueryTableReferenceAST { #[derive(Debug, PartialEq, Eq, Clone)] pub struct BoundInsertStatementAST { pub table_name: String, - pub column_names: Vec, + pub column_names: Option>, pub values: Vec, pub first_page_id: PageID, pub table_schema: Schema, @@ -441,33 +441,36 @@ impl Binder { .lock() .map_err(|_| anyhow::anyhow!("lock error"))? .get_schema_by_table_name(&statement.table_name, self.txn_id)?; - if statement.column_names.len() > 0 { - if statement.column_names.len() != statement.values.len() { - return Err(anyhow::anyhow!( - "expected {} values, but got {}", - statement.column_names.len(), - statement.values.len() - )); - } - if statement.column_names.len() > schema.columns.len() { - return Err(anyhow::anyhow!( - "expected {} values, but got {}", - schema.columns.len(), - statement.column_names.len() - )); - } - for column_name in &statement.column_names { - if !schema.columns.iter().any(|column| column.name == *column_name) { - return Err(anyhow::anyhow!("column {} not found", column_name)); + match &statement.column_names { + Some(column_names) => { + if column_names.len() != statement.values.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + column_names.len(), + statement.values.len() + )); + } + if column_names.len() > schema.columns.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + schema.columns.len(), + column_names.len() + )); + } + for column_name in column_names { + if !schema.columns.iter().any(|column| column.name == *column_name) { + return Err(anyhow::anyhow!("column {} not found", column_name)); + } } } - } else { - if statement.values.len() != schema.columns.len() { - return Err(anyhow::anyhow!( - "expected {} values, but got {}", - schema.columns.len(), - statement.values.len() - )); + None => { + if statement.values.len() != schema.columns.len() { + return Err(anyhow::anyhow!( + "expected {} values, but got {}", + schema.columns.len(), + statement.values.len() + )); + } } } let mut values = Vec::new(); @@ -1307,7 +1310,7 @@ mod tests { bound_statement, BoundStatementAST::Insert(BoundInsertStatementAST { table_name: "t1".to_string(), - column_names: vec![], + column_names: None, values: vec![ BoundExpressionAST::Literal(BoundLiteralExpressionAST { value: Value::Integer(IntegerValue(1)), diff --git a/src/executor/insert_executor.rs b/src/executor/insert_executor.rs index c53ccfa..d13c44e 100644 --- a/src/executor/insert_executor.rs +++ b/src/executor/insert_executor.rs @@ -29,19 +29,17 @@ impl InsertExecutor<'_> { .enumerate() .map(|(i, c)| { let index; - if self.plan.column_names.len() > 0 { - let position = self - .plan - .column_names - .iter() - .position(|x| x == &c.name); - // TODO: support default value - match position { - Some(i) => index = i, - None => return Ok(Value::Null), + match &self.plan.column_names { + Some(column_names) => { + let position = column_names.iter().position(|x| x == &c.name); + match position { + Some(pos) => index = pos, + None => return Ok(Value::Null) + } + }, + None => { + index = i; } - } else { - index = i; } let raw_value = self.plan.values[index].eval( &vec![&Tuple::new(None, &[])], diff --git a/src/parser.rs b/src/parser.rs index da1031c..4ce0c90 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -95,7 +95,7 @@ pub struct LimitAST { #[derive(Debug, PartialEq, Eq, Clone)] pub struct InsertStatementAST { pub table_name: String, - pub column_names: Vec, + pub column_names: Option>, // TODO: support multiple rows pub values: Vec, } @@ -461,20 +461,18 @@ impl Parser { self.consume_token_or_error(Token::Keyword(Keyword::Insert))?; self.consume_token_or_error(Token::Keyword(Keyword::Into))?; let table_name = self.identifier()?; - // verify column names - let column_names = if self.consume_token(Token::LeftParen) { - let mut column_names = Vec::new(); + let mut column_names: Option> = None; + if self.consume_token(Token::LeftParen) { + let mut names = Vec::new(); loop { - column_names.push(self.identifier()?); + names.push(self.identifier()?); if !self.consume_token(Token::Comma) { break; } } self.consume_token_or_error(Token::RightParen)?; - column_names - } else { - Vec::new() - }; + column_names = Some(names); + } self.consume_token_or_error(Token::Keyword(Keyword::Values))?; self.consume_token_or_error(Token::LeftParen)?; let mut values = Vec::new(); @@ -971,7 +969,7 @@ mod tests { statement, StatementAST::Insert(InsertStatementAST { table_name: String::from("users"), - column_names: vec![], + column_names: None, values: vec![ ExpressionAST::Literal(LiteralExpressionAST { value: Value::Integer(IntegerValue(1)), @@ -989,7 +987,7 @@ mod tests { } #[test] - fn test_parse_insert_with_column_names() -> Result<()> { + fn test_parse_insert_with_columns() -> Result<()> { let sql = "INSERT INTO users (id, name, is_deleted) VALUES (1, 'foo', true)"; let mut parser = Parser::new(tokenize(&mut sql.chars().peekable())?); @@ -998,11 +996,11 @@ mod tests { statement, StatementAST::Insert(InsertStatementAST { table_name: String::from("users"), - column_names: vec![ + column_names: Some(vec![ String::from("id"), String::from("name"), String::from("is_deleted"), - ], + ]), values: vec![ ExpressionAST::Literal(LiteralExpressionAST { value: Value::Integer(IntegerValue(1)), diff --git a/src/plan.rs b/src/plan.rs index 0734327..e2d185b 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -134,7 +134,7 @@ pub struct EmptyRowPlan { pub struct InsertPlan { pub first_page_id: PageID, pub table_schema: Schema, - pub column_names: Vec, + pub column_names: Option>, pub values: Vec, pub schema: Schema, }