diff --git a/src/connector/src/sink/postgres.rs b/src/connector/src/sink/postgres.rs index 3aa6708848d9f..b541744a13d19 100644 --- a/src/connector/src/sink/postgres.rs +++ b/src/connector/src/sink/postgres.rs @@ -125,7 +125,75 @@ impl Sink for PostgresSink { "Primary key not defined for upsert Postgres sink (please define in `primary_key` field)"))); } - // TODO(kwannoel): Add more validation - see sqlserver. Check type compatibility, etc. + for field in self.schema.fields() { + check_data_type_compatibility(field.data_type())?; + } + + // Verify pg table schema matches rw table schema, and pk indices are valid + let table_name = &self.config.table; + let connection_string = format!( + "host={} port={} user={} password={} dbname={}", + self.config.host, + self.config.port, + self.config.user, + self.config.password, + self.config.database + ); + let (client, connection) = + tokio_postgres::connect(&connection_string, tokio_postgres::NoTls) + .await + .context("Failed to connect to Postgres for Sinking")?; + tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("connection error: {}", e); + } + }); + + let result = client + .query( + " + SELECT a.attname as col_name, i.indisprimary AS is_pk + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = $1::regclass", + &[&table_name], + ) + .await + .context("Failed to query Postgres for Sinking")?; + + let mut pg_schema = BTreeMap::new(); + for row in result { + let col_name: String = row.get(0); + let is_pk: bool = row.get(1); + pg_schema.insert(col_name, is_pk); + } + + for (i, field) in self.schema.fields().iter().enumerate() { + let col_name = &field.name; + let is_pk = pg_schema.get(col_name); + match is_pk { + None => return Err(SinkError::Config(anyhow!( + "Column `{}` not found in Postgres table `{}`", + col_name, + table_name + ))), + Some(is_pk) => + match (*is_pk, self.pk_indices.contains(&i)) { + (false, false) | (true, true) => continue, + (false, true) => return Err(SinkError::Config(anyhow!( + "Column `{}` in Postgres table `{}` is not a primary key, but RW schema defines it as a primary key", + col_name, + table_name + ))), + (true, false) => return Err(SinkError::Config(anyhow!( + "Column `{}` in Postgres table `{}` is a primary key, but RW schema does not define it as a primary key", + col_name, + table_name + ))), + } + } + } Ok(()) } @@ -223,10 +291,12 @@ impl PostgresSinkWriter { .map(|i| schema.fields()[*i].data_type().to_pg_type()) .collect_vec(); let delete_sql = create_delete_sql(&schema, &config.table, &pk_indices); - Some(client - .prepare_typed(&delete_sql, &delete_types) - .await - .context("Failed to prepare delete statement")?) + Some( + client + .prepare_typed(&delete_sql, &delete_types) + .await + .context("Failed to prepare delete statement")?, + ) }; let merge_statement = if is_append_only { @@ -238,10 +308,12 @@ impl PostgresSinkWriter { .map(|field| field.data_type().to_pg_type()) .collect_vec(); let merge_sql = create_merge_sql(&schema, &config.table, &pk_indices); - Some(client - .prepare_typed(&merge_sql, &merge_types) - .await - .context("Failed to prepare merge statement")?) + Some( + client + .prepare_typed(&merge_sql, &merge_types) + .await + .context("Failed to prepare merge statement")?, + ) }; let writer = Self { @@ -345,7 +417,7 @@ fn data_type_not_supported(data_type_name: &str) -> SinkError { ))) } -fn check_data_type_compatibility(data_type: &DataType) -> Result<()> { +fn check_data_type_compatibility(data_type: DataType) -> Result<()> { match data_type { DataType::Boolean | DataType::Int16 @@ -359,11 +431,11 @@ fn check_data_type_compatibility(data_type: &DataType) -> Result<()> { | DataType::Time | DataType::Timestamp | DataType::Timestamptz + | DataType::Jsonb + | DataType::Interval | DataType::Bytea => Ok(()), - DataType::Interval => Err(data_type_not_supported("Interval")), DataType::Struct(_) => Err(data_type_not_supported("Struct")), DataType::List(_) => Err(data_type_not_supported("List")), - DataType::Jsonb => Err(data_type_not_supported("Jsonb")), DataType::Serial => Err(data_type_not_supported("Serial")), DataType::Int256 => Err(data_type_not_supported("Int256")), DataType::Map(_) => Err(data_type_not_supported("Map")),