diff --git a/brush-core/src/arithmetic.rs b/brush-core/src/arithmetic.rs index 9ae0cb44..825874ef 100644 --- a/brush-core/src/arithmetic.rs +++ b/brush-core/src/arithmetic.rs @@ -10,6 +10,10 @@ pub enum EvalError { #[error("division by zero")] DivideByZero, + /// Negative exponent. + #[error("exponent less than 0")] + NegativeExponent, + /// Failed to tokenize an arithmetic expression. #[error("failed to tokenize expression")] FailedToTokenizeExpression, @@ -204,13 +208,19 @@ async fn apply_binary_op( #[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_sign_loss)] match op { - ast::BinaryOperator::Power => Ok(left.pow(right as u32)), - ast::BinaryOperator::Multiply => Ok(left * right), + ast::BinaryOperator::Power => { + if right >= 0 { + Ok(left.wrapping_pow(right as u32)) + } else { + Err(EvalError::NegativeExponent) + } + } + ast::BinaryOperator::Multiply => Ok(left.wrapping_mul(right)), ast::BinaryOperator::Divide => { if right == 0 { Err(EvalError::DivideByZero) } else { - Ok(left / right) + Ok(left.wrapping_div(right)) } } ast::BinaryOperator::Modulo => { @@ -221,10 +231,10 @@ async fn apply_binary_op( } } ast::BinaryOperator::Comma => Ok(right), - ast::BinaryOperator::Add => Ok(left + right), - ast::BinaryOperator::Subtract => Ok(left - right), - ast::BinaryOperator::ShiftLeft => Ok(left << right), - ast::BinaryOperator::ShiftRight => Ok(left >> right), + ast::BinaryOperator::Add => Ok(left.wrapping_add(right)), + ast::BinaryOperator::Subtract => Ok(left.wrapping_sub(right)), + ast::BinaryOperator::ShiftLeft => Ok(left.wrapping_shl(right as u32)), + ast::BinaryOperator::ShiftRight => Ok(left.wrapping_shr(right as u32)), ast::BinaryOperator::LessThan => Ok(bool_to_i64(left < right)), ast::BinaryOperator::LessThanOrEqualTo => Ok(bool_to_i64(left <= right)), ast::BinaryOperator::GreaterThan => Ok(bool_to_i64(left > right)), diff --git a/brush-parser/src/arithmetic.rs b/brush-parser/src/arithmetic.rs index ce3cdf70..bf6abe64 100644 --- a/brush-parser/src/arithmetic.rs +++ b/brush-parser/src/arithmetic.rs @@ -68,7 +68,7 @@ peg::parser! { x:(@) _ "%" _ y:@ { ast::ArithmeticExpr::BinaryOp(ast::BinaryOperator::Modulo, Box::new(x), Box::new(y)) } x:(@) _ "/" _ y:@ { ast::ArithmeticExpr::BinaryOp(ast::BinaryOperator::Divide, Box::new(x), Box::new(y)) } -- - x:(@) _ "**" _ y:@ { ast::ArithmeticExpr::BinaryOp(ast::BinaryOperator::Power, Box::new(x), Box::new(y)) } + x:@ _ "**" _ y:(@) { ast::ArithmeticExpr::BinaryOp(ast::BinaryOperator::Power, Box::new(x), Box::new(y)) } -- "!" x:(@) { ast::ArithmeticExpr::UnaryOp(ast::UnaryOperator::LogicalNot, Box::new(x)) } "~" x:(@) { ast::ArithmeticExpr::UnaryOp(ast::UnaryOperator::BitwiseNot, Box::new(x)) } @@ -96,14 +96,14 @@ peg::parser! { } rule variable_name() -> &'input str = - $(['a'..='z' | 'A'..='Z' | '_']+) + $(['a'..='z' | 'A'..='Z' | '_'](['a'..='z' | 'A'..='Z' | '_' | '0'..='9']*)) rule _() -> () = quiet!{[' ' | '\t' | '\n' | '\r']*} {} rule literal_number() -> i64 = // TODO: handle binary? - "0" ['x' | 'X'] s:$(['0'..='9']*) {? i64::from_str_radix(s, 16).or(Err("i64")) } / - s:$("0" ['0'..='9']*) {? i64::from_str_radix(s, 8).or(Err("i64")) } / + "0" ['x' | 'X'] s:$(['0'..='9' | 'a'..='f' | 'A'..='F']*) {? i64::from_str_radix(s, 16).or(Err("i64")) } / + s:$("0" ['0'..='8']*) {? i64::from_str_radix(s, 8).or(Err("i64")) } / s:$(['1'..='9'] ['0'..='9']*) {? s.parse().or(Err("i64")) } } } diff --git a/brush-parser/src/parser.rs b/brush-parser/src/parser.rs index 165b516b..efccc27d 100644 --- a/brush-parser/src/parser.rs +++ b/brush-parser/src/parser.rs @@ -708,8 +708,13 @@ peg::parser! { rule array_element() -> &'input String = linebreak() [Token::Word(e, _)] linebreak() { e } + // N.B. An I/O number must be a string of only digits, and it must be + // followed by a '<' or '>' character (but not consume them). rule io_number() -> u32 = - w:[Token::Word(_, _)] {? w.to_str().parse().or(Err("io_number u32")) } + [Token::Word(w, _) if w.chars().all(|c: char| c.is_ascii_digit())] + &([Token::Operator(o, _) if o.starts_with('<') || o.starts_with('>')]) { + w.parse().unwrap() + } // // Helpers diff --git a/fuzz/fuzz_targets/fuzz_arithmetic.rs b/fuzz/fuzz_targets/fuzz_arithmetic.rs index 4cb7264d..afa60c1d 100644 --- a/fuzz/fuzz_targets/fuzz_arithmetic.rs +++ b/fuzz/fuzz_targets/fuzz_arithmetic.rs @@ -38,7 +38,7 @@ async fn eval_arithmetic_async(input: ast::ArithmeticExpr) -> Result<()> { const DEFAULT_TIMEOUT_IN_SECONDS: u64 = 15; oracle_cmd.timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_IN_SECONDS)); - let input = std::format!("echo \"$(( {input} ))\"\n"); + let input = std::format!("echo \"$(( {input_str} ))\"\n"); oracle_cmd.write_stdin(input.as_bytes()); let oracle_result = oracle_cmd.output()?; @@ -54,7 +54,7 @@ async fn eval_arithmetic_async(input: ast::ArithmeticExpr) -> Result<()> { // if our_eval_result != oracle_eval_result { Err(anyhow::anyhow!( - "Mismatched eval results: {oracle_eval_result:?} from oracle vs. {our_eval_result:?} from our test (expr: '{input}', oracle result: {oracle_result:?})" + "Mismatched eval results: {oracle_eval_result:?} from oracle vs. {our_eval_result:?} from our test (expr: '{input_str}', oracle result: {oracle_result:?})" )) } else { Ok(()) @@ -69,6 +69,8 @@ fuzz_target!(|input: ast::ArithmeticExpr| { if s.contains("+ 0") || s.is_empty() || s.contains(|c: char| c.is_ascii_control() || !c.is_ascii()) + || s.contains("$[") + // old deprecated form of arithmetic expansion { return; }