Skip to content

Commit

Permalink
fix(frontend): check data type in column id generator
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Dec 17, 2024
1 parent 6f14e79 commit d9698d7
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 32 deletions.
21 changes: 21 additions & 0 deletions src/common/src/catalog/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use risingwave_pb::plan_common::{
AdditionalColumn, ColumnDescVersion, DefaultColumnDesc, PbColumnCatalog, PbColumnDesc,
};

use super::schema::FieldLike;
use super::{
iceberg_sequence_num_column_desc, row_id_column_desc, rw_timestamp_column_desc,
USER_COLUMN_ID_OFFSET,
Expand Down Expand Up @@ -523,6 +524,26 @@ impl ColumnCatalog {
}
}

impl FieldLike for ColumnDesc {
fn data_type(&self) -> &DataType {
&self.data_type
}

fn name(&self) -> &str {
&self.name
}
}

impl FieldLike for ColumnCatalog {
fn data_type(&self) -> &DataType {
&self.column_desc.data_type
}

fn name(&self) -> &str {
&self.column_desc.name
}
}

pub fn columns_extend(preserved_columns: &mut Vec<ColumnCatalog>, columns: Vec<ColumnCatalog>) {
debug_assert_eq!(ROW_ID_COLUMN_ID.get_id(), 0);
let mut max_incoming_column_id = ROW_ID_COLUMN_ID.get_id();
Expand Down
2 changes: 1 addition & 1 deletion src/common/src/catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use risingwave_pb::catalog::{
StreamJobStatus as PbStreamJobStatus,
};
use risingwave_pb::plan_common::ColumnDescVersion;
pub use schema::{test_utils as schema_test_utils, Field, FieldDisplay, Schema};
pub use schema::{test_utils as schema_test_utils, Field, FieldDisplay, FieldLike, Schema};
use serde::{Deserialize, Serialize};

use crate::array::DataChunk;
Expand Down
16 changes: 16 additions & 0 deletions src/common/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ impl From<&PbColumnDesc> for Field {
}
}

#[auto_impl::auto_impl(&, &mut)]
pub trait FieldLike {
fn data_type(&self) -> &DataType;
fn name(&self) -> &str;
}

impl FieldLike for Field {
fn data_type(&self) -> &DataType {
&self.data_type
}

fn name(&self) -> &str {
&self.name
}
}

pub struct FieldDisplay<'a>(pub &'a Field);

impl std::fmt::Debug for FieldDisplay<'_> {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/create_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ pub async fn bind_create_source_or_table_with_connector(
// XXX: why do we use col_id_gen here? It doesn't seem to be very necessary.
// XXX: should we also chenge the col id for struct fields?
for c in &mut columns {
c.column_desc.column_id = col_id_gen.generate(c.name())
c.column_desc.column_id = col_id_gen.generate(&*c)
}
debug_assert_column_ids_distinct(&columns);

Expand Down
120 changes: 92 additions & 28 deletions src/frontend/src/handler/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ use fixedbitset::FixedBitSet;
use itertools::Itertools;
use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::catalog::{
CdcTableDesc, ColumnCatalog, ColumnDesc, Engine, TableId, TableVersionId, DEFAULT_SCHEMA_NAME,
INITIAL_TABLE_VERSION_ID, RISINGWAVE_ICEBERG_ROW_ID, ROWID_PREFIX,
CdcTableDesc, ColumnCatalog, ColumnDesc, Engine, FieldLike, TableId, TableVersionId,
DEFAULT_SCHEMA_NAME, INITIAL_TABLE_VERSION_ID, RISINGWAVE_ICEBERG_ROW_ID, ROWID_PREFIX,
};
use risingwave_common::config::MetaBackend;
use risingwave_common::license::Feature;
use risingwave_common::session_config::sink_decouple::SinkDecouple;
use risingwave_common::system_param::reader::SystemParamsRead;
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumToProtoExt;
Expand All @@ -52,9 +53,9 @@ use risingwave_pb::secret::PbSecretRef;
use risingwave_pb::stream_plan::StreamFragmentGraph;
use risingwave_sqlparser::ast::{
CdcTableInfo, ColumnDef, ColumnOption, CompatibleFormatEncode, CreateSink, CreateSinkStatement,
CreateSourceStatement, DataType, DataType as AstDataType, ExplainOptions, Format,
FormatEncodeOptions, Ident, ObjectName, OnConflict, SecretRefAsType, SourceWatermark,
Statement, TableConstraint, WebhookSourceInfo, WithProperties,
CreateSourceStatement, DataType as AstDataType, ExplainOptions, Format, FormatEncodeOptions,
Ident, ObjectName, OnConflict, SecretRefAsType, SourceWatermark, Statement, TableConstraint,
WebhookSourceInfo, WithProperties,
};
use risingwave_sqlparser::parser::{IncludeOption, Parser};
use thiserror_ext::AsReport;
Expand All @@ -77,6 +78,7 @@ use crate::optimizer::plan_node::generic::{CdcScanOptions, SourceNodeKind};
use crate::optimizer::plan_node::{LogicalCdcScan, LogicalSource};
use crate::optimizer::property::{Order, RequiredDist};
use crate::optimizer::{OptimizerContext, OptimizerContextRef, PlanRef, PlanRoot};
use crate::session::current::notice_to_user;
use crate::session::SessionImpl;
use crate::stream_fragmenter::build_graph;
use crate::utils::OverwriteOptions;
Expand All @@ -91,7 +93,7 @@ pub struct ColumnIdGenerator {
/// exists, its ID is reused. Otherwise, a new ID is generated.
///
/// For a new table, this is empty.
pub existing: HashMap<String, ColumnId>,
pub existing: HashMap<String, (ColumnId, DataType)>,

/// The next column ID to generate, used for new columns that do not exist in `existing`.
pub next_column_id: ColumnId,
Expand All @@ -109,7 +111,12 @@ impl ColumnIdGenerator {
let existing = original
.columns()
.iter()
.map(|col| (col.name().to_owned(), col.column_id()))
.map(|col| {
(
col.name().to_owned(),
(col.column_id(), col.data_type().clone()),
)
})
.collect();

let version = original.version().expect("version field not set");
Expand All @@ -130,15 +137,29 @@ impl ColumnIdGenerator {
}
}

/// Generates a new [`ColumnId`] for a column with the given name.
pub fn generate(&mut self, name: &str) -> ColumnId {
if let Some(id) = self.existing.get(name) {
*id
} else {
let id = self.next_column_id;
self.next_column_id = self.next_column_id.next();
id
/// Generates a new [`ColumnId`] for a column with the given field.
pub fn generate(&mut self, field: impl FieldLike) -> ColumnId {
if let Some((id, original_type)) = self.existing.get(field.name()) {
// Intentionally not using `datatype_equals` here because we want nested types to be
// exactly the same, **NOT** ignoring field names as they may be referenced in expressions
// of generated columns or downstream jobs.
if original_type == field.data_type() {
return *id;
} else {
notice_to_user(format!(
"The data type of column \"{}\" has been changed from {} to {}. \
This is currently not supported, even if it could be a compatible change in external systems. \
The original column will be dropped and a new column will be created.",
field.name(),
original_type,
field.data_type()
));
}
}

let id = self.next_column_id;
self.next_column_id = self.next_column_id.next();
id
}

/// Consume this generator and return a [`TableVersion`] for the table to be created or altered.
Expand Down Expand Up @@ -564,7 +585,7 @@ pub(crate) fn gen_create_table_plan(
let definition = context.normalized_sql().to_owned();
let mut columns = bind_sql_columns(&column_defs)?;
for c in &mut columns {
c.column_desc.column_id = col_id_gen.generate(c.name())
c.column_desc.column_id = col_id_gen.generate(&*c)
}

let (_, secret_refs, connection_refs) = context.with_options().clone().into_parts();
Expand Down Expand Up @@ -817,7 +838,7 @@ pub(crate) fn gen_create_table_plan_for_cdc_table(
)?;

for c in &mut columns {
c.column_desc.column_id = col_id_gen.generate(c.name())
c.column_desc.column_id = col_id_gen.generate(&*c)
}

let (mut columns, pk_column_ids, _row_id_index) =
Expand Down Expand Up @@ -1901,7 +1922,8 @@ fn bind_webhook_info(
webhook_info: WebhookSourceInfo,
) -> Result<PbWebhookSourceInfo> {
// validate columns
if columns_defs.len() != 1 || columns_defs[0].data_type.as_ref().unwrap() != &DataType::Jsonb {
if columns_defs.len() != 1 || columns_defs[0].data_type.as_ref().unwrap() != &AstDataType::Jsonb
{
return Err(ErrorCode::InvalidInputSyntax(
"Table with webhook source should have exactly one JSONB column".to_owned(),
)
Expand Down Expand Up @@ -1963,12 +1985,28 @@ mod tests {
use super::*;
use crate::test_utils::{create_proto_file, LocalFrontend, PROTO_FILE_DATA};

struct BrandNewColumn(&'static str);
use BrandNewColumn as B;

impl FieldLike for BrandNewColumn {
fn name(&self) -> &str {
self.0
}

fn data_type(&self) -> &DataType {
unreachable!("for brand new columns, data type will not be accessed")
}
}

#[test]
fn test_col_id_gen() {
fn test_col_id_gen_initial() {
let mut gen = ColumnIdGenerator::new_initial();
assert_eq!(gen.generate("v1"), ColumnId::new(1));
assert_eq!(gen.generate("v2"), ColumnId::new(2));
assert_eq!(gen.generate(B("v1")), ColumnId::new(1));
assert_eq!(gen.generate(B("v2")), ColumnId::new(2));
}

#[test]
fn test_col_id_gen_alter() {
let mut gen = ColumnIdGenerator::new_alter(&TableCatalog {
columns: vec![
ColumnCatalog {
Expand All @@ -1985,16 +2023,42 @@ mod tests {
),
is_hidden: false,
},
ColumnCatalog {
column_desc: ColumnDesc::from_field_with_column_id(
&Field::with_name(
StructType::new([("f1", DataType::Int32)]).into(),
"nested",
),
3,
),
is_hidden: false,
},
],
version: Some(TableVersion::new_initial_for_test(ColumnId::new(2))),
version: Some(TableVersion::new_initial_for_test(ColumnId::new(3))),
..Default::default()
});

assert_eq!(gen.generate("v1"), ColumnId::new(3));
assert_eq!(gen.generate("v2"), ColumnId::new(4));
assert_eq!(gen.generate("f32"), ColumnId::new(1));
assert_eq!(gen.generate("f64"), ColumnId::new(2));
assert_eq!(gen.generate("v3"), ColumnId::new(5));
assert_eq!(gen.generate(B("v1")), ColumnId::new(4));
assert_eq!(gen.generate(B("v2")), ColumnId::new(5));
assert_eq!(
gen.generate(Field::new("f32", DataType::Float32)),
ColumnId::new(1)
);
assert_eq!(
// mismatched data type, will generate a new column id
gen.generate(Field::new("f64", DataType::Float32)),
ColumnId::new(6)
);
assert_eq!(
// mismatched data type, will generate a new column id
// we require the nested data type to be exactly the same
gen.generate(Field::new(
"nested",
StructType::new([("f1", DataType::Int32), ("f2", DataType::Int64)]).into()
)),
ColumnId::new(7)
);
assert_eq!(gen.generate(B("v3")), ColumnId::new(8));
}

#[tokio::test]
Expand Down Expand Up @@ -2086,7 +2150,7 @@ mod tests {
let mut columns = bind_sql_columns(&column_defs)?;
let mut col_id_gen = ColumnIdGenerator::new_initial();
for c in &mut columns {
c.column_desc.column_id = col_id_gen.generate(c.name())
c.column_desc.column_id = col_id_gen.generate(&*c)
}

let pk_names =
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/create_table_as.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub async fn handle_create_as(
.fields()
.iter()
.map(|field| {
let id = col_id_gen.generate(&field.name);
let id = col_id_gen.generate(field);
ColumnCatalog {
column_desc: ColumnDesc::from_field_with_column_id(field, id.get_id()),
is_hidden: false,
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ impl SessionImpl {

pub fn notice_to_user(&self, str: impl Into<String>) {
let notice = str.into();
tracing::trace!("notice to user:{}", notice);
tracing::trace!(notice, "notice to user");
self.notices.write().push(notice);
}

Expand Down

0 comments on commit d9698d7

Please sign in to comment.