From 3b6e8e1c5e5f51b4a442a7a96ebdfd9061f1554f Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Thu, 5 Sep 2024 21:32:02 +0800 Subject: [PATCH] fix: `EXCEPT` not handled well --- src/formatter.rs | 2 +- src/lib.rs | 25 ++++++++++ src/tokenizer.rs | 119 +++++++++++++++++++++++++++++------------------ 3 files changed, 99 insertions(+), 47 deletions(-) diff --git a/src/formatter.rs b/src/formatter.rs index 2ccab77..c18cc04 100644 --- a/src/formatter.rs +++ b/src/formatter.rs @@ -94,7 +94,7 @@ impl<'a> Formatter<'a> { let previous_token = self.previous_token(1); if previous_token.is_some() - && previous_token.unwrap().value.contains("\n") + && previous_token.unwrap().value.contains('\n') && is_whitespace_followed_by_special_token { self.add_new_line(query); diff --git a/src/lib.rs b/src/lib.rs index 80f3c71..1800bf9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1633,4 +1633,29 @@ mod tests { ); assert_eq!(format(input, &QueryParams::None, options), expected); } + + #[test] + fn it_formats_except_on_columns() { + let input = indoc!( + "SELECT table_0.* EXCEPT (profit), + details.* EXCEPT (item_id), + table_0.profit + FROM table_0" + ); + let options = FormatOptions { + indent: Indent::Spaces(4), + ..Default::default() + }; + let expected = indoc!( + " + SELECT + table_0.* EXCEPT (profit), + details.* EXCEPT (item_id), + table_0.profit + FROM + table_0" + ); + + assert_eq!(format(input, &QueryParams::None, options), expected); + } } diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 4fe572e..e0de6de 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -14,16 +14,20 @@ pub(crate) fn tokenize(mut input: &str, named_placeholders: bool) -> Vec = Vec::new(); let mut last_reserved_token = None; + let mut last_reserved_top_level_token = None; // Keep processing the string until it is empty while let Ok(result) = get_next_token( input, tokens.last().cloned(), last_reserved_token.clone(), + last_reserved_top_level_token.clone(), named_placeholders, ) { if result.1.kind == TokenKind::Reserved { last_reserved_token = Some(result.1.clone()); + } else if result.1.kind == TokenKind::ReservedTopLevel { + last_reserved_top_level_token = Some(result.1.clone()); } input = result.0; @@ -86,6 +90,7 @@ fn get_next_token<'a>( input: &'a str, previous_token: Option>, last_reserved_token: Option>, + last_reserved_top_level_token: Option>, named_placeholders: bool, ) -> IResult<&'a str, Token<'a>> { get_whitespace_token(input) @@ -94,7 +99,14 @@ fn get_next_token<'a>( .or_else(|_| get_open_paren_token(input)) .or_else(|_| get_close_paren_token(input)) .or_else(|_| get_number_token(input)) - .or_else(|_| get_reserved_word_token(input, previous_token, last_reserved_token)) + .or_else(|_| { + get_reserved_word_token( + input, + previous_token, + last_reserved_token, + last_reserved_top_level_token, + ) + }) .or_else(|_| get_placeholder_token(input, named_placeholders)) .or_else(|_| get_word_token(input)) .or_else(|_| get_operator_token(input)) @@ -422,6 +434,7 @@ fn get_reserved_word_token<'a>( input: &'a str, previous_token: Option>, last_reserved_token: Option>, + last_reserved_top_level_token: Option>, ) -> IResult<&'a str, Token<'a>> { // A reserved word cannot be preceded by a "." // this makes it so in "my_table.from", "from" is not considered a reserved word @@ -432,7 +445,7 @@ fn get_reserved_word_token<'a>( } alt(( - get_top_level_reserved_token, + get_top_level_reserved_token(last_reserved_top_level_token), get_newline_reserved_token(last_reserved_token), get_top_level_reserved_token_no_indent, get_plain_reserved_token, @@ -449,50 +462,64 @@ fn get_uc_words(input: &str, words: usize) -> String { .to_ascii_uppercase() } -fn get_top_level_reserved_token(input: &str) -> IResult<&str, Token<'_>> { - let uc_input = get_uc_words(input, 3); - let result: IResult<&str, &str> = alt(( - terminated(tag("ADD"), end_of_word), - terminated(tag("AFTER"), end_of_word), - terminated(tag("ALTER COLUMN"), end_of_word), - terminated(tag("ALTER TABLE"), end_of_word), - terminated(tag("DELETE FROM"), end_of_word), - terminated(tag("EXCEPT"), end_of_word), - terminated(tag("FETCH FIRST"), end_of_word), - terminated(tag("FROM"), end_of_word), - terminated(tag("GROUP BY"), end_of_word), - terminated(tag("GO"), end_of_word), - terminated(tag("HAVING"), end_of_word), - terminated(tag("INSERT INTO"), end_of_word), - terminated(tag("INSERT"), end_of_word), - terminated(tag("LIMIT"), end_of_word), - terminated(tag("MODIFY"), end_of_word), - terminated(tag("ORDER BY"), end_of_word), - terminated(tag("SELECT"), end_of_word), - terminated(tag("SET CURRENT SCHEMA"), end_of_word), - terminated(tag("SET SCHEMA"), end_of_word), - terminated(tag("SET"), end_of_word), - alt(( - terminated(tag("UPDATE"), end_of_word), - terminated(tag("VALUES"), end_of_word), - terminated(tag("WHERE"), end_of_word), - terminated(tag("RETURNING"), end_of_word), - )), - ))(&uc_input); - if let Ok((_, token)) = result { - let final_word = token.split(' ').last().unwrap(); - let input_end_pos = input.to_ascii_uppercase().find(final_word).unwrap() + final_word.len(); - let (token, input) = input.split_at(input_end_pos); - Ok(( - input, - Token { - kind: TokenKind::ReservedTopLevel, - value: token, - key: None, - }, - )) - } else { - Err(Err::Error(Error::new(input, ErrorKind::Alt))) +fn get_top_level_reserved_token<'a>( + last_reserved_top_level_token: Option>, +) -> impl FnMut(&'a str) -> IResult<&'a str, Token<'a>> { + move |input: &'a str| { + let uc_input: String = get_uc_words(input, 3); + let result: IResult<&str, &str> = alt(( + terminated(tag("ADD"), end_of_word), + terminated(tag("AFTER"), end_of_word), + terminated(tag("ALTER COLUMN"), end_of_word), + terminated(tag("ALTER TABLE"), end_of_word), + terminated(tag("DELETE FROM"), end_of_word), + terminated(tag("EXCEPT"), end_of_word), + terminated(tag("FETCH FIRST"), end_of_word), + terminated(tag("FROM"), end_of_word), + terminated(tag("GROUP BY"), end_of_word), + terminated(tag("GO"), end_of_word), + terminated(tag("HAVING"), end_of_word), + terminated(tag("INSERT INTO"), end_of_word), + terminated(tag("INSERT"), end_of_word), + terminated(tag("LIMIT"), end_of_word), + terminated(tag("MODIFY"), end_of_word), + terminated(tag("ORDER BY"), end_of_word), + terminated(tag("SELECT"), end_of_word), + terminated(tag("SET CURRENT SCHEMA"), end_of_word), + terminated(tag("SET SCHEMA"), end_of_word), + terminated(tag("SET"), end_of_word), + alt(( + terminated(tag("UPDATE"), end_of_word), + terminated(tag("VALUES"), end_of_word), + terminated(tag("WHERE"), end_of_word), + terminated(tag("RETURNING"), end_of_word), + )), + ))(&uc_input); + if let Ok((_, token)) = result { + let final_word = token.split(' ').last().unwrap(); + let input_end_pos = + input.to_ascii_uppercase().find(final_word).unwrap() + final_word.len(); + let (token, input) = input.split_at(input_end_pos); + let kind = if token == "EXCEPT" + && last_reserved_top_level_token.is_some() + && last_reserved_top_level_token.as_ref().unwrap().value == "SELECT" + { + // If the query statement before and after the except keyword is not complete, mark it a `Word` + TokenKind::Word + } else { + TokenKind::ReservedTopLevel + }; + Ok(( + input, + Token { + kind, + value: token, + key: None, + }, + )) + } else { + Err(Err::Error(Error::new(input, ErrorKind::Alt))) + } } }