diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 4e9ba74f2db63..249ee7c2306f0 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -78,6 +78,15 @@ impl Decimal { let decimal = RustDecimal::from_scientific(value).ok()?; Some(Normalized(decimal)) } + + pub fn from_str_radix(s: &str, radix: u32) -> rust_decimal::Result { + match s.to_ascii_lowercase().as_str() { + "nan" => Ok(Decimal::NaN), + "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf), + "-inf" | "-infinity" => Ok(Decimal::NegativeInf), + s => RustDecimal::from_str_radix(s, radix).map(Decimal::Normalized), + } + } } impl ToBinary for Decimal { diff --git a/src/frontend/planner_test/tests/testdata/input/expr.yaml b/src/frontend/planner_test/tests/testdata/input/expr.yaml index 1a020c433d0ab..cc8f104e1ac88 100644 --- a/src/frontend/planner_test/tests/testdata/input/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/input/expr.yaml @@ -40,6 +40,11 @@ SELECT null < null; expected_outputs: - logical_plan +- name: hex bitwise-or bin + sql: | + SELECT 0x25 | 0b110; + expected_outputs: + - logical_plan - name: bind is distinct from sql: | SELECT 1 IS DISTINCT FROM 2 diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index 016bff05efe90..ccb63d024aa43 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -50,6 +50,12 @@ logical_plan: |- LogicalProject { exprs: [(null:Varchar < null:Varchar) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } +- name: hex bitwise-or bin + sql: | + SELECT 0x25 | 0b110; + logical_plan: |- + LogicalProject { exprs: [(37:Int32 | 6:Int32) as $expr1] } + └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } - name: bind is distinct from sql: | SELECT 1 IS DISTINCT FROM 2 diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index ee3c772b0b504..f5faff3de2470 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -52,12 +52,30 @@ impl Binder { Ok(Literal::new(Some(ScalarImpl::Bool(b)), DataType::Boolean)) } - fn bind_number(&mut self, s: String) -> Result { - let (data, data_type) = if let Ok(int_32) = s.parse::() { + fn bind_number(&mut self, mut s: String) -> Result { + let prefix_start = match s.starts_with('-') { + true => 1, + false => 0, + }; + let base = match prefix_start + 2 <= s.len() { + true => match &s[prefix_start..prefix_start + 2] { + // tokenizer already converts them to lowercase + "0x" => 16, + "0o" => 8, + "0b" => 2, + _ => 10, + }, + false => 10, + }; + if base != 10 { + s.replace_range(prefix_start..prefix_start + 2, ""); + } + + let (data, data_type) = if let Ok(int_32) = i32::from_str_radix(&s, base) { (Some(ScalarImpl::Int32(int_32)), DataType::Int32) - } else if let Ok(int_64) = s.parse::() { + } else if let Ok(int_64) = i64::from_str_radix(&s, base) { (Some(ScalarImpl::Int64(int_64)), DataType::Int64) - } else if let Ok(decimal) = s.parse::() { + } else if let Ok(decimal) = Decimal::from_str_radix(&s, base) { // Notice: when the length of decimal exceeds 29(>= 30), it will be rounded up. (Some(ScalarImpl::Decimal(decimal)), DataType::Decimal) } else if let Some(scientific) = Decimal::from_scientific(&s) { @@ -207,6 +225,7 @@ mod tests { use risingwave_expr::expr::build_from_prost; use risingwave_sqlparser::ast::Value::Number; + use super::*; use crate::binder::test_utils::mock_binder; use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall}; @@ -214,8 +233,6 @@ mod tests { async fn test_bind_value() { use std::str::FromStr; - use super::*; - let mut binder = mock_binder(); let values = [ "1", @@ -254,12 +271,33 @@ mod tests { } } + #[tokio::test] + async fn test_bind_radix() { + let mut binder = mock_binder(); + + for (input, expected) in [ + ("0x42e3", ScalarImpl::Int32(0x42e3)), + ("-0x40", ScalarImpl::Int32(-0x40)), + ("0b1101", ScalarImpl::Int32(0b1101)), + ("-0b101", ScalarImpl::Int32(-0b101)), + ("0o664", ScalarImpl::Int32(0o664)), + ("-0o755", ScalarImpl::Int32(-0o755)), + ("2147483647", ScalarImpl::Int32(2147483647)), + ("2147483648", ScalarImpl::Int64(2147483648)), + ("-2147483648", ScalarImpl::Int32(-2147483648)), + ("0x7fffffff", ScalarImpl::Int32(0x7fffffff)), + ("0x80000000", ScalarImpl::Int64(0x80000000)), + ("-0x80000000", ScalarImpl::Int32(-0x80000000)), + ] { + let lit = binder.bind_number(input.into()).unwrap(); + assert_eq!(lit.get_data().as_ref().unwrap(), &expected); + } + } + #[tokio::test] async fn test_bind_scientific_number() { use std::str::FromStr; - use super::*; - let mut binder = mock_binder(); let values = [ ("1e6"), @@ -336,8 +374,6 @@ mod tests { #[tokio::test] async fn test_bind_interval() { - use super::*; - let mut binder = mock_binder(); let values = [ "1 hour", diff --git a/src/sqlparser/src/tokenizer.rs b/src/sqlparser/src/tokenizer.rs index 7ca5c230418cf..03b794b06c62b 100644 --- a/src/sqlparser/src/tokenizer.rs +++ b/src/sqlparser/src/tokenizer.rs @@ -539,14 +539,6 @@ impl<'a> Tokenizer<'a> { chars.next(); // consume the first char let s = self.tokenize_word(ch, chars); - if s.chars().all(|x| x.is_ascii_digit() || x == '.') { - let mut s = peeking_take_while(&mut s.chars().peekable(), |ch| { - ch.is_ascii_digit() || ch == '.' - }); - let s2 = peeking_take_while(chars, |ch| ch.is_ascii_digit() || ch == '.'); - s += s2.as_str(); - return Ok(Some(Token::Number(s))); - } Ok(Some(Token::make_word(&s, None))) } // string @@ -574,10 +566,24 @@ impl<'a> Tokenizer<'a> { let mut s = peeking_take_while(chars, |ch| ch.is_ascii_digit()); // match binary literal that starts with 0x - if s == "0" && chars.peek() == Some(&'x') { + if s == "0" + && let Some(&radix) = chars.peek() + && "xob".contains(radix.to_ascii_lowercase()) + { chars.next(); - let s2 = peeking_take_while(chars, |ch| ch.is_ascii_hexdigit()); - return Ok(Some(Token::HexStringLiteral(s2))); + let radix = radix.to_ascii_lowercase(); + let base = match radix { + 'x' => 16, + 'o' => 8, + 'b' => 2, + _ => unreachable!(), + }; + let s2 = peeking_take_while(chars, |ch| ch.is_digit(base)); + if s2.is_empty() { + return self.tokenizer_error("incomplete integer literal"); + } + self.reject_number_junk(chars)?; + return Ok(Some(Token::Number(format!("0{radix}{s2}")))); } // match one period @@ -603,11 +609,13 @@ impl<'a> Tokenizer<'a> { chars.next(); } s += &peeking_take_while(chars, |ch| ch.is_ascii_digit()); + self.reject_number_junk(chars)?; return Ok(Some(Token::Number(s))); } // Not a scientific number _ => {} }; + self.reject_number_junk(chars)?; Ok(Some(Token::Number(s))) } // punctuation @@ -901,6 +909,15 @@ impl<'a> Tokenizer<'a> { }) } + fn reject_number_junk(&self, chars: &mut Peekable>) -> Result<(), TokenizerError> { + if let Some(ch) = chars.peek() + && is_identifier_start(*ch) + { + return self.tokenizer_error("trailing junk after numeric literal"); + } + Ok(()) + } + // Consume characters until newline fn tokenize_single_line_comment(&self, chars: &mut Peekable>) -> String { let mut comment = peeking_take_while(chars, |ch| ch != '\n'); diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 6aed3d2a4dc4c..1fb897166a1ab 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -77,6 +77,8 @@ - input: SELECT timestamp with time zone '2022-10-01 12:00:00Z' AT TIME ZONE 'US/Pacific' formatted_sql: SELECT TIMESTAMP WITH TIME ZONE '2022-10-01 12:00:00Z' AT TIME ZONE 'US/Pacific' formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(AtTimeZone { timestamp: TypedString { data_type: Timestamp(true), value: "2022-10-01 12:00:00Z" }, time_zone: "US/Pacific" })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT 0c6 + error_msg: 'sql parser error: trailing junk after numeric literal at Line: 1, Column 8' - input: SELECT 1e6 formatted_sql: SELECT 1e6 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("1e6")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' @@ -89,6 +91,34 @@ - input: SELECT -1e6 formatted_sql: SELECT -1e6 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("-1e6")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT 0x42e3 + formatted_sql: SELECT 0x42e3 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("0x42e3")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT -0X40 + formatted_sql: SELECT -0x40 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("-0x40")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT 0B1101 + formatted_sql: SELECT 0b1101 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("0b1101")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT -0b101 + formatted_sql: SELECT -0b101 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("-0b101")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT 0o664 + formatted_sql: SELECT 0o664 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("0o664")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT -0O755 + formatted_sql: SELECT -0o755 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("-0o755")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT 0o129 + error_msg: |- + sql parser error: Expected end of statement, found: 9 at line:1, column:13 + Near "SELECT 0o12" +- input: SELECT 0o3.5 + error_msg: |- + sql parser error: Expected end of statement, found: .5 at line:1, column:13 + Near "SELECT 0o3" +- input: SELECT 0x + error_msg: 'sql parser error: incomplete integer literal at Line: 1, Column 8' - input: SELECT 1::float(0) error_msg: 'sql parser error: precision for type float must be at least 1 bit' - input: SELECT 1::float(54)