Skip to content

Commit

Permalink
feat: support create table with map type
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed Aug 6, 2024
1 parent b6ac6f2 commit aa33f2f
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 38 deletions.
2 changes: 1 addition & 1 deletion e2e_test/batch/distribution_mode.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SET RW_IMPLICIT_FLUSH TO true;
statement ok
SET QUERY_MODE TO distributed;

include ./basic/*.slt.part
include ./basic/**/*.slt.part
include ./duckdb/all.slt.part
include ./order/*.slt.part
include ./join/*.slt.part
Expand Down
2 changes: 1 addition & 1 deletion e2e_test/batch/local_mode.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SET RW_IMPLICIT_FLUSH TO true;
statement ok
SET QUERY_MODE TO local;

include ./basic/*.slt.part
include ./basic/**/*.slt.part
include ./duckdb/all.slt.part
include ./order/*.slt.part
include ./join/*.slt.part
Expand Down
2 changes: 0 additions & 2 deletions e2e_test/batch/types/list.slt.part

This file was deleted.

57 changes: 57 additions & 0 deletions e2e_test/batch/types/map.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;


statement error
create table t (m map (float, float));
----
db error: ERROR: Failed to run the query

Caused by:
invalid map key type: double precision


statement ok
create table t (
m1 map(varchar, float),
m2 map(int, bool),
m3 map(varchar, map(varchar, varchar)),
l map(varchar,int)[],
s struct<m map(varchar, struct<x int>)>,
);


statement ok
insert into t values (
map_from_entries(array['a','b','c'], array[1.0,2.0,3.0]::float[]),
map_from_entries(array[1,2,3], array[true,false,true]),
map_from_entries(array['a','b'],
array[
map_from_entries(array['a1'], array['a2']),
map_from_entries(array['b1'], array['b2'])
]
),
array[
map_from_entries(array['a','b','c'], array[1,2,3]),
map_from_entries(array['d','e','f'], array[4,5,6])
],
row(
map_from_entries(array['a','b','c'], array[row(1),row(2),row(3)]::struct<x int>[])
)
);

query ?????
select * from t;
----
{"a":1,"b":2,"c":3} {"1":t,"2":f,"3":t} {"a":{"a1":a2},"b":{"b1":b2}} {"{\"a\":1,\"b\":2,\"c\":3}","{\"d\":4,\"e\":5,\"f\":6}"} ("{""a"":(1),""b"":(2),""c"":(3)}")


# FIXME: The cast should be supported
statement error
insert into t(m1) values (map_from_entries(array['a','b','c'], array[1,2,3]));
----
db error: ERROR: Failed to run the query

Caused by these errors (recent errors listed first):
1: failed to cast the 1st column
2: cannot cast type "map(character varying,integer)" to "map(character varying,double precision)" in Assign context
1 change: 0 additions & 1 deletion e2e_test/batch/types/struct.slt.part

This file was deleted.

9 changes: 5 additions & 4 deletions src/common/src/types/map_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ impl MapType {
.iter()
.collect_tuple()
.expect("the underlying struct for map must have exactly two fields");
if cfg!(debug_assertions) {
// the field names are not strictly enforced
itertools::assert_equal(struct_type.names(), ["key", "value"]);
}
// the field names are not strictly enforced
// Currently this panics for SELECT * FROM t
// if cfg!(debug_assertions) {
// itertools::assert_equal(struct_type.names(), ["key", "value"]);
// }
Self::from_kv(k.1.clone(), v.1.clone())
}

Expand Down
7 changes: 6 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use itertools::Itertools;
use risingwave_common::catalog::{ColumnDesc, ColumnId, PG_CATALOG_SCHEMA_NAME};
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, MapType};
use risingwave_common::util::iter_util::zip_eq_fast;
use risingwave_common::{bail_no_function, bail_not_implemented, not_implemented};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDescVersion};
Expand Down Expand Up @@ -999,6 +999,11 @@ pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
.collect::<Result<Vec<_>>>()?,
types.iter().map(|f| f.name.real_value()).collect_vec(),
),
AstDataType::Map(kv) => {
let key = bind_data_type(&kv.0)?;
let value = bind_data_type(&kv.1)?;
DataType::Map(MapType::try_from_kv(key, value)?)
}
AstDataType::Custom(qualified_type_name) => {
let idents = qualified_type_name
.0
Expand Down
17 changes: 13 additions & 4 deletions src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::collections::{BTreeMap, HashMap, HashSet};

use anyhow::Context;
use itertools::Itertools;
use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
use risingwave_common::types::DataType;
Expand All @@ -26,6 +27,7 @@ use crate::binder::{Binder, Clause};
use crate::catalog::TableId;
use crate::error::{ErrorCode, Result, RwError};
use crate::expr::{ExprImpl, InputRef};
use crate::handler::create_mv::ordinal;
use crate::user::UserId;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -197,7 +199,7 @@ impl Binder {
let bound_query;
let cast_exprs;

let bounded_column_nums = match source.as_simple_values() {
let bound_column_nums = match source.as_simple_values() {
None => {
bound_query = self.bind_query(source)?;
let actual_types = bound_query.data_types();
Expand Down Expand Up @@ -234,7 +236,7 @@ impl Binder {
cols_to_insert_in_table.len()
};

let (err_msg, default_column_indices) = match num_target_cols.cmp(&bounded_column_nums) {
let (err_msg, default_column_indices) = match num_target_cols.cmp(&bound_column_nums) {
std::cmp::Ordering::Equal => (None, default_column_indices),
std::cmp::Ordering::Greater => {
if has_user_specified_columns {
Expand All @@ -248,7 +250,7 @@ impl Binder {
// insert into t values (7)
// this kind of usage is fine, null values will be provided
// implicitly.
(None, col_indices_to_insert.split_off(bounded_column_nums))
(None, col_indices_to_insert.split_off(bound_column_nums))
}
}
std::cmp::Ordering::Less => {
Expand Down Expand Up @@ -315,7 +317,14 @@ impl Binder {
return exprs
.into_iter()
.zip_eq_fast(expected_types.iter().take(expr_num))
.map(|(e, t)| e.cast_assign(t.clone()).map_err(Into::into))
.enumerate()
.map(|(i, (e, t))| {
e.cast_assign(t.clone())
.with_context(|| {
format!("failed to cast the {} column", ordinal(i + 1))
})
.map_err(Into::into)
})
.try_collect();
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ fn data_type_to_alias(data_type: &AstDataType) -> Option<String> {
AstDataType::Jsonb => "jsonb".to_string(),
AstDataType::Array(ty) => return data_type_to_alias(ty),
AstDataType::Custom(ty) => format!("{}", ty),
AstDataType::Struct(_) => {
// Note: Postgres doesn't have anonymous structs
AstDataType::Struct(_) | AstDataType::Map(_) => {
// Note: Postgres doesn't have maps and anonymous structs
return None;
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/create_mv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ It only indicates the physical clustering of the data, which may improve the per
))
}

fn ordinal(i: usize) -> String {
pub fn ordinal(i: usize) -> String {
let s = i.to_string();
let suffix = if s.ends_with('1') && !s.ends_with("11") {
"st"
Expand Down
5 changes: 5 additions & 0 deletions src/sqlparser/src/ast/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum DataType {
Array(Box<DataType>),
/// Structs
Struct(Vec<StructField>),
/// Map(key_type, value_type)
Map(Box<(DataType, DataType)>),
}

impl fmt::Display for DataType {
Expand Down Expand Up @@ -110,6 +112,9 @@ impl fmt::Display for DataType {
DataType::Struct(defs) => {
write!(f, "STRUCT<{}>", display_comma_separated(defs))
}
DataType::Map(kv) => {
write!(f, "MAP({},{})", kv.0, kv.1)
}
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3632,8 +3632,7 @@ impl Parser<'_> {
.parse_next(self)
}

/// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) and convert
/// into an array of that datatype if needed
/// Parse a SQL datatype (in the context of a CREATE TABLE statement for example)
pub fn parse_data_type(&mut self) -> PResult<DataType> {
parser_v2::data_type(self)
}
Expand Down
57 changes: 38 additions & 19 deletions src/sqlparser/src/parser_v2/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ where
/// Consume a data type definition.
///
/// The parser is the main entry point for data type parsing.
///
/// Note: in recursion, we should use `data_type_stateful` instead of `data_type`,
/// otherwise the type parameter will recurse like `Stateful<Stateful<Stateful<...>>>`.
/// Also note that we cannot use `Parser<'_>` directly to avoid misuse, because we need
/// generics `<S>` to parameterize over `Parser<'_>` and `Stateful<Parser<'_>>`.
pub fn data_type<S>(input: &mut S) -> PResult<DataType>
where
S: TokenStream,
Expand Down Expand Up @@ -166,6 +171,14 @@ fn data_type_stateful_inner<S>(input: &mut StatefulStream<S>) -> PResult<DataTyp
where
S: TokenStream,
{
trace(
"data_type_inner",
alt((keyword_datatype, non_keyword_datatype)),
)
.parse_next(input)
}

fn keyword_datatype<S: TokenStream>(input: &mut StatefulStream<S>) -> PResult<DataType> {
let with_time_zone = || {
opt(alt((
(Keyword::WITH, Keyword::TIME, Keyword::ZONE).value(true),
Expand All @@ -186,7 +199,7 @@ where
})
};

let keywords = dispatch! {keyword;
let mut ty = dispatch! {keyword;
Keyword::BOOLEAN | Keyword::BOOL => empty.value(DataType::Boolean),
Keyword::FLOAT => opt(precision_in_range(1..54)).map(DataType::Float),
Keyword::REAL => empty.value(DataType::Real),
Expand All @@ -211,26 +224,32 @@ where
Keyword::NUMERIC | Keyword::DECIMAL | Keyword::DEC => cut_err(precision_and_scale()).map(|(precision, scale)| {
DataType::Decimal(precision, scale)
}),
_ => fail,
_ => fail
};

trace(
"data_type_inner",
alt((
keywords,
trace(
"non_keyword_data_type",
object_name.map(
|name| match name.to_string().to_ascii_lowercase().as_str() {
// PostgreSQL built-in data types that are not keywords.
"jsonb" => DataType::Jsonb,
"regclass" => DataType::Regclass,
"regproc" => DataType::Regproc,
_ => DataType::Custom(name),
},
),
),
)),
ty.parse_next(input)
}

fn non_keyword_datatype<S: TokenStream>(input: &mut StatefulStream<S>) -> PResult<DataType> {
let type_name = object_name.parse_next(input)?;
match type_name.to_string().to_ascii_lowercase().as_str() {
// PostgreSQL built-in data types that are not keywords.
"jsonb" => Ok(DataType::Jsonb),
"regclass" => Ok(DataType::Regclass),
"regproc" => Ok(DataType::Regproc),
"map" => cut_err(map_type_arguments).parse_next(input),
_ => Ok(DataType::Custom(type_name)),
}
}

fn map_type_arguments<S: TokenStream>(input: &mut StatefulStream<S>) -> PResult<DataType> {
delimited(
Token::LParen,
// key is string or integral type. value is arbitrary type.
// We don't validate here, but in binder bind_data_type
seq!(keyword_datatype, _:Token::Comma, data_type_stateful),
Token::RParen,
)
.map(|(k, v)| DataType::Map(Box::new((k, v))))
.parse_next(input)
}

0 comments on commit aa33f2f

Please sign in to comment.