Skip to content

Commit

Permalink
refactor(common): consolidate StructType constructors (Part 2/2) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangjinwu authored Nov 27, 2024
1 parent 0669783 commit 3faa0bf
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/common/src/array/arrow/arrow_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ pub trait FromArrow {
fields
.iter()
.map(|f| Ok((f.name().clone(), self.from_field(f)?)))
.try_collect::<_, _, ArrayError>()?,
.try_collect::<_, Vec<_>, ArrayError>()?,
))
}

Expand Down
8 changes: 3 additions & 5 deletions src/common/src/catalog/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use risingwave_pb::plan_common::{

use super::{row_id_column_desc, rw_timestamp_column_desc, USER_COLUMN_ID_OFFSET};
use crate::catalog::{cdc_table_name_column_desc, offset_column_desc, Field, ROW_ID_COLUMN_ID};
use crate::types::DataType;
use crate::types::{DataType, StructType};
use crate::util::value_encoding::DatumToProtoExt;

/// Column ID is the unique identifier of a column in a table. Different from table ID, column ID is
Expand Down Expand Up @@ -270,10 +270,8 @@ impl ColumnDesc {
type_name: &str,
fields: Vec<ColumnDesc>,
) -> Self {
let data_type = DataType::new_struct(
fields.iter().map(|f| f.data_type.clone()).collect_vec(),
fields.iter().map(|f| f.name.clone()).collect_vec(),
);
let data_type =
StructType::new(fields.iter().map(|f| (&f.name, f.data_type.clone()))).into();
Self {
data_type,
column_id: ColumnId::new(column_id),
Expand Down
11 changes: 6 additions & 5 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use postgres_types::{FromSql, IsNull, ToSql, Type};
use risingwave_common_estimate_size::{EstimateSize, ZeroHeapSize};
use risingwave_pb::data::data_type::PbTypeName;
use risingwave_pb::data::PbDataType;
use rw_iter_util::ZipEqFast as _;
use serde::{Deserialize, Serialize, Serializer};
use strum_macros::EnumDiscriminants;
use thiserror_ext::AsReport;
Expand Down Expand Up @@ -241,7 +242,11 @@ impl From<&PbDataType> for DataType {
PbTypeName::Struct => {
let fields: Vec<DataType> = proto.field_type.iter().map(|f| f.into()).collect_vec();
let field_names: Vec<String> = proto.field_names.iter().cloned().collect_vec();
DataType::new_struct(fields, field_names)
if proto.field_names.is_empty() {
StructType::unnamed(fields).into()
} else {
StructType::new(field_names.into_iter().zip_eq_fast(fields)).into()
}
}
PbTypeName::List => DataType::List(
// The first (and only) item is the list element type.
Expand Down Expand Up @@ -405,10 +410,6 @@ impl DataType {
}
}

pub fn new_struct(fields: Vec<DataType>, field_names: Vec<String>) -> Self {
Self::Struct(StructType::from_parts(field_names, fields))
}

pub fn as_struct(&self) -> &StructType {
match self {
DataType::Struct(t) => t,
Expand Down
18 changes: 5 additions & 13 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ struct StructTypeInner {

impl StructType {
/// Creates a struct type with named fields.
pub fn new(named_fields: Vec<(impl Into<String>, DataType)>) -> Self {
let mut field_types = Vec::with_capacity(named_fields.len());
let mut field_names = Vec::with_capacity(named_fields.len());
for (name, ty) in named_fields {
pub fn new(named_fields: impl IntoIterator<Item = (impl Into<String>, DataType)>) -> Self {
let iter = named_fields.into_iter();
let mut field_types = Vec::with_capacity(iter.size_hint().0);
let mut field_names = Vec::with_capacity(iter.size_hint().0);
for (name, ty) in iter {
field_names.push(name.into());
field_types.push(ty);
}
Expand All @@ -70,15 +71,6 @@ impl StructType {
}))
}

pub(super) fn from_parts(field_names: Vec<String>, field_types: Vec<DataType>) -> Self {
// TODO: enable this assertion
// debug_assert!(field_names.len() == field_types.len());
Self(Arc::new(StructTypeInner {
field_types: field_types.into(),
field_names: field_names.into(),
}))
}

/// Creates a struct type with unnamed fields.
pub fn unnamed(fields: Vec<DataType>) -> Self {
Self(Arc::new(StructTypeInner {
Expand Down
23 changes: 12 additions & 11 deletions src/connector/codec/src/decoder/avro/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use apache_avro::AvroResult;
use itertools::Itertools;
use risingwave_common::error::NotImplemented;
use risingwave_common::log::LogSuppresser;
use risingwave_common::types::{DataType, Decimal, MapType};
use risingwave_common::types::{DataType, Decimal, MapType, StructType};
use risingwave_common::{bail, bail_not_implemented};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};

Expand Down Expand Up @@ -190,12 +190,13 @@ fn avro_type_mapping(
return Ok(DataType::Decimal);
}

let struct_fields = fields
.iter()
.map(|f| avro_type_mapping(&f.schema, map_handling))
.collect::<anyhow::Result<_>>()?;
let struct_names = fields.iter().map(|f| f.name.clone()).collect_vec();
DataType::new_struct(struct_fields, struct_names)
StructType::new(
fields
.iter()
.map(|f| Ok((&f.name, avro_type_mapping(&f.schema, map_handling)?)))
.collect::<anyhow::Result<Vec<_>>>()?,
)
.into()
}
Schema::Array(item_schema) => {
let item_type = avro_type_mapping(item_schema.as_ref(), map_handling)?;
Expand Down Expand Up @@ -225,21 +226,21 @@ fn avro_type_mapping(
// Note: Avro union's variant tag is type name, not field name (unlike Rust enum, or Protobuf oneof).

// XXX: do we need to introduce union.handling.mode?
let (fields, field_names) = union_schema
let fields = union_schema
.variants()
.iter()
// null will mean the whole struct is null
.filter(|variant| !matches!(variant, &&Schema::Null))
.map(|variant| {
avro_type_mapping(variant, map_handling).and_then(|t| {
let name = avro_schema_to_struct_field_name(variant)?;
Ok((t, name))
Ok((name, t))
})
})
.process_results(|it| it.unzip::<_, _, Vec<_>, Vec<_>>())
.try_collect::<_, Vec<_>, _>()
.context("failed to convert Avro union to struct")?;

DataType::new_struct(fields, field_names)
StructType::new(fields).into()
}
}
}
Expand Down
15 changes: 10 additions & 5 deletions src/connector/codec/src/decoder/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use itertools::Itertools;
use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{
DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64,
DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, StructType, ToOwnedDatum,
F32, F64,
};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};
use thiserror::Error;
Expand Down Expand Up @@ -257,10 +258,14 @@ fn protobuf_type_mapping(
} else {
let fields = m
.fields()
.map(|f| protobuf_type_mapping(&f, parse_trace))
.try_collect()?;
let field_names = m.fields().map(|f| f.name().to_string()).collect_vec();
DataType::new_struct(fields, field_names)
.map(|f| {
Ok((
f.name().to_string(),
protobuf_type_mapping(&f, parse_trace)?,
))
})
.try_collect::<_, Vec<_>, _>()?;
StructType::new(fields).into()
}
}
Kind::Enum(_) => DataType::Varchar,
Expand Down
10 changes: 5 additions & 5 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use itertools::Itertools;
use risingwave_common::catalog::{ColumnDesc, ColumnId, PG_CATALOG_SCHEMA_NAME};
use risingwave_common::types::{DataType, MapType};
use risingwave_common::types::{DataType, MapType, StructType};
use risingwave_common::util::iter_util::zip_eq_fast;
use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDescVersion};
Expand Down Expand Up @@ -1008,13 +1008,13 @@ pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
AstDataType::Char(..) => {
bail_not_implemented!("CHAR is not supported, please use VARCHAR instead")
}
AstDataType::Struct(types) => DataType::new_struct(
AstDataType::Struct(types) => StructType::new(
types
.iter()
.map(|f| bind_data_type(&f.data_type))
.map(|f| Ok((f.name.real_value(), bind_data_type(&f.data_type)?)))
.collect::<Result<Vec<_>>>()?,
types.iter().map(|f| f.name.real_value()).collect_vec(),
),
)
.into(),
AstDataType::Map(kv) => {
let key = bind_data_type(&kv.0)?;
let value = bind_data_type(&kv.1)?;
Expand Down
12 changes: 7 additions & 5 deletions src/frontend/src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::collections::{BTreeMap, HashMap};
use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::types::DataType;
use risingwave_common::types::StructType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};

Expand Down Expand Up @@ -206,10 +206,12 @@ impl Binder {
bail_bind_error!("number of columns does not match number of values");
}

let target_type = DataType::new_struct(
ids.iter().map(|id| id.return_type()).collect(),
id.iter().map(|id| id.real_value()).collect(),
);
let target_type = StructType::new(
id.iter()
.zip_eq_fast(ids)
.map(|(id, expr)| (id.real_value(), expr.return_type())),
)
.into();
let expr = expr.cast_assign(target_type)?;

exprs.push(expr);
Expand Down
4 changes: 1 addition & 3 deletions src/frontend/src/expr/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ impl Expr for Subquery {
StructType::unnamed(self.query.data_types())
} else {
StructType::new(
(schema.fields().iter().cloned())
.map(|f| (f.name, f.data_type))
.collect(),
(schema.fields().iter().cloned()).map(|f| (f.name, f.data_type)),
)
};
DataType::Struct(struct_type)
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use anyhow::Context;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::DataType;
use risingwave_common::types::StructType;
use risingwave_expr::sig::{CreateFunctionOptions, UdfKind};
use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
use risingwave_pb::catalog::Function;
Expand Down Expand Up @@ -83,9 +83,9 @@ pub async fn handle_create_function(
// return type is a struct for multiple columns
let it = columns
.into_iter()
.map(|c| bind_data_type(&c.data_type).map(|ty| (ty, c.name.real_value())));
let (datatypes, names) = itertools::process_results(it, |it| it.unzip())?;
return_type = DataType::new_struct(datatypes, names);
.map(|c| bind_data_type(&c.data_type).map(|ty| (c.name.real_value(), ty)));
let fields = it.try_collect::<_, Vec<_>, _>()?;
return_type = StructType::new(fields).into();
}
Kind::Table(TableFunction {})
}
Expand Down
12 changes: 4 additions & 8 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::collections::HashMap;

use fancy_regex::Regex;
use risingwave_common::catalog::FunctionId;
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, StructType};
use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction};
use risingwave_pb::catalog::Function;
use risingwave_sqlparser::parser::{Parser, ParserError};
Expand Down Expand Up @@ -188,15 +188,11 @@ pub async fn handle_create_sql_function(
return_type = bind_data_type(&columns[0].data_type)?;
} else {
// return type is a struct for multiple columns
let datatypes = columns
let fields = columns
.iter()
.map(|c| bind_data_type(&c.data_type))
.map(|c| Ok((c.name.real_value(), bind_data_type(&c.data_type)?)))
.collect::<Result<Vec<_>>>()?;
let names = columns
.iter()
.map(|c| c.name.real_value())
.collect::<Vec<_>>();
return_type = DataType::new_struct(datatypes, names);
return_type = StructType::new(fields).into();
}
Kind::Table(TableFunction {})
}
Expand Down
3 changes: 1 addition & 2 deletions src/tests/sqlsmith/src/sql_gen/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ impl<R: Rng> SqlGenerator<'_, R> {
DataType::Struct(StructType::new(
STRUCT_FIELD_NAMES[0..num_fields]
.iter()
.map(|s| (s.to_string(), self.gen_data_type_inner(depth)))
.collect(),
.map(|s| (s.to_string(), self.gen_data_type_inner(depth))),
))
}

Expand Down

0 comments on commit 3faa0bf

Please sign in to comment.