Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(expr): move several expressions to expr_impl #13923

Merged
merged 10 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ members = [
"src/expr/core",
"src/expr/impl",
"src/expr/macro",
"src/expr/udf",
"src/frontend",
"src/frontend/planner_test",
"src/java_binding",
Expand Down Expand Up @@ -48,7 +49,6 @@ members = [
"src/tests/simulation",
"src/tests/sqlsmith",
"src/tests/state_cleaning_test",
"src/udf",
"src/utils/local_stats_alloc",
"src/utils/pgwire",
"src/utils/runtime",
Expand Down Expand Up @@ -170,7 +170,7 @@ risingwave_sqlsmith = { path = "./src/tests/sqlsmith" }
risingwave_storage = { path = "./src/storage" }
risingwave_stream = { path = "./src/stream" }
risingwave_test_runner = { path = "./src/test_runner" }
risingwave_udf = { path = "./src/udf" }
risingwave_udf = { path = "./src/expr/udf" }
risingwave_variables = { path = "./src/utils/variables" }
risingwave_java_binding = { path = "./src/java_binding" }
risingwave_jni_core = { path = "src/jni_core" }
Expand Down
2 changes: 1 addition & 1 deletion ci/scripts/run-unit-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -euo pipefail
REPO_ROOT=${PWD}

echo "+++ Run python UDF SDK unit tests"
cd ${REPO_ROOT}/src/udf/python
cd ${REPO_ROOT}/src/expr/udf/python
python3 -m pytest
cd ${REPO_ROOT}

Expand Down
2 changes: 1 addition & 1 deletion e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Iterator, List, Optional, Tuple, Any
from decimal import Decimal

sys.path.append("src/udf/python") # noqa
sys.path.append("src/expr/udf/python") # noqa

from risingwave.udf import udf, udtf, UdfServer

Expand Down
8 changes: 3 additions & 5 deletions integration_tests/feature-store/server/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@
from typing import Iterator, List, Optional, Tuple, Any
from decimal import Decimal

sys.path.append("src/udf/python") # noqa
sys.path.append("src/expr/udf/python") # noqa

from risingwave.udf import udf, UdfServer



@udf(input_types=["INT", "VARCHAR"], result_type="INT")
def udf_sum(x: int, y: str) -> int:
if y=='mfa+':
if y == "mfa+":
return x
else:
return -x



if __name__ == "__main__":
server = UdfServer(location="0.0.0.0:8815")
server.add_function(udf_sum)
server.serve()
server.serve()
4 changes: 2 additions & 2 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ pub trait DataChunkTestExt {
/// // T: str
/// // TS: Timestamp
/// // SRL: Serial
/// // {i,f}: struct
/// // <i,f>: struct
/// ```
fn from_pretty(s: &str) -> Self;

Expand Down Expand Up @@ -783,7 +783,7 @@ impl DataChunkTestExt for DataChunk {
"TZ" => DataType::Timestamptz,
"T" => DataType::Varchar,
"SRL" => DataType::Serial,
array if array.starts_with('{') && array.ends_with('}') => {
array if array.starts_with('<') && array.ends_with('>') => {
DataType::Struct(StructType::unnamed(
array[1..array.len() - 1]
.split(',')
Expand Down
2 changes: 1 addition & 1 deletion src/common/src/array/stream_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ impl StreamChunkTestExt for StreamChunk {
/// // TZ: Timestamptz
/// // SRL: Serial
/// // x[]: array of x
/// // {i,f}: struct
/// // <i,f>: struct
/// ```
fn from_pretty(s: &str) -> Self {
let mut chunk_str = String::new();
Expand Down
16 changes: 14 additions & 2 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub enum DataType {
#[from_str(regex = "(?i)^interval$")]
Interval,
#[display("{0}")]
#[from_str(ignore)]
#[from_str(regex = "(?i)^(?P<0>.+)$")]
Struct(StructType),
#[display("{0}[]")]
#[from_str(regex = r"(?i)^(?P<0>.+)\[\]$")]
Expand Down Expand Up @@ -933,7 +933,7 @@ impl ScalarImpl {
Self::List(ListValue::new(builder.finish()))
}
DataType::Struct(s) => {
if !(str.starts_with('{') && str.ends_with('}')) {
if !(str.starts_with('(') && str.ends_with(')')) {
return Err(FromSqlError::from_text(str));
}
let mut fields = Vec::with_capacity(s.len());
Expand Down Expand Up @@ -1466,5 +1466,17 @@ mod tests {
DataType::from_str("interval[]").unwrap(),
DataType::List(Box::new(DataType::Interval))
);

assert_eq!(
DataType::from_str("record").unwrap(),
DataType::Struct(StructType::unnamed(vec![]))
);
assert_eq!(
DataType::from_str("struct<a int4, b varchar>").unwrap(),
DataType::Struct(StructType::new(vec![
("a", DataType::Int32),
("b", DataType::Varchar)
]))
);
}
}
7 changes: 5 additions & 2 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::fmt::{Debug, Display, Formatter};
use std::str::FromStr;
use std::sync::Arc;

use anyhow::anyhow;
use itertools::Itertools;

use super::DataType;
Expand Down Expand Up @@ -153,10 +154,12 @@ impl FromStr for StructType {
if s == "record" {
return Ok(StructType::unnamed(Vec::new()));
}
let s = s.trim_start_matches("struct<").trim_end_matches('>');
if !(s.starts_with("struct<") && s.ends_with('>')) {
return Err(anyhow!("expect struct<...>"));
};
let mut field_types = Vec::new();
let mut field_names = Vec::new();
for field in s.split(',') {
for field in s[7..s.len() - 1].split(',') {
let field = field.trim();
let mut iter = field.split_whitespace();
let field_name = iter.next().unwrap();
Expand Down
22 changes: 11 additions & 11 deletions src/connector/src/sink/formatter/debezium_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,17 +337,17 @@ mod tests {
#[test]
fn test_chunk_to_json() -> Result<()> {
let chunk = StreamChunk::from_pretty(
" i f {i,f}
+ 0 0.0 {0,0.0}
+ 1 1.0 {1,1.0}
+ 2 2.0 {2,2.0}
+ 3 3.0 {3,3.0}
+ 4 4.0 {4,4.0}
+ 5 5.0 {5,5.0}
+ 6 6.0 {6,6.0}
+ 7 7.0 {7,7.0}
+ 8 8.0 {8,8.0}
+ 9 9.0 {9,9.0}",
" i f <i,f>
+ 0 0.0 (0,0.0)
+ 1 1.0 (1,1.0)
+ 2 2.0 (2,2.0)
+ 3 3.0 (3,3.0)
+ 4 4.0 (4,4.0)
+ 5 5.0 (5,5.0)
+ 6 6.0 (6,6.0)
+ 7 7.0 (7,7.0)
+ 8 8.0 (8,8.0)
+ 9 9.0 (9,9.0)",
);

let schema = Schema::new(vec![
Expand Down
22 changes: 5 additions & 17 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@ use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::expr_array_transform::ArrayTransformExpression;
use super::expr_coalesce::CoalesceExpression;
use super::expr_field::FieldExpression;
use super::expr_in::InExpression;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
Expand Down Expand Up @@ -112,10 +107,6 @@ where
RexNode::FuncCall(_) => match prost.function_type() {
// Dedicated types
E::All | E::Some => SomeAllExpression::build_boxed(prost, build_child),
E::In => InExpression::build_boxed(prost, build_child),
E::Coalesce => CoalesceExpression::build_boxed(prost, build_child),
E::Field => FieldExpression::build_boxed(prost, build_child),
E::Vnode => VnodeExpression::build_boxed(prost, build_child),

// General types, lookup in the function signature map
_ => FuncCallBuilder::build_boxed(prost, build_child),
Expand Down Expand Up @@ -193,12 +184,6 @@ pub fn build_func(
ret_type: DataType,
children: Vec<BoxedExpression>,
) -> Result<BoxedExpression> {
if func == PbType::ArrayTransform {
// TODO: The function framework can't handle the lambda arg now.
let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap();
return Ok(ArrayTransformExpression { array, lambda }.boxed());
}

let args = children.iter().map(|c| c.return_type()).collect_vec();
let desc = FUNCTION_REGISTRY
.get(func, &args, &ret_type)
Expand Down Expand Up @@ -300,7 +285,7 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {
assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon");
let ty = self.parse_type();
let value = match value.as_str() {
"null" => None,
"null" | "NULL" => None,
_ => Some(
ScalarImpl::from_text(value.as_bytes(), &ty).expect_str("value", &value),
),
Expand All @@ -313,7 +298,10 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {

fn parse_type(&mut self) -> DataType {
match self.tokens.next().expect("Unexpected end of input") {
Token::Literal(name) => name.parse::<DataType>().expect_str("type", &name),
Token::Literal(name) => name
.replace('_', " ")
.parse::<DataType>()
.expect_str("type", &name),
t => panic!("Expected a Literal, got {t:?}"),
}
}
Expand Down
Loading
Loading