diff --git a/dozer-sql/expression/src/scalar/common.rs b/dozer-sql/expression/src/scalar/common.rs index ec9d7106dc..1447c5d18e 100644 --- a/dozer-sql/expression/src/scalar/common.rs +++ b/dozer-sql/expression/src/scalar/common.rs @@ -10,7 +10,7 @@ use dozer_types::types::Record; use dozer_types::types::{Field, FieldType, Schema}; use std::fmt::{Display, Formatter}; -use super::string::evaluate_chr; +use super::string::{evaluate_chr, evaluate_substr, validate_substr}; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ScalarFunctionType { @@ -21,6 +21,7 @@ pub enum ScalarFunctionType { Length, ToChar, Chr, + Substr, } impl Display for ScalarFunctionType { @@ -33,6 +34,7 @@ impl Display for ScalarFunctionType { ScalarFunctionType::Length => f.write_str("LENGTH"), ScalarFunctionType::ToChar => f.write_str("TO_CHAR"), ScalarFunctionType::Chr => f.write_str("CHR"), + ScalarFunctionType::Substr => f.write_str("SUBSTR"), } } } @@ -78,6 +80,7 @@ pub(crate) fn get_scalar_function_type( } } ScalarFunctionType::Chr => validate_one_argument(args, schema, ScalarFunctionType::Chr), + ScalarFunctionType::Substr => validate_substr(args, schema), } } @@ -130,6 +133,18 @@ impl ScalarFunctionType { validate_num_arguments(1..2, args.len(), ScalarFunctionType::Chr)?; evaluate_chr(schema, &mut args[0], record) } + ScalarFunctionType::Substr => { + let (mut arg0, mut arg1, mut arg2) = if let Some(arg) = args.get(2) { + ( + args[0].clone(), + args[1].clone(), + Some(Box::new(arg.clone())), + ) + } else { + (args[0].clone(), args[1].clone(), None) + }; + evaluate_substr(schema, &mut arg0, &mut arg1, &mut arg2, record) + } } } } diff --git a/dozer-sql/expression/src/scalar/string.rs b/dozer-sql/expression/src/scalar/string.rs index 5d20d2e2f3..a4f054dfce 100644 --- a/dozer-sql/expression/src/scalar/string.rs +++ b/dozer-sql/expression/src/scalar/string.rs @@ -4,7 +4,7 @@ use std::fmt::{Display, Formatter}; use crate::execution::{Expression, ExpressionType}; -use crate::arg_utils::validate_arg_type; +use crate::arg_utils::{validate_arg_type, validate_num_arguments}; use crate::scalar::common::ScalarFunctionType; use dozer_types::types::Record; @@ -302,6 +302,119 @@ pub(crate) fn evaluate_chr( } } +pub fn validate_substr(args: &[Expression], schema: &Schema) -> Result { + validate_num_arguments(2..4, args.len(), ScalarFunctionType::Substr)?; + + if args.len() == 2 { + validate_arg_type( + &args[0], + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Substr, + 0, + )?; + validate_arg_type( + &args[1], + vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + ], + schema, + ScalarFunctionType::Substr, + 1, + )?; + } else { + validate_arg_type( + &args[0], + vec![FieldType::String, FieldType::Text], + schema, + ScalarFunctionType::Substr, + 0, + )?; + validate_arg_type( + &args[1], + vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + ], + schema, + ScalarFunctionType::Substr, + 1, + )?; + validate_arg_type( + &args[2], + vec![ + FieldType::UInt, + FieldType::U128, + FieldType::Int, + FieldType::I128, + ], + schema, + ScalarFunctionType::Substr, + 2, + )?; + } + + let ret_type = FieldType::String; + + Ok(ExpressionType::new( + ret_type, + false, + dozer_types::types::SourceDefinition::Dynamic, + false, + )) +} + +pub(crate) fn evaluate_substr( + schema: &Schema, + arg: &mut Expression, + position: &mut Expression, + length: &mut Option>, + record: &Record, +) -> Result { + let arg_field = arg.evaluate(record, schema)?; + let arg_value = arg_field.to_string(); + + let position_field = position.evaluate(record, schema)?; + let position_result = position_field.to_i128(); + if !position_result.is_some() { + return Err(Error::InvalidFunctionArgument { + function_name: "SUBSTR".to_string(), + argument_index: 1, + argument: position_field, + }); + } + let position_value = position_result.unwrap(); + + let length_value = match length { + Some(length_expr) => { + let length_field = length_expr.evaluate(record, schema)?; + let length_result = length_field.to_i128(); + if !length_result.is_some() { + return Err(Error::InvalidFunctionArgument { + function_name: "SUBSTR".to_string(), + argument_index: 2, + argument: length_field, + }); + } + length_result.unwrap() + } + None => arg_value.len() as i128, + }; + + let result = arg_value + .chars() + .skip(position_value as usize - 1) + .take(length_value as usize) + .collect::(); + + Ok(Field::String(result)) +} + #[cfg(test)] mod tests { use super::*;