diff --git a/src/batch/src/executor2/insert.rs b/src/batch/src/executor2/insert.rs index 32c7ce65bba85..06c7c92b389ed 100644 --- a/src/batch/src/executor2/insert.rs +++ b/src/batch/src/executor2/insert.rs @@ -177,7 +177,7 @@ mod tests { use std::sync::Arc; use futures::StreamExt; - use risingwave_common::array::{Array, I64Array}; + use risingwave_common::array::{Array, I32Array, I64Array, StructArray}; use risingwave_common::catalog::{schema_test_utils, ColumnDesc, ColumnId}; use risingwave_common::column_nonnull; use risingwave_common::types::DataType; @@ -194,13 +194,19 @@ mod tests { let source_manager = Arc::new(MemSourceManager::default()); let store = MemoryStateStore::new(); + // Make struct field + let struct_field = Field::unnamed(DataType::Struct { + fields: vec![DataType::Int32, DataType::Int32, DataType::Int32].into(), + }); + // Schema for mock executor. - let schema = schema_test_utils::ii(); + let mut schema = schema_test_utils::ii(); + schema.fields.push(struct_field.clone()); let mut mock_executor = MockExecutor::new(schema.clone()); // Schema of the table - let schema = schema_test_utils::iii(); - + let mut schema = schema_test_utils::iii(); + schema.fields.push(struct_field); let table_columns: Vec<_> = schema .fields .iter() @@ -216,7 +222,19 @@ mod tests { let col1 = column_nonnull! { I64Array, [1, 3, 5, 7, 9] }; let col2 = column_nonnull! { I64Array, [2, 4, 6, 8, 10] }; - let data_chunk: DataChunk = DataChunk::builder().columns(vec![col1, col2]).build(); + let array = StructArray::from_slices( + &[true, false, false, false, false], + vec![ + array! { I32Array, [Some(1),None,None,None,None] }.into(), + array! { I32Array, [Some(2),None,None,None,None] }.into(), + array! { I32Array, [Some(3),None,None,None,None] }.into(), + ], + vec![DataType::Int32, DataType::Int32, DataType::Int32], + ) + .map(|x| Arc::new(x.into())) + .unwrap(); + let col3 = Column::new(array); + let data_chunk: DataChunk = DataChunk::builder().columns(vec![col1, col2, col3]).build(); mock_executor.add(data_chunk.clone()); // Create the table. @@ -227,7 +245,7 @@ mod tests { let source_desc = source_manager.get_source(&table_id)?; let source = source_desc.source.as_table_v2().unwrap(); let mut reader = source - .stream_reader(vec![0.into(), 1.into(), 2.into()]) + .stream_reader(vec![0.into(), 1.into(), 2.into(), 3.into()]) .await?; // Insert @@ -276,6 +294,19 @@ mod tests { vec![Some(2), Some(4), Some(6), Some(8), Some(10)] ); + let array: ArrayImpl = StructArray::from_slices( + &[true, false, false, false, false], + vec![ + array! { I32Array, [Some(1),None,None,None,None] }.into(), + array! { I32Array, [Some(2),None,None,None,None] }.into(), + array! { I32Array, [Some(3),None,None,None,None] }.into(), + ], + vec![DataType::Int32, DataType::Int32, DataType::Int32], + ) + .unwrap() + .into(); + assert_eq!(*chunk.chunk.columns()[2].array(), array); + // There's nothing in store since `TableSourceV2` has no side effect. // Data will be materialized in associated streaming task. let epoch = u64::MAX; diff --git a/src/batch/src/executor2/values.rs b/src/batch/src/executor2/values.rs index c0e8eaed0dd70..9f42c10cf58f7 100644 --- a/src/batch/src/executor2/values.rs +++ b/src/batch/src/executor2/values.rs @@ -143,9 +143,12 @@ impl BoxedExecutor2Builder for ValuesExecutor2 { #[cfg(test)] mod tests { + use futures::stream::StreamExt; use risingwave_common::array; - use risingwave_common::array::{I16Array, I32Array, I64Array}; + use risingwave_common::array::{ + ArrayImpl, I16Array, I32Array, I64Array, StructArray, StructValue, + }; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_expr::expr::{BoxedExpression, LiteralExpression}; @@ -154,6 +157,11 @@ mod tests { #[tokio::test] async fn test_values_executor() { + let value = StructValue::new(vec![ + Some(ScalarImpl::Int32(1)), + Some(ScalarImpl::Int32(2)), + Some(ScalarImpl::Int32(3)), + ]); let exprs = vec![ Box::new(LiteralExpression::new( DataType::Int16, @@ -167,6 +175,12 @@ mod tests { DataType::Int64, Some(ScalarImpl::Int64(3)), )), + Box::new(LiteralExpression::new( + DataType::Struct { + fields: vec![DataType::Int32, DataType::Int32, DataType::Int32].into(), + }, + Some(ScalarImpl::Struct(value)), + )) as BoxedExpression, ]; let fields = exprs @@ -185,9 +199,26 @@ mod tests { assert_eq!(fields[0].data_type, DataType::Int16); assert_eq!(fields[1].data_type, DataType::Int32); assert_eq!(fields[2].data_type, DataType::Int64); + assert_eq!( + fields[3].data_type, + DataType::Struct { + fields: vec![DataType::Int32, DataType::Int32, DataType::Int32].into() + } + ); let mut stream = values_executor.execute(); let result = stream.next().await.unwrap(); + let array: ArrayImpl = StructArray::from_slices( + &[true], + vec![ + array! { I32Array, [Some(1)] }.into(), + array! { I32Array, [Some(2)] }.into(), + array! { I32Array, [Some(3)] }.into(), + ], + vec![DataType::Int32, DataType::Int32, DataType::Int32], + ) + .unwrap() + .into(); if let Ok(result) = result { assert_eq!( @@ -202,6 +233,7 @@ mod tests { *result.column_at(2).array(), array! {I64Array, [Some(3)]}.into() ); + assert_eq!(*result.column_at(3).array(), array); } } diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 24c2436f83470..6ccfb8360b619 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -288,6 +288,10 @@ impl ListValue { pub fn new(values: Vec) -> Self { Self { values } } + + pub fn values(&self) -> &[Datum] { + &self.values + } } #[derive(Copy, Clone)] diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index 7865affd8c23d..5980be39c6b8a 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -275,8 +275,12 @@ pub struct StructValue { } impl fmt::Display for StructValue { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { - Ok(()) + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{{{}}}", + self.fields.iter().map(|f| format!("{:?}", f)).join(", ") + ) } } @@ -296,6 +300,10 @@ impl StructValue { pub fn new(fields: Vec) -> Self { Self { fields } } + + pub fn fields(&self) -> &[Datum] { + &self.fields + } } #[derive(Copy, Clone)] diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index d32df607a7ebc..4cc675d4b98c8 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -19,8 +19,7 @@ use bytes::{Buf, BufMut}; use risingwave_pb::data::DataType as ProstDataType; use serde::{Deserialize, Serialize}; -use crate::error::{ErrorCode, Result, RwError}; - +use crate::error::{internal_error, ErrorCode, Result, RwError}; mod native_type; mod scalar_impl; @@ -327,6 +326,16 @@ for_all_scalar_variants! { scalar_impl_enum } pub type Datum = Option; pub type DatumRef<'a> = Option>; +pub fn get_data_type_from_datum(datum: &Datum) -> Result { + match datum { + // TODO: Predicate data type from None Datum + None => Err(internal_error( + "cannot get data type from None Datum".to_string(), + )), + Some(scalar) => scalar.data_type(), + } +} + /// Convert a [`Datum`] to a [`DatumRef`]. pub fn to_datum_ref(datum: &Datum) -> DatumRef<'_> { datum.as_ref().map(|d| d.as_scalar_ref_impl()) diff --git a/src/common/src/types/scalar_impl.rs b/src/common/src/types/scalar_impl.rs index c1fc04344e8a6..f4dc72fc608be 100644 --- a/src/common/src/types/scalar_impl.rs +++ b/src/common/src/types/scalar_impl.rs @@ -300,6 +300,40 @@ impl ScalarImpl { } for_all_scalar_variants! { impl_all_get_ident, self } } + + pub(crate) fn data_type(&self) -> Result { + let data_type = match self { + ScalarImpl::Int16(_) => DataType::Int16, + ScalarImpl::Int32(_) => DataType::Int32, + ScalarImpl::Int64(_) => DataType::Int64, + ScalarImpl::Float32(_) => DataType::Float32, + ScalarImpl::Float64(_) => DataType::Float64, + ScalarImpl::Utf8(_) => DataType::Varchar, + ScalarImpl::Bool(_) => DataType::Boolean, + ScalarImpl::Decimal(_) => DataType::Decimal, + ScalarImpl::Interval(_) => DataType::Interval, + ScalarImpl::NaiveDate(_) => DataType::Date, + ScalarImpl::NaiveDateTime(_) => DataType::Timestamp, + ScalarImpl::NaiveTime(_) => DataType::Time, + ScalarImpl::Struct(data) => { + let types = data + .fields() + .iter() + .map(get_data_type_from_datum) + .collect::>>()?; + DataType::Struct { + fields: types.into(), + } + } + ScalarImpl::List(data) => { + let data = data.values().get(0).ok_or_else(|| { + internal_error("cannot get data type from empty list".to_string()) + })?; + get_data_type_from_datum(data)? + } + }; + Ok(data_type) + } } impl<'scalar> ScalarRefImpl<'scalar> { diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index 117f735b1dd50..6a5e5669616ae 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -103,6 +103,7 @@ fn literal_type_match(return_type: &DataType, literal: Option<&ScalarImpl>) -> b | (DataType::Timestamp, ScalarImpl::NaiveDateTime(_)) | (DataType::Decimal, ScalarImpl::Decimal(_)) | (DataType::Interval, ScalarImpl::Interval(_)) + | (DataType::Struct { .. }, ScalarImpl::Struct(_)) ) } None => true, diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 0887a7bf6928d..c8bb5bec5136e 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -102,6 +102,7 @@ impl Binder { list, negated, } => self.bind_in_list(*expr, list, negated), + Expr::Row(exprs) => Ok(ExprImpl::Literal(Box::new(self.bind_row(&exprs)?))), _ => Err(ErrorCode::NotImplemented( format!("unsupported expression {:?}", expr), 112.into(), diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index 2b92015cf6f94..8dde1e5a0dc1d 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -21,7 +21,7 @@ use crate::binder::Binder; use crate::expr::Literal; impl Binder { - pub(super) fn bind_value(&mut self, value: Value) -> Result { + pub fn bind_value(&mut self, value: Value) -> Result { match value { Value::Number(s, b) => self.bind_number(s, b), Value::SingleQuotedString(s) => self.bind_string(s), diff --git a/src/frontend/src/binder/values.rs b/src/frontend/src/binder/values.rs index 6efaacbaf131e..dc400ad478ad6 100644 --- a/src/frontend/src/binder/values.rs +++ b/src/frontend/src/binder/values.rs @@ -13,14 +13,15 @@ // limitations under the License. use itertools::Itertools; +use risingwave_common::array::StructValue; use risingwave_common::catalog::{Field, Schema}; -use risingwave_common::error::{ErrorCode, Result}; -use risingwave_common::types::DataType; -use risingwave_sqlparser::ast::Values; +use risingwave_common::error::{ErrorCode, Result, TrackingIssue}; +use risingwave_common::types::{get_data_type_from_datum, DataType, Datum, Scalar}; +use risingwave_sqlparser::ast::{Expr, Values}; use super::bind_context::Clause; use crate::binder::Binder; -use crate::expr::{align_types, ExprImpl}; +use crate::expr::{align_types, ExprImpl, Literal}; #[derive(Debug)] pub struct BoundValues { @@ -80,6 +81,34 @@ impl Binder { schema, }) } + + /// Bind row to `struct_value` for nested column, + /// e.g. Row(1,2,(1,2,3)). + /// Only accept value and row expr in row. + pub fn bind_row(&mut self, exprs: &[Expr]) -> Result { + let datums = exprs + .iter() + .map(|e| match e { + Expr::Value(value) => Ok(self.bind_value(value.clone())?.get_data().clone()), + Expr::Row(expr) => Ok(self.bind_row(expr)?.get_data().clone()), + _ => Err(ErrorCode::NotImplemented( + format!("unsupported expression {:?}", e), + TrackingIssue::none(), + ) + .into()), + }) + .collect::>>()?; + let value = StructValue::new(datums); + let data_type = DataType::Struct { + fields: value + .fields() + .iter() + .map(get_data_type_from_datum) + .collect::>>()? + .into(), + }; + Ok(Literal::new(Some(value.to_scalar_value()), data_type)) + } } #[cfg(test)] diff --git a/src/frontend/src/expr/type_inference.rs b/src/frontend/src/expr/type_inference.rs index 2c5809f46f8ab..1259a6a86e0e3 100644 --- a/src/frontend/src/expr/type_inference.rs +++ b/src/frontend/src/expr/type_inference.rs @@ -206,6 +206,7 @@ fn build_type_derive_map() -> HashMap { E::GreaterThanOrEqual, ]; build_binary_cmp_funcs(&mut map, cmp_exprs, &num_types); + build_binary_cmp_funcs(&mut map, cmp_exprs, &[T::Struct]); build_binary_cmp_funcs(&mut map, cmp_exprs, &[T::Date, T::Timestamp, T::Timestampz]); build_binary_cmp_funcs(&mut map, cmp_exprs, &[T::Time, T::Interval]); for e in cmp_exprs { diff --git a/src/frontend/test_runner/src/lib.rs b/src/frontend/test_runner/src/lib.rs index ecb37786b3223..68388656fb9b4 100644 --- a/src/frontend/test_runner/src/lib.rs +++ b/src/frontend/test_runner/src/lib.rs @@ -93,6 +93,7 @@ pub struct CreateSource { row_format: String, name: String, file: Option, + materialized: Option, } #[serde_with::skip_serializing_none] @@ -200,11 +201,16 @@ impl TestCase { match self.create_source.clone() { Some(source) => { if let Some(content) = source.file { + let materialized = if let Some(true) = source.materialized { + "materialized".to_string() + } else { + "".to_string() + }; let sql = format!( - r#"CREATE SOURCE {} + r#"CREATE {} SOURCE {} WITH ('kafka.topic' = 'abc', 'kafka.servers' = 'localhost:1001') ROW FORMAT {} MESSAGE '.test.TestRecord' ROW SCHEMA LOCATION 'file://"#, - source.name, source.row_format + materialized, source.name, source.row_format ); let temp_file = create_proto_file(content.as_str()); self.run_sql( diff --git a/src/frontend/test_runner/tests/testdata/struct_query.yaml b/src/frontend/test_runner/tests/testdata/struct_query.yaml index 92b8daade7ecd..6e2ed4b319493 100644 --- a/src/frontend/test_runner/tests/testdata/struct_query.yaml +++ b/src/frontend/test_runner/tests/testdata/struct_query.yaml @@ -315,3 +315,58 @@ string address = 1; string zipcode = 3; } +- sql: | + insert into s values (1,2,(1,2,(1,2,3))); + logical_plan: | + LogicalInsert { table: s } + LogicalValues { rows: [[1:Int32, 2:Int32, {Some(Int32(1)), Some(Int32(2)), Some(Struct(StructValue { fields: [Some(Int32(1)), Some(Int32(2)), Some(Int32(3))] }))}:Struct { fields: [Int32, Int32, Struct { fields: [Int32, Int32, Int32] }] }]], schema: Schema { fields: [:Int32, :Int32, :Struct { fields: [Int32, Int32, Struct { fields: [Int32, Int32, Int32] }] }] } } + create_source: + row_format: protobuf + name: s + materialized: true + file: | + syntax = "proto3"; + package test; + message TestRecord { + int32 v1 = 1; + int32 v2 = 2; + V v3 = 3; + } + message V { + int32 v1 = 1; + int32 v2 = 2; + U v3 = 3; + } + message U { + int32 v1 = 1; + int32 v2 = 2; + int32 v3 = 3; + } +- sql: | + select * from s where s.v3 = (1,2,(1,2,3)); + logical_plan: | + LogicalProject { exprs: [$1, $2, $3], expr_alias: [v1, v2, v3] } + LogicalFilter { predicate: ($3 = {Some(Int32(1)), Some(Int32(2)), Some(Struct(StructValue { fields: [Some(Int32(1)), Some(Int32(2)), Some(Int32(3))] }))}:Struct { fields: [Int32, Int32, Struct { fields: [Int32, Int32, Int32] }] }) } + LogicalScan { table: s, columns: [_row_id#0, v1, v2, v3] } + create_source: + row_format: protobuf + name: s + materialized: true + file: | + syntax = "proto3"; + package test; + message TestRecord { + int32 v1 = 1; + int32 v2 = 2; + V v3 = 3; + } + message V { + int32 v1 = 1; + int32 v2 = 2; + U v3 = 3; + } + message U { + int32 v1 = 1; + int32 v2 = 2; + int32 v3 = 3; + } \ No newline at end of file