diff --git a/Cargo.lock b/Cargo.lock index ef73103f3e020..a7295685d3685 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14180,6 +14180,7 @@ dependencies = [ "futures-util", "generic-array", "getrandom", + "google-cloud-googleapis", "governor", "hashbrown 0.13.2", "hashbrown 0.14.3", @@ -14203,6 +14204,7 @@ dependencies = [ "madsim-tokio", "md-5", "memchr", + "mime_guess", "mio", "moka", "nom", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index dcccbe4ac4edd..533407a63fa51 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -57,6 +57,9 @@ futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } gcp-bigquery-client = "0.18.0" glob = "0.3" +google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } +google-cloud-gax = "0.17.0" +google-cloud-googleapis = "0.12.0" google-cloud-pubsub = "0.23" http = "0.2" hyper = { version = "0.14", features = [ @@ -114,9 +117,6 @@ redis = { version = "0.24.0", features = [ regex = "1.4" reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } -google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } -google-cloud-gax = "0.17.0" -google-cloud-googleapis = "0.12.0" risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index fb38f2db00c4f..a530cd681951d 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -33,6 +33,7 @@ #![feature(try_blocks)] #![feature(error_generic_member_access)] #![feature(register_tool)] +#![feature(assert_matches)] #![register_tool(rw)] #![recursion_limit = "256"] diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 70f8f842664f0..96975c9967538 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::mem; use core::time::Duration; use std::collections::HashMap; use std::sync::Arc; @@ -20,30 +19,36 @@ use std::sync::Arc; use anyhow::anyhow; use async_trait::async_trait; use gcp_bigquery_client::model::query_request::QueryRequest; -use gcp_bigquery_client::model::table_data_insert_all_request::TableDataInsertAllRequest; -use gcp_bigquery_client::model::table_data_insert_all_request_rows::TableDataInsertAllRequestRows; use gcp_bigquery_client::Client; -use google_cloud_gax::grpc::{Request}; -use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows as AppendRowsRequestRows; -use google_cloud_gax::conn::{ConnectionOptions, Environment}; use google_cloud_bigquery::grpc::apiv1::bigquery_client::StreamingWriteClient; use google_cloud_bigquery::grpc::apiv1::conn_pool::{WriteConnectionManager, DOMAIN}; -use google_cloud_googleapis::cloud::bigquery::storage::v1::AppendRowsRequest; +use google_cloud_gax::conn::{ConnectionOptions, Environment}; +use google_cloud_gax::grpc::Request; +use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ + ProtoData, Rows as AppendRowsRequestRows, +}; +use google_cloud_googleapis::cloud::bigquery::storage::v1::{ + AppendRowsRequest, ProtoRows, ProtoSchema, +}; use google_cloud_pubsub::client::google_cloud_auth; use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; +use prost_reflect::MessageDescriptor; +use prost_types::{ + field_descriptor_proto, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, + FileDescriptorSet, +}; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; use risingwave_common::types::DataType; use serde_derive::Deserialize; -use serde_json::Value; -use serde_with::{serde_as, DisplayFromStr}; +use serde_with::serde_as; use url::Url; use uuid::Uuid; use with_options::WithOptions; use yup_oauth2::ServiceAccountKey; -use super::encoder::{JsonEncoder, RowEncoder}; +use super::encoder::{CustomProtoType, ProtoEncoder, ProtoHeader, RowEncoder, SerTo}; use super::writer::LogSinkerOf; use super::{SinkError, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT}; use crate::aws_utils::load_file_descriptor_from_s3; @@ -54,7 +59,10 @@ use crate::sink::{ }; pub const BIGQUERY_SINK: &str = "bigquery"; +pub const CHANGE_TYPE: &str = "_CHANGE_TYPE"; const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4; +const CONNECT_TIMEOUT: Option = Some(Duration::from_secs(30)); +const CONNECTION_TIMEOUT: Option = None; #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] @@ -69,9 +77,6 @@ pub struct BigQueryCommon { pub dataset: String, #[serde(rename = "bigquery.table")] pub table: String, - #[serde(rename = "bigquery.max_batch_rows", default = "default_max_batch_rows")] - #[serde_as(as = "DisplayFromStr")] - pub max_batch_rows: usize, } fn default_max_batch_rows() -> usize { @@ -79,41 +84,45 @@ fn default_max_batch_rows() -> usize { } impl BigQueryCommon { - pub(crate) async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; - + let service_account = serde_json::from_str::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; let client: Client = Client::from_service_account_key(service_account, false) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; Ok(client) } - pub(crate) async fn build_writer_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + async fn build_writer_client( + &self, + aws_auth_props: &AwsAuthProps, + ) -> Result { let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; - - let credentials_file= CredentialsFile::new_from_str(&auth_json).await.unwrap(); + + let credentials_file = CredentialsFile::new_from_str(&auth_json) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; let client = StorageWriterClient::new(credentials_file).await?; Ok(client) } - async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result{ - if let Some(local_path) = &self.local_path{ + async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result { + if let Some(local_path) = &self.local_path { std::fs::read_to_string(local_path) .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err))) - }else if let Some(s3_path) = &self.s3_path { + } else if let Some(s3_path) = &self.s3_path { let url = Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - Ok(String::from_utf8(auth_vec).unwrap()) - }else{ + Ok(String::from_utf8(auth_vec).map_err(|e| SinkError::BigQuery(e.into()))?) + } else { Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))) } } - } #[serde_as] @@ -211,9 +220,7 @@ impl BigQuerySink { DataType::Decimal => Ok("NUMERIC".to_owned()), DataType::Date => Ok("DATE".to_owned()), DataType::Varchar => Ok("STRING".to_owned()), - DataType::Time => Err(SinkError::BigQuery(anyhow::anyhow!( - "Bigquery cannot support Time" - ))), + DataType::Time => Ok("TIME".to_owned()), DataType::Timestamp => Ok("DATETIME".to_owned()), DataType::Timestamptz => Ok("TIMESTAMP".to_owned()), DataType::Interval => Ok("INTERVAL".to_owned()), @@ -258,12 +265,6 @@ impl Sink for BigQuerySink { } async fn validate(&self) -> Result<()> { - if !self.is_append_only { - return Err(SinkError::Config(anyhow!( - "BigQuery sink don't support upsert" - ))); - } - let client = self .config .common @@ -306,8 +307,10 @@ pub struct BigQuerySinkWriter { pk_indices: Vec, client: StorageWriterClient, is_append_only: bool, - insert_request: TableDataInsertAllRequest, - row_encoder: JsonEncoder, + row_encoder: ProtoEncoder, + writer_pb_schema: ProtoSchema, + message_descriptor: MessageDescriptor, + write_stream: String, } impl TryFrom for BigQuerySink { @@ -332,66 +335,126 @@ impl BigQuerySinkWriter { pk_indices: Vec, is_append_only: bool, ) -> Result { - let client = config.common.build_writer_client(&config.aws_auth_props).await?; + let client = config + .common + .build_writer_client(&config.aws_auth_props) + .await?; + let mut descriptor_proto = build_protobuf_schema( + schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + config.common.table.clone(), + 1, + ); + + if !is_append_only { + let field = FieldDescriptorProto { + name: Some(CHANGE_TYPE.to_string()), + number: Some((schema.len() + 1) as i32), + r#type: Some(field_descriptor_proto::Type::String.into()), + ..Default::default() + }; + descriptor_proto.field.push(field); + } + + let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto); + let message_descriptor = descriptor_pool + .get_message_by_name(&config.common.table) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!( + "Can't find message proto {}", + &config.common.table + )) + })?; + let row_encoder = ProtoEncoder::new( + schema.clone(), + None, + message_descriptor.clone(), + ProtoHeader::None, + CustomProtoType::BigQuery, + )?; Ok(Self { + write_stream: format!( + "projects/{}/datasets/{}/tables/{}/streams/_default", + config.common.project, config.common.dataset, config.common.table + ), config, - schema: schema.clone(), + schema, pk_indices, client, is_append_only, - insert_request: TableDataInsertAllRequest::new(), - row_encoder: JsonEncoder::new_with_bigquery(schema, None), + row_encoder, + message_descriptor, + writer_pb_schema: ProtoSchema { + proto_descriptor: Some(descriptor_proto), + }, }) } async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { - let mut insert_vec = Vec::with_capacity(chunk.capacity()); + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); for (op, row) in chunk.rows() { if op != Op::Insert { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))); + continue; } - insert_vec.push(TableDataInsertAllRequestRows { - insert_id: None, - json: Value::Object(self.row_encoder.encode(row)?), - }) - } - self.insert_request - .add_rows(insert_vec) - .map_err(|e| SinkError::BigQuery(e.into()))?; - if self - .insert_request - .len() - .ge(&self.config.common.max_batch_rows) - { - self.insert_data().await?; + + serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?) } + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.client + .append_rows(vec![rows], self.write_stream.clone()) + .await?; Ok(()) } - async fn insert_data(&mut self) -> Result<()> { - if !self.insert_request.is_empty() { - let insert_request = - mem::replace(&mut self.insert_request, TableDataInsertAllRequest::new()); - let request = self - .client - .tabledata() - .insert_all( - &self.config.common.project, - &self.config.common.dataset, - &self.config.common.table, - insert_request, - ) - .await - .map_err(|e| SinkError::BigQuery(e.into()))?; - if let Some(error) = request.insert_errors { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "Insert error: {:?}", - error - ))); - } + async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> { + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); + for (op, row) in chunk.rows() { + let mut pb_row = self.row_encoder.encode(row)?; + let proto_field = self + .message_descriptor + .get_field_by_name(CHANGE_TYPE) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE)) + })?; + match op { + Op::Insert => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("INSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::Delete => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("DELETE".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::UpdateDelete => continue, + Op::UpdateInsert => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("UPSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + }; + + serialized_rows.push(pb_row.ser_to()?) } + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.client + .append_rows(vec![rows], self.write_stream.clone()) + .await?; Ok(()) } } @@ -402,9 +465,7 @@ impl SinkWriter for BigQuerySinkWriter { if self.is_append_only { self.append_only(chunk).await } else { - Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))) + self.upsert(chunk).await } } @@ -417,7 +478,7 @@ impl SinkWriter for BigQuerySinkWriter { } async fn barrier(&mut self, _is_checkpoint: bool) -> Result<()> { - self.insert_data().await + Ok(()) } async fn update_vnode_bitmap(&mut self, _vnode_bitmap: Arc) -> Result<()> { @@ -425,38 +486,71 @@ impl SinkWriter for BigQuerySinkWriter { } } -struct StorageWriterClient{ +struct StorageWriterClient { client: StreamingWriteClient, environment: Environment, } -impl StorageWriterClient{ - pub async fn new(credentials: CredentialsFile) -> Result{ - // let credentials = CredentialsFile::new_from_file("/home/xxhx/winter-dynamics-383822-9690ac19ce78.json".to_string()).await.unwrap(); +impl StorageWriterClient { + pub async fn new(credentials: CredentialsFile) -> Result { let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials( Self::bigquery_grpc_auth_config(), Box::new(credentials), ) - .await.unwrap(); - let conn_options = ConnectionOptions{ - connect_timeout: Some(Duration::from_secs(30)), - timeout: None, + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let conn_options = ConnectionOptions { + connect_timeout: CONNECT_TIMEOUT, + timeout: CONNECTION_TIMEOUT, }; let environment = Environment::GoogleCloud(Box::new(ts_grpc)); - let conn = WriteConnectionManager::new(DEFAULT_GRPC_CHANNEL_NUMS, &environment, DOMAIN, &conn_options).await.unwrap(); + let conn = WriteConnectionManager::new( + DEFAULT_GRPC_CHANNEL_NUMS, + &environment, + DOMAIN, + &conn_options, + ) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; let client = conn.conn(); - Ok(StorageWriterClient{ + Ok(StorageWriterClient { client, environment, }) } + pub async fn append_rows( &mut self, rows: Vec, write_stream: String, ) -> Result<()> { - let trace_id = Uuid::new_v4().hyphenated().to_string(); - let append_req:Vec = rows.into_iter().map(|row| AppendRowsRequest{ write_stream: write_stream.clone(), offset:None, trace_id: trace_id.clone(), missing_value_interpretations: HashMap::default(), rows: Some(row)}).collect(); - let a = self.client.append_rows(Request::new(tokio_stream::iter(append_req))).await; + let trace_id = Uuid::new_v4().hyphenated().to_string(); + let append_req: Vec = rows + .into_iter() + .map(|row| AppendRowsRequest { + write_stream: write_stream.clone(), + offset: None, + trace_id: trace_id.clone(), + missing_value_interpretations: HashMap::default(), + rows: Some(row), + }) + .collect(); + let resp = self + .client + .append_rows(Request::new(tokio_stream::iter(append_req))) + .await + .map_err(|e| SinkError::BigQuery(e.into()))? + .into_inner() + .message() + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + if let Some(i) = resp { + if !i.row_errors.is_empty() { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Insert error {:?}", + i.row_errors + ))); + } + } Ok(()) } @@ -469,11 +563,99 @@ impl StorageWriterClient{ } } +fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::DescriptorPool { + let file_descriptor = FileDescriptorProto { + message_type: vec![desc.clone()], + name: Some("bigquery".to_string()), + ..Default::default() + }; + + prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet { + file: vec![file_descriptor], + }) + .unwrap() +} + +fn build_protobuf_schema<'a>( + fields: impl Iterator, + name: String, + index: i32, +) -> DescriptorProto { + let mut proto = DescriptorProto { + name: Some(name), + ..Default::default() + }; + let mut index_mut = index; + let mut field_vec = vec![]; + let mut struct_vec = vec![]; + for (name, data_type) in fields { + let (field, des_proto) = build_protobuf_field(data_type, index_mut, name.to_string()); + field_vec.push(field); + if let Some(sv) = des_proto { + struct_vec.push(sv); + } + index_mut += 1; + } + proto.field = field_vec; + proto.nested_type = struct_vec; + proto +} + +fn build_protobuf_field( + data_type: &DataType, + index: i32, + name: String, +) -> (FieldDescriptorProto, Option) { + let mut field = FieldDescriptorProto { + name: Some(name.clone()), + number: Some(index), + ..Default::default() + }; + match data_type { + DataType::Boolean => field.r#type = Some(field_descriptor_proto::Type::Bool.into()), + DataType::Int32 => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Int16 | DataType::Int64 => { + field.r#type = Some(field_descriptor_proto::Type::Int64.into()) + } + DataType::Float64 => field.r#type = Some(field_descriptor_proto::Type::Double.into()), + DataType::Decimal => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Date => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Varchar => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Time => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamp => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamptz => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Interval => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Struct(s) => { + field.r#type = Some(field_descriptor_proto::Type::Message.into()); + let name = format!("Struct{}", name); + let sub_proto = build_protobuf_schema(s.iter(), name.clone(), 1); + field.type_name = Some(name); + return (field, Some(sub_proto)); + } + DataType::List(l) => { + let (mut field, proto) = build_protobuf_field(l.as_ref(), index, name.clone()); + field.label = Some(field_descriptor_proto::Label::Repeated.into()); + return (field, proto); + } + DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()), + DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()), + DataType::Float32 | DataType::Int256 => todo!(), + } + (field, None) +} + #[cfg(test)] mod test { + + use std::assert_matches::assert_matches; + + use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, StructType}; - use crate::sink::big_query::BigQuerySink; + use crate::sink::big_query::{ + build_protobuf_descriptor_pool, build_protobuf_schema, BigQuerySink, + }; #[tokio::test] async fn test_type_check() { @@ -493,4 +675,63 @@ mod test { big_query_type_string ); } + + #[tokio::test] + async fn test_schema_check() { + let schema = Schema { + fields: vec![ + Field::with_name(DataType::Int64, "v1"), + Field::with_name(DataType::Float64, "v2"), + Field::with_name( + DataType::List(Box::new(DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::List(Box::new(DataType::Int64))), + ( + "v3".to_owned(), + DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::Int64), + ("v2".to_owned(), DataType::Int64), + ])), + ), + ])))), + "v3", + ), + ], + }; + let fields = schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)); + let desc = build_protobuf_schema(fields, "t1".to_string(), 1); + let pool = build_protobuf_descriptor_pool(&desc); + let t1_message = pool.get_message_by_name("t1").unwrap(); + assert_matches!( + t1_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + t1_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Double + ); + assert_matches!( + t1_message.get_field_by_name("v3").unwrap().kind(), + prost_reflect::Kind::Message(_) + ); + + let v3_message = pool.get_message_by_name("t1.Structv3").unwrap(); + assert_matches!( + v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert!(v3_message.get_field_by_name("v1").unwrap().is_list()); + + let v3_v3_message = pool.get_message_by_name("t1.Structv3.Structv3").unwrap(); + assert_matches!( + v3_v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + v3_v3_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + } } diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 64a06ff70770f..006500c60914d 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -114,19 +114,6 @@ impl JsonEncoder { } } - pub fn new_with_bigquery(schema: Schema, col_indices: Option>) -> Self { - Self { - schema, - col_indices, - time_handling_mode: TimeHandlingMode::Milli, - date_handling_mode: DateHandlingMode::String, - timestamp_handling_mode: TimestampHandlingMode::String, - timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, - custom_json_type: CustomJsonType::BigQuery, - kafka_connect: None, - } - } - pub fn with_kafka_connect(self, kafka_connect: KafkaConnectParams) -> Self { Self { kafka_connect: Some(Arc::new(kafka_connect)), @@ -204,14 +191,7 @@ fn datum_to_json_object( ) -> ArrayResult { let scalar_ref = match datum { None => { - if let CustomJsonType::BigQuery = custom_json_type - && matches!(field.data_type(), DataType::List(_)) - { - // Bigquery need to convert null of array to empty array https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types - return Ok(Value::Array(vec![])); - } else { - return Ok(Value::Null); - } + return Ok(Value::Null); } Some(datum) => datum, }; @@ -259,7 +239,7 @@ fn datum_to_json_object( } json!(v_string) } - CustomJsonType::Es | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Es | CustomJsonType::None => { json!(v.to_text()) } }, @@ -311,7 +291,7 @@ fn datum_to_json_object( } (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match custom_json_type { CustomJsonType::Es | CustomJsonType::StarRocks(_) => JsonbVal::from(jsonb_ref).take(), - CustomJsonType::Doris(_) | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Doris(_) | CustomJsonType::None => { json!(jsonb_ref.to_string()) } }, @@ -362,7 +342,7 @@ fn datum_to_json_object( "starrocks can't support struct".to_string(), )); } - CustomJsonType::Es | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Es | CustomJsonType::None => { let mut map = Map::with_capacity(st.len()); for (sub_datum_ref, sub_field) in struct_ref.iter_fields_ref().zip_eq_debug( st.iter() diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 34dc4c8886448..4b4807f291bc0 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -144,7 +144,11 @@ pub enum CustomJsonType { Es, // starrocks' need jsonb is struct StarRocks(HashMap), - // bigquery need null array -> [] + None, +} + +#[derive(Clone)] +pub enum CustomProtoType { BigQuery, None, } diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index a5f1090dbafaf..4d464e488ece7 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -22,7 +22,7 @@ use risingwave_common::row::Row; use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl, StructType}; use risingwave_common::util::iter_util::ZipEqDebug; -use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; +use super::{CustomProtoType, FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; type Result = std::result::Result; @@ -31,6 +31,7 @@ pub struct ProtoEncoder { col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, + custom_proto_type: CustomProtoType, } #[derive(Debug, Clone, Copy)] @@ -49,6 +50,7 @@ impl ProtoEncoder { col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, + custom_proto_type: CustomProtoType, ) -> SinkResult { match &col_indices { Some(col_indices) => validate_fields( @@ -57,6 +59,7 @@ impl ProtoEncoder { (f.name.as_str(), &f.data_type) }), &descriptor, + custom_proto_type.clone(), )?, None => validate_fields( schema @@ -64,6 +67,7 @@ impl ProtoEncoder { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &descriptor, + custom_proto_type.clone(), )?, }; @@ -72,12 +76,13 @@ impl ProtoEncoder { col_indices, descriptor, header, + custom_proto_type, }) } } pub struct ProtoEncoded { - message: DynamicMessage, + pub message: DynamicMessage, header: ProtoHeader, } @@ -103,6 +108,7 @@ impl RowEncoder for ProtoEncoder { ((f.name.as_str(), &f.data_type), row.datum_at(idx)) }), &self.descriptor, + self.custom_proto_type.clone(), ) .map_err(Into::into) .map(|m| ProtoEncoded { @@ -180,9 +186,19 @@ trait MaybeData: std::fmt::Debug { fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result; - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result; - - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result; + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result; + + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result; } impl MaybeData for () { @@ -192,12 +208,22 @@ impl MaybeData for () { Ok(self) } - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { - validate_fields(st.iter(), pb) + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { + validate_fields(st.iter(), pb, custom_proto_type) } - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { - encode_field(elem, (), pb, true) + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { + encode_field(elem, (), pb, true, custom_proto_type) } } @@ -213,13 +239,27 @@ impl MaybeData for ScalarRefImpl<'_> { f(self) } - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { let d = self.into_struct(); - let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?; + let message = encode_fields( + st.iter().zip_eq_debug(d.iter_fields_ref()), + pb, + custom_proto_type, + )?; Ok(Value::Message(message)) } - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { let d = self.into_list(); let vs = d .iter() @@ -231,6 +271,7 @@ impl MaybeData for ScalarRefImpl<'_> { })?, pb, true, + custom_proto_type.clone(), ) }) .try_collect()?; @@ -241,6 +282,7 @@ impl MaybeData for ScalarRefImpl<'_> { fn validate_fields<'a>( fields: impl Iterator, descriptor: &MessageDescriptor, + custom_proto_type: CustomProtoType, ) -> Result<()> { for (name, t) in fields { let Some(proto_field) = descriptor.get_field_by_name(name) else { @@ -249,7 +291,8 @@ fn validate_fields<'a>( if proto_field.cardinality() == prost_reflect::Cardinality::Required { return Err(FieldEncodeError::new("`required` not supported").with_name(name)); } - encode_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?; + encode_field(t, (), &proto_field, false, custom_proto_type.clone()) + .map_err(|e| e.with_name(name))?; } Ok(()) } @@ -257,14 +300,15 @@ fn validate_fields<'a>( fn encode_fields<'a>( fields_with_datums: impl Iterator)>, descriptor: &MessageDescriptor, + custom_proto_type: CustomProtoType, ) -> Result { let mut message = DynamicMessage::new(descriptor.clone()); for ((name, t), d) in fields_with_datums { let proto_field = descriptor.get_field_by_name(name).unwrap(); // On `null`, simply skip setting the field. if let Some(scalar) = d { - let value = - encode_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?; + let value = encode_field(t, scalar, &proto_field, false, custom_proto_type.clone()) + .map_err(|e| e.with_name(name))?; message .try_set_field(&proto_field, value) .map_err(|e| FieldEncodeError::new(e).with_name(name))?; @@ -284,6 +328,7 @@ fn encode_field( maybe: D, proto_field: &FieldDescriptor, in_repeated: bool, + custom_proto_type: CustomProtoType, ) -> Result { // Regarding (proto_field.is_list, in_repeated): // (F, T) => impossible @@ -307,7 +352,7 @@ fn encode_field( proto_field.kind() ))) }; - + let is_big_query = matches!(custom_proto_type, CustomProtoType::BigQuery); let value = match &data_type { // Group A: perfect match between RisingWave types and ProtoBuf types DataType::Boolean => match (expect_list, proto_field.kind()) { @@ -345,11 +390,11 @@ fn encode_field( _ => return no_match_err(), }, DataType::Struct(st) => match (expect_list, proto_field.kind()) { - (false, Kind::Message(pb)) => maybe.on_struct(st, &pb)?, + (false, Kind::Message(pb)) => maybe.on_struct(st, &pb, custom_proto_type)?, _ => return no_match_err(), }, DataType::List(elem) => match expect_list { - true => maybe.on_list(elem, proto_field)?, + true => maybe.on_list(elem, proto_field, custom_proto_type)?, false => return no_match_err(), }, // Group B: match between RisingWave types and ProtoBuf Well-Known types @@ -364,18 +409,61 @@ fn encode_field( Ok(Value::Message(message.transcode_to_dynamic())) })? } + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))? + } + _ => return no_match_err(), + }, + DataType::Jsonb => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))? + } + _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue + * Group C: experimental */ + }, + DataType::Int16 => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))? + } _ => return no_match_err(), }, - DataType::Jsonb => return no_match_err(), // Value, NullValue, Struct (map), ListValue - // Group C: experimental - DataType::Int16 => return no_match_err(), - DataType::Date => return no_match_err(), // google.type.Date - DataType::Time => return no_match_err(), // google.type.TimeOfDay - DataType::Timestamp => return no_match_err(), // google.type.DateTime - DataType::Decimal => return no_match_err(), // google.type.Decimal - DataType::Interval => return no_match_err(), - // Group D: unsupported - DataType::Serial | DataType::Int256 => { + DataType::Date => match (expect_list, proto_field.kind()) { + (false, Kind::Int32) if is_big_query => { + maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))? + } + _ => return no_match_err(), // google.type.Date + }, + DataType::Time => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))? + } + _ => return no_match_err(), // google.type.TimeOfDay + }, + DataType::Timestamp => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))? + } + _ => return no_match_err(), // google.type.DateTime + }, + DataType::Decimal => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))? + } + _ => return no_match_err(), // google.type.Decimal + }, + DataType::Interval => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Serial => match (expect_list, proto_field.kind()) { + (false, Kind::Int64) if is_big_query => { + maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Int256 => { return no_match_err(); } }; @@ -398,7 +486,7 @@ mod tests { let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); - + println!("a{:?}", descriptor.descriptor_proto()); let schema = Schema::new(vec![ Field::with_name(DataType::Boolean, "bool_field"), Field::with_name(DataType::Varchar, "string_field"), @@ -441,8 +529,14 @@ mod tests { Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))), ]); - let encoder = - ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap(); + let encoder = ProtoEncoder::new( + schema, + None, + descriptor.clone(), + ProtoHeader::None, + CustomProtoType::None, + ) + .unwrap(); let m = encoder.encode(row).unwrap(); let encoded: Vec = m.ser_to().unwrap(); assert_eq!( @@ -480,6 +574,7 @@ mod tests { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -505,6 +600,7 @@ mod tests { .map(|f| (f.name.as_str(), &f.data_type)) .zip_eq_debug(row.iter()), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -524,6 +620,7 @@ mod tests { let err = validate_fields( std::iter::once(("not_exists", &DataType::Int16)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -534,6 +631,7 @@ mod tests { let err = validate_fields( std::iter::once(("map_field", &DataType::Jsonb)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( diff --git a/src/connector/src/sink/formatter/mod.rs b/src/connector/src/sink/formatter/mod.rs index d923d337a3ffb..1ce6675d7d456 100644 --- a/src/connector/src/sink/formatter/mod.rs +++ b/src/connector/src/sink/formatter/mod.rs @@ -29,7 +29,8 @@ pub use upsert::UpsertFormatter; use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc}; use super::encoder::template::TemplateEncoder; use super::encoder::{ - DateHandlingMode, KafkaConnectParams, TimeHandlingMode, TimestamptzHandlingMode, + CustomProtoType, DateHandlingMode, KafkaConnectParams, TimeHandlingMode, + TimestamptzHandlingMode, }; use super::redis::{KEY_FORMAT, VALUE_FORMAT}; use crate::sink::encoder::{ @@ -134,7 +135,13 @@ impl SinkFormatterImpl { None => ProtoHeader::None, Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid), }; - let val_encoder = ProtoEncoder::new(schema, None, descriptor, header)?; + let val_encoder = ProtoEncoder::new( + schema, + None, + descriptor, + header, + CustomProtoType::None, + )?; let formatter = AppendOnlyFormatter::new(key_encoder, val_encoder); Ok(SinkFormatterImpl::AppendOnlyProto(formatter)) } diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index f9d459fddfd9c..9c7b87ade0ef2 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -17,10 +17,6 @@ BigQueryConfig: - name: bigquery.table field_type: String required: true - - name: bigquery.max_batch_rows - field_type: usize - required: false - default: '1024' - name: region field_type: String required: false diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index bd33f5268aedb..28eded0121a63 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -31,7 +31,7 @@ aws-smithy-runtime = { version = "1", default-features = false, features = ["cli aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "hyper-0-14-x", "rt-tokio"] } axum = { version = "0.6" } base64 = { version = "0.21" } -bigdecimal = { version = "0.4" } +bigdecimal = { version = "0.4", features = ["serde"] } bit-vec = { version = "0.6" } bitflags = { version = "2", default-features = false, features = ["serde", "std"] } byteorder = { version = "1" } @@ -61,6 +61,7 @@ futures-task = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } +google-cloud-googleapis = { version = "0.12", default-features = false, features = ["bigquery", "pubsub"] } governor = { version = "0.6", default-features = false, features = ["dashmap", "jitter", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } @@ -84,6 +85,7 @@ madsim-rdkafka = { version = "0.3", features = ["cmake-build", "gssapi", "ssl-ve madsim-tokio = { version = "0.2", default-features = false, features = ["fs", "io-util", "macros", "net", "process", "rt", "rt-multi-thread", "signal", "sync", "time", "tracing"] } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } mio = { version = "0.8", features = ["net", "os-ext"] } moka = { version = "0.12", features = ["future", "sync"] } nom = { version = "7" } @@ -112,7 +114,7 @@ redis = { version = "0.24", features = ["async-std-comp", "tokio-comp"] } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa", "hybrid", "meta", "nfa", "perf", "unicode"] } regex-syntax = { version = "0.8" } -reqwest = { version = "0.11", features = ["blocking", "json", "rustls-tls"] } +reqwest = { version = "0.11", features = ["blocking", "json", "multipart", "rustls-tls", "stream"] } ring = { version = "0.16", features = ["std"] } rust_decimal = { version = "1", features = ["db-postgres", "maths"] } rustc-hash = { version = "1" } @@ -181,6 +183,7 @@ lazy_static = { version = "1", default-features = false, features = ["spin_no_st libc = { version = "0.2", features = ["extra_traits"] } log = { version = "0.4", default-features = false, features = ["kv_unstable", "std"] } memchr = { version = "2" } +mime_guess = { version = "2" } nom = { version = "7" } num-bigint = { version = "0.4" } num-integer = { version = "0.1", features = ["i128"] }