diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index f72a1d8739be5..a341e218878c9 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -485,7 +485,7 @@ pub trait FromArrow { fields .iter() .map(|f| Ok((f.name().clone(), self.from_field(f)?))) - .try_collect::<_, _, ArrayError>()?, + .try_collect::<_, Vec<_>, ArrayError>()?, )) } diff --git a/src/common/src/catalog/column.rs b/src/common/src/catalog/column.rs index f7c9260e0d1ab..7b1387a5feffc 100644 --- a/src/common/src/catalog/column.rs +++ b/src/common/src/catalog/column.rs @@ -25,7 +25,7 @@ use risingwave_pb::plan_common::{ use super::{row_id_column_desc, rw_timestamp_column_desc, USER_COLUMN_ID_OFFSET}; use crate::catalog::{cdc_table_name_column_desc, offset_column_desc, Field, ROW_ID_COLUMN_ID}; -use crate::types::DataType; +use crate::types::{DataType, StructType}; use crate::util::value_encoding::DatumToProtoExt; /// Column ID is the unique identifier of a column in a table. Different from table ID, column ID is @@ -270,10 +270,8 @@ impl ColumnDesc { type_name: &str, fields: Vec, ) -> Self { - let data_type = DataType::new_struct( - fields.iter().map(|f| f.data_type.clone()).collect_vec(), - fields.iter().map(|f| f.name.clone()).collect_vec(), - ); + let data_type = + StructType::new(fields.iter().map(|f| (&f.name, f.data_type.clone()))).into(); Self { data_type, column_id: ColumnId::new(column_id), diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 5e6264cd933d0..0b0d63aa4b300 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -30,6 +30,7 @@ use postgres_types::{FromSql, IsNull, ToSql, Type}; use risingwave_common_estimate_size::{EstimateSize, ZeroHeapSize}; use risingwave_pb::data::data_type::PbTypeName; use risingwave_pb::data::PbDataType; +use rw_iter_util::ZipEqFast as _; use serde::{Deserialize, Serialize, Serializer}; use strum_macros::EnumDiscriminants; use thiserror_ext::AsReport; @@ -241,7 +242,11 @@ impl From<&PbDataType> for DataType { PbTypeName::Struct => { let fields: Vec = proto.field_type.iter().map(|f| f.into()).collect_vec(); let field_names: Vec = proto.field_names.iter().cloned().collect_vec(); - DataType::new_struct(fields, field_names) + if proto.field_names.is_empty() { + StructType::unnamed(fields).into() + } else { + StructType::new(field_names.into_iter().zip_eq_fast(fields)).into() + } } PbTypeName::List => DataType::List( // The first (and only) item is the list element type. @@ -405,10 +410,6 @@ impl DataType { } } - pub fn new_struct(fields: Vec, field_names: Vec) -> Self { - Self::Struct(StructType::from_parts(field_names, fields)) - } - pub fn as_struct(&self) -> &StructType { match self { DataType::Struct(t) => t, diff --git a/src/common/src/types/struct_type.rs b/src/common/src/types/struct_type.rs index edc4b73311533..cc1980b34f830 100644 --- a/src/common/src/types/struct_type.rs +++ b/src/common/src/types/struct_type.rs @@ -48,10 +48,11 @@ struct StructTypeInner { impl StructType { /// Creates a struct type with named fields. - pub fn new(named_fields: Vec<(impl Into, DataType)>) -> Self { - let mut field_types = Vec::with_capacity(named_fields.len()); - let mut field_names = Vec::with_capacity(named_fields.len()); - for (name, ty) in named_fields { + pub fn new(named_fields: impl IntoIterator, DataType)>) -> Self { + let iter = named_fields.into_iter(); + let mut field_types = Vec::with_capacity(iter.size_hint().0); + let mut field_names = Vec::with_capacity(iter.size_hint().0); + for (name, ty) in iter { field_names.push(name.into()); field_types.push(ty); } @@ -70,15 +71,6 @@ impl StructType { })) } - pub(super) fn from_parts(field_names: Vec, field_types: Vec) -> Self { - // TODO: enable this assertion - // debug_assert!(field_names.len() == field_types.len()); - Self(Arc::new(StructTypeInner { - field_types: field_types.into(), - field_names: field_names.into(), - })) - } - /// Creates a struct type with unnamed fields. pub fn unnamed(fields: Vec) -> Self { Self(Arc::new(StructTypeInner { diff --git a/src/connector/codec/src/decoder/avro/schema.rs b/src/connector/codec/src/decoder/avro/schema.rs index 7e86a1cc11dd1..f523147d8175d 100644 --- a/src/connector/codec/src/decoder/avro/schema.rs +++ b/src/connector/codec/src/decoder/avro/schema.rs @@ -20,7 +20,7 @@ use apache_avro::AvroResult; use itertools::Itertools; use risingwave_common::error::NotImplemented; use risingwave_common::log::LogSuppresser; -use risingwave_common::types::{DataType, Decimal, MapType}; +use risingwave_common::types::{DataType, Decimal, MapType, StructType}; use risingwave_common::{bail, bail_not_implemented}; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; @@ -190,12 +190,13 @@ fn avro_type_mapping( return Ok(DataType::Decimal); } - let struct_fields = fields - .iter() - .map(|f| avro_type_mapping(&f.schema, map_handling)) - .collect::>()?; - let struct_names = fields.iter().map(|f| f.name.clone()).collect_vec(); - DataType::new_struct(struct_fields, struct_names) + StructType::new( + fields + .iter() + .map(|f| Ok((&f.name, avro_type_mapping(&f.schema, map_handling)?))) + .collect::>>()?, + ) + .into() } Schema::Array(item_schema) => { let item_type = avro_type_mapping(item_schema.as_ref(), map_handling)?; @@ -225,7 +226,7 @@ fn avro_type_mapping( // Note: Avro union's variant tag is type name, not field name (unlike Rust enum, or Protobuf oneof). // XXX: do we need to introduce union.handling.mode? - let (fields, field_names) = union_schema + let fields = union_schema .variants() .iter() // null will mean the whole struct is null @@ -233,13 +234,13 @@ fn avro_type_mapping( .map(|variant| { avro_type_mapping(variant, map_handling).and_then(|t| { let name = avro_schema_to_struct_field_name(variant)?; - Ok((t, name)) + Ok((name, t)) }) }) - .process_results(|it| it.unzip::<_, _, Vec<_>, Vec<_>>()) + .try_collect::<_, Vec<_>, _>() .context("failed to convert Avro union to struct")?; - DataType::new_struct(fields, field_names) + StructType::new(fields).into() } } } diff --git a/src/connector/codec/src/decoder/protobuf/parser.rs b/src/connector/codec/src/decoder/protobuf/parser.rs index 852fa9cca48d6..f249d7db72e13 100644 --- a/src/connector/codec/src/decoder/protobuf/parser.rs +++ b/src/connector/codec/src/decoder/protobuf/parser.rs @@ -17,7 +17,8 @@ use itertools::Itertools; use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value}; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{ - DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64, + DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, StructType, ToOwnedDatum, + F32, F64, }; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; use thiserror::Error; @@ -257,10 +258,14 @@ fn protobuf_type_mapping( } else { let fields = m .fields() - .map(|f| protobuf_type_mapping(&f, parse_trace)) - .try_collect()?; - let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); - DataType::new_struct(fields, field_names) + .map(|f| { + Ok(( + f.name().to_string(), + protobuf_type_mapping(&f, parse_trace)?, + )) + }) + .try_collect::<_, Vec<_>, _>()?; + StructType::new(fields).into() } } Kind::Enum(_) => DataType::Varchar, diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 86631767998bd..59aa42be0e79b 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -14,7 +14,7 @@ use itertools::Itertools; use risingwave_common::catalog::{ColumnDesc, ColumnId, PG_CATALOG_SCHEMA_NAME}; -use risingwave_common::types::{DataType, MapType}; +use risingwave_common::types::{DataType, MapType, StructType}; use risingwave_common::util::iter_util::zip_eq_fast; use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented}; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDescVersion}; @@ -1008,13 +1008,13 @@ pub fn bind_data_type(data_type: &AstDataType) -> Result { AstDataType::Char(..) => { bail_not_implemented!("CHAR is not supported, please use VARCHAR instead") } - AstDataType::Struct(types) => DataType::new_struct( + AstDataType::Struct(types) => StructType::new( types .iter() - .map(|f| bind_data_type(&f.data_type)) + .map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?))) .collect::>>()?, - types.iter().map(|f| f.name.real_value()).collect_vec(), - ), + ) + .into(), AstDataType::Map(kv) => { let key = bind_data_type(&kv.0)?; let value = bind_data_type(&kv.1)?; diff --git a/src/frontend/src/binder/update.rs b/src/frontend/src/binder/update.rs index a2038a4d471d9..d16ece284fa3b 100644 --- a/src/frontend/src/binder/update.rs +++ b/src/frontend/src/binder/update.rs @@ -17,7 +17,7 @@ use std::collections::{BTreeMap, HashMap}; use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::catalog::{Schema, TableVersionId}; -use risingwave_common::types::DataType; +use risingwave_common::types::StructType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem}; @@ -206,10 +206,12 @@ impl Binder { bail_bind_error!("number of columns does not match number of values"); } - let target_type = DataType::new_struct( - ids.iter().map(|id| id.return_type()).collect(), - id.iter().map(|id| id.real_value()).collect(), - ); + let target_type = StructType::new( + id.iter() + .zip_eq_fast(ids) + .map(|(id, expr)| (id.real_value(), expr.return_type())), + ) + .into(); let expr = expr.cast_assign(target_type)?; exprs.push(expr); diff --git a/src/frontend/src/expr/subquery.rs b/src/frontend/src/expr/subquery.rs index fc904638790e2..6f5b419a46d64 100644 --- a/src/frontend/src/expr/subquery.rs +++ b/src/frontend/src/expr/subquery.rs @@ -97,9 +97,7 @@ impl Expr for Subquery { StructType::unnamed(self.query.data_types()) } else { StructType::new( - (schema.fields().iter().cloned()) - .map(|f| (f.name, f.data_type)) - .collect(), + (schema.fields().iter().cloned()).map(|f| (f.name, f.data_type)), ) }; DataType::Struct(struct_type) diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index b81d2b4514edf..6d529e9ccb356 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -14,7 +14,7 @@ use anyhow::Context; use risingwave_common::catalog::FunctionId; -use risingwave_common::types::DataType; +use risingwave_common::types::StructType; use risingwave_expr::sig::{CreateFunctionOptions, UdfKind}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; @@ -83,9 +83,9 @@ pub async fn handle_create_function( // return type is a struct for multiple columns let it = columns .into_iter() - .map(|c| bind_data_type(&c.data_type).map(|ty| (ty, c.name.real_value()))); - let (datatypes, names) = itertools::process_results(it, |it| it.unzip())?; - return_type = DataType::new_struct(datatypes, names); + .map(|c| bind_data_type(&c.data_type).map(|ty| (c.name.real_value(), ty))); + let fields = it.try_collect::<_, Vec<_>, _>()?; + return_type = StructType::new(fields).into(); } Kind::Table(TableFunction {}) } diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index c733f603a3c44..b48b06942005e 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -16,7 +16,7 @@ use std::collections::HashMap; use fancy_regex::Regex; use risingwave_common::catalog::FunctionId; -use risingwave_common::types::DataType; +use risingwave_common::types::{DataType, StructType}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::parser::{Parser, ParserError}; @@ -188,15 +188,11 @@ pub async fn handle_create_sql_function( return_type = bind_data_type(&columns[0].data_type)?; } else { // return type is a struct for multiple columns - let datatypes = columns + let fields = columns .iter() - .map(|c| bind_data_type(&c.data_type)) + .map(|c| Ok((c.name.real_value(), bind_data_type(&c.data_type)?))) .collect::>>()?; - let names = columns - .iter() - .map(|c| c.name.real_value()) - .collect::>(); - return_type = DataType::new_struct(datatypes, names); + return_type = StructType::new(fields).into(); } Kind::Table(TableFunction {}) } diff --git a/src/tests/sqlsmith/src/sql_gen/expr.rs b/src/tests/sqlsmith/src/sql_gen/expr.rs index 4625727f67dca..f5d93d2b6d6f9 100644 --- a/src/tests/sqlsmith/src/sql_gen/expr.rs +++ b/src/tests/sqlsmith/src/sql_gen/expr.rs @@ -168,8 +168,7 @@ impl SqlGenerator<'_, R> { DataType::Struct(StructType::new( STRUCT_FIELD_NAMES[0..num_fields] .iter() - .map(|s| (s.to_string(), self.gen_data_type_inner(depth))) - .collect(), + .map(|s| (s.to_string(), self.gen_data_type_inner(depth))), )) }