Skip to content

Commit

Permalink
Add Substr
Browse files Browse the repository at this point in the history
  • Loading branch information
mediuminvader committed Feb 26, 2024
1 parent b09512b commit e28c841
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
17 changes: 16 additions & 1 deletion dozer-sql/expression/src/scalar/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use dozer_types::types::Record;
use dozer_types::types::{Field, FieldType, Schema};
use std::fmt::{Display, Formatter};

use super::string::evaluate_chr;
use super::string::{evaluate_chr, evaluate_substr, validate_substr};

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ScalarFunctionType {
Expand All @@ -21,6 +21,7 @@ pub enum ScalarFunctionType {
Length,
ToChar,
Chr,
Substr,
}

impl Display for ScalarFunctionType {
Expand All @@ -33,6 +34,7 @@ impl Display for ScalarFunctionType {
ScalarFunctionType::Length => f.write_str("LENGTH"),
ScalarFunctionType::ToChar => f.write_str("TO_CHAR"),
ScalarFunctionType::Chr => f.write_str("CHR"),
ScalarFunctionType::Substr => f.write_str("SUBSTR"),
}
}
}
Expand Down Expand Up @@ -78,6 +80,7 @@ pub(crate) fn get_scalar_function_type(
}
}
ScalarFunctionType::Chr => validate_one_argument(args, schema, ScalarFunctionType::Chr),
ScalarFunctionType::Substr => validate_substr(args, schema),
}
}

Expand Down Expand Up @@ -130,6 +133,18 @@ impl ScalarFunctionType {
validate_num_arguments(1..2, args.len(), ScalarFunctionType::Chr)?;
evaluate_chr(schema, &mut args[0], record)
}
ScalarFunctionType::Substr => {
let (mut arg0, mut arg1, mut arg2) = if let Some(arg) = args.get(2) {
(
args[0].clone(),
args[1].clone(),
Some(Box::new(arg.clone())),
)
} else {
(args[0].clone(), args[1].clone(), None)
};
evaluate_substr(schema, &mut arg0, &mut arg1, &mut arg2, record)
}
}
}
}
115 changes: 114 additions & 1 deletion dozer-sql/expression/src/scalar/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fmt::{Display, Formatter};

use crate::execution::{Expression, ExpressionType};

use crate::arg_utils::validate_arg_type;
use crate::arg_utils::{validate_arg_type, validate_num_arguments};
use crate::scalar::common::ScalarFunctionType;

use dozer_types::types::Record;
Expand Down Expand Up @@ -302,6 +302,119 @@ pub(crate) fn evaluate_chr(
}
}

pub fn validate_substr(args: &[Expression], schema: &Schema) -> Result<ExpressionType, Error> {
validate_num_arguments(2..4, args.len(), ScalarFunctionType::Substr)?;

if args.len() == 2 {
validate_arg_type(
&args[0],
vec![FieldType::String, FieldType::Text],
schema,
ScalarFunctionType::Substr,
0,
)?;
validate_arg_type(
&args[1],
vec![
FieldType::UInt,
FieldType::U128,
FieldType::Int,
FieldType::I128,
],
schema,
ScalarFunctionType::Substr,
1,
)?;
} else {
validate_arg_type(
&args[0],
vec![FieldType::String, FieldType::Text],
schema,
ScalarFunctionType::Substr,
0,
)?;
validate_arg_type(
&args[1],
vec![
FieldType::UInt,
FieldType::U128,
FieldType::Int,
FieldType::I128,
],
schema,
ScalarFunctionType::Substr,
1,
)?;
validate_arg_type(
&args[2],
vec![
FieldType::UInt,
FieldType::U128,
FieldType::Int,
FieldType::I128,
],
schema,
ScalarFunctionType::Substr,
2,
)?;
}

let ret_type = FieldType::String;

Ok(ExpressionType::new(
ret_type,
false,
dozer_types::types::SourceDefinition::Dynamic,
false,
))
}

pub(crate) fn evaluate_substr(
schema: &Schema,
arg: &mut Expression,
position: &mut Expression,
length: &mut Option<Box<Expression>>,
record: &Record,
) -> Result<Field, Error> {
let arg_field = arg.evaluate(record, schema)?;
let arg_value = arg_field.to_string();

let position_field = position.evaluate(record, schema)?;
let position_result = position_field.to_i128();
if !position_result.is_some() {
return Err(Error::InvalidFunctionArgument {
function_name: "SUBSTR".to_string(),
argument_index: 1,
argument: position_field,
});
}
let position_value = position_result.unwrap();

let length_value = match length {
Some(length_expr) => {
let length_field = length_expr.evaluate(record, schema)?;
let length_result = length_field.to_i128();
if !length_result.is_some() {
return Err(Error::InvalidFunctionArgument {
function_name: "SUBSTR".to_string(),
argument_index: 2,
argument: length_field,
});
}
length_result.unwrap()
}
None => arg_value.len() as i128,
};

let result = arg_value
.chars()
.skip(position_value as usize - 1)
.take(length_value as usize)
.collect::<String>();

Ok(Field::String(result))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit e28c841

Please sign in to comment.