diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index c89e200093473..df800f234831c 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -18,7 +18,11 @@ use std::sync::Arc; use anyhow::anyhow; use async_trait::async_trait; +use gcp_bigquery_client::error::BQError; use gcp_bigquery_client::model::query_request::QueryRequest; +use gcp_bigquery_client::model::table::Table; +use gcp_bigquery_client::model::table_field_schema::TableFieldSchema; +use gcp_bigquery_client::model::table_schema::TableSchema; use gcp_bigquery_client::Client; use google_cloud_bigquery::grpc::apiv1::bigquery_client::StreamingWriteClient; use google_cloud_bigquery::grpc::apiv1::conn_pool::{WriteConnectionManager, DOMAIN}; @@ -39,10 +43,11 @@ use prost_types::{ }; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; -use risingwave_common::catalog::Schema; +use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; use serde_derive::Deserialize; use serde_with::{serde_as, DisplayFromStr}; +use simd_json::prelude::ArrayTrait; use url::Url; use uuid::Uuid; use with_options::WithOptions; @@ -83,6 +88,9 @@ pub struct BigQueryCommon { #[serde(rename = "bigquery.retry_times", default = "default_retry_times")] #[serde_as(as = "DisplayFromStr")] pub retry_times: usize, + #[serde(default)] // default false + #[serde_as(as = "DisplayFromStr")] + pub auto_create: bool, } fn default_max_batch_rows() -> usize { @@ -255,6 +263,79 @@ impl BigQuerySink { ))), } } + + fn map_field(rw_field: &Field) -> Result { + let tfs = match &rw_field.data_type { + DataType::Boolean => TableFieldSchema::bool(&rw_field.name), + DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => { + TableFieldSchema::integer(&rw_field.name) + } + DataType::Float32 => { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Bigquery cannot support real" + ))) + } + DataType::Float64 => TableFieldSchema::float(&rw_field.name), + DataType::Decimal => TableFieldSchema::numeric(&rw_field.name), + DataType::Date => TableFieldSchema::date(&rw_field.name), + DataType::Varchar => TableFieldSchema::string(&rw_field.name), + DataType::Time => TableFieldSchema::time(&rw_field.name), + DataType::Timestamp => TableFieldSchema::date_time(&rw_field.name), + DataType::Timestamptz => TableFieldSchema::timestamp(&rw_field.name), + DataType::Interval => { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Bigquery cannot support Interval" + ))) + } + DataType::Struct(_) => { + let mut sub_fields = Vec::with_capacity(rw_field.sub_fields.len()); + for rw_field in &rw_field.sub_fields { + let field = Self::map_field(rw_field)?; + sub_fields.push(field) + } + TableFieldSchema::record(&rw_field.name, sub_fields) + } + DataType::List(dt) => { + let inner_field = Self::map_field(&Field::with_name(*dt.clone(), &rw_field.name))?; + TableFieldSchema { + mode: Some("REPEATED".to_string()), + ..inner_field + } + } + + DataType::Bytea => TableFieldSchema::bytes(&rw_field.name), + DataType::Jsonb => TableFieldSchema::json(&rw_field.name), + DataType::Int256 => { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Bigquery cannot support Int256" + ))) + } + }; + Ok(tfs) + } + + async fn create_table( + &self, + client: &Client, + project_id: &str, + dataset_id: &str, + table_id: &str, + fields: &Vec, + ) -> Result { + let dataset = client + .dataset() + .get(project_id, dataset_id) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let fields: Vec<_> = fields.iter().map(Self::map_field).collect::>()?; + let table = Table::from_dataset(&dataset, table_id, TableSchema::new(fields)); + + client + .table() + .create(table) + .await + .map_err(|e| SinkError::BigQuery(e.into())) + } } impl Sink for BigQuerySink { @@ -284,16 +365,47 @@ impl Sink for BigQuerySink { .common .build_client(&self.config.aws_auth_props) .await?; + let BigQueryCommon { + project: project_id, + dataset: dataset_id, + table: table_id, + .. + } = &self.config.common; + + if self.config.common.auto_create { + match client + .table() + .get(project_id, dataset_id, table_id, None) + .await + { + Err(BQError::RequestError(_)) => { + // early return: no need to query schema to check column and type + return self + .create_table( + &client, + project_id, + dataset_id, + table_id, + &self.schema.fields, + ) + .await + .map(|_| ()); + } + Err(e) => return Err(SinkError::BigQuery(e.into())), + _ => {} + } + } + let mut rs = client - .job() - .query( - &self.config.common.project, - QueryRequest::new(format!( - "SELECT column_name, data_type FROM `{}.{}.INFORMATION_SCHEMA.COLUMNS` WHERE table_name = '{}'" - ,self.config.common.project,self.config.common.dataset,self.config.common.table, - )), - ) - .await.map_err(|e| SinkError::BigQuery(e.into()))?; + .job() + .query( + &self.config.common.project, + QueryRequest::new(format!( + "SELECT column_name, data_type FROM `{}.{}.INFORMATION_SCHEMA.COLUMNS` WHERE table_name = '{}'", + project_id, dataset_id, table_id, + )), + ).await.map_err(|e| SinkError::BigQuery(e.into()))?; + let mut big_query_schema = HashMap::default(); while rs.next_row() { big_query_schema.insert( diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 7a9aaa444400b..7deb67a524fcb 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -25,6 +25,10 @@ BigQueryConfig: field_type: usize required: false default: '5' + - name: auto_create + field_type: bool + required: false + default: Default::default - name: aws.region field_type: String required: false