Skip to content

Commit

Permalink
refactor(sink): Use error instead of unwrap (#17777)
Browse files Browse the repository at this point in the history
  • Loading branch information
xxhZs authored Jul 26, 2024
1 parent 9fd5669 commit 321e8c0
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 30 deletions.
11 changes: 6 additions & 5 deletions src/connector/src/sink/big_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use core::time::Duration;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

use anyhow::anyhow;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use gcp_bigquery_client::error::BQError;
use gcp_bigquery_client::model::query_request::QueryRequest;
Expand Down Expand Up @@ -489,7 +489,7 @@ impl BigQuerySinkWriter {
descriptor_proto.field.push(field);
}

let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto);
let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto)?;
let message_descriptor = descriptor_pool
.get_message_by_name(&config.common.table)
.ok_or_else(|| {
Expand Down Expand Up @@ -733,7 +733,7 @@ impl StorageWriterClient {
}
}

fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::DescriptorPool {
fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> Result<prost_reflect::DescriptorPool> {
let file_descriptor = FileDescriptorProto {
message_type: vec![desc.clone()],
name: Some("bigquery".to_string()),
Expand All @@ -743,7 +743,8 @@ fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::Desc
prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet {
file: vec![file_descriptor],
})
.unwrap()
.context("failed to build descriptor pool")
.map_err(SinkError::BigQuery)
}

fn build_protobuf_schema<'a>(
Expand Down Expand Up @@ -876,7 +877,7 @@ mod test {
.iter()
.map(|f| (f.name.as_str(), &f.data_type));
let desc = build_protobuf_schema(fields, "t1".to_string()).unwrap();
let pool = build_protobuf_descriptor_pool(&desc);
let pool = build_protobuf_descriptor_pool(&desc).unwrap();
let t1_message = pool.get_message_by_name("t1").unwrap();
assert_matches!(
t1_message.get_field_by_name("v1").unwrap().kind(),
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/sink/clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ impl ClickHouseFieldWithNull {
) -> Result<Vec<ClickHouseFieldWithNull>> {
let clickhouse_schema_feature = clickhouse_schema_feature_vec
.get(clickhouse_schema_feature_index)
.unwrap();
.ok_or_else(|| SinkError::ClickHouse(format!("No column found from clickhouse table schema, index is {clickhouse_schema_feature_index}")))?;
if data.is_none() {
if !clickhouse_schema_feature.can_null {
return Err(SinkError::ClickHouse(
Expand Down
28 changes: 22 additions & 6 deletions src/connector/src/sink/doris.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ impl DorisSinkWriter {
.build_get_client()
.get_schema_from_doris()
.await?;
doris_schema.properties.iter().for_each(|s| {
if let Some(v) = s.get_decimal_pre_scale() {
for s in &doris_schema.properties {
if let Some(v) = s.get_decimal_pre_scale()? {
decimal_map.insert(s.name.clone(), v);
}
});
}

let header_builder = HeaderBuilder::new()
.add_common_header()
Expand Down Expand Up @@ -491,11 +491,27 @@ pub struct DorisField {
aggregation_type: String,
}
impl DorisField {
pub fn get_decimal_pre_scale(&self) -> Option<u8> {
pub fn get_decimal_pre_scale(&self) -> Result<Option<u8>> {
if self.r#type.contains("DECIMAL") {
Some(self.scale.clone().unwrap().parse::<u8>().unwrap())
let scale = self
.scale
.as_ref()
.ok_or_else(|| {
SinkError::Doris(format!(
"In doris, the type of {} is DECIMAL, but `scale` is not found",
self.name
))
})?
.parse::<u8>()
.map_err(|err| {
SinkError::Doris(format!(
"Unable to convert decimal's scale to u8. error: {:?}",
err.kind()
))
})?;
Ok(Some(scale))
} else {
None
Ok(None)
}
}
}
Expand Down
23 changes: 10 additions & 13 deletions src/connector/src/sink/doris_starrocks_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,22 +256,22 @@ impl InserterInnerBuilder {
})
}

fn build_request(&self, uri: String) -> RequestBuilder {
fn build_request(&self, uri: String) -> Result<RequestBuilder> {
let client = Client::builder()
.pool_idle_timeout(POOL_IDLE_TIMEOUT)
.redirect(redirect::Policy::none()) // we handle redirect by ourselves
.build()
.unwrap();
.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;

let mut builder = client.put(uri);
for (k, v) in &self.header {
builder = builder.header(k, v);
}
builder
Ok(builder)
}

pub async fn build(&self) -> Result<InserterInner> {
let builder = self.build_request(self.url.clone());
let builder = self.build_request(self.url.clone())?;
let resp = builder
.send()
.await
Expand All @@ -284,7 +284,7 @@ impl InserterInnerBuilder {
let body = Body::wrap_stream(
tokio_stream::wrappers::UnboundedReceiverStream::new(receiver).map(Ok::<_, Infallible>),
);
let builder = self.build_request(be_url.into()).body(body);
let builder = self.build_request(be_url.into())?.body(body);

let handle: JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
let response = builder
Expand Down Expand Up @@ -321,7 +321,7 @@ type Sender = UnboundedSender<Bytes>;

pub struct InserterInner {
sender: Option<Sender>,
join_handle: Option<JoinHandle<Result<Vec<u8>>>>,
join_handle: JoinHandle<Result<Vec<u8>>>,
buffer: BytesMut,
stream_load_http_timeout: Duration,
}
Expand All @@ -333,7 +333,7 @@ impl InserterInner {
) -> Self {
Self {
sender: Some(sender),
join_handle: Some(join_handle),
join_handle,
buffer: BytesMut::with_capacity(BUFFER_SIZE),
stream_load_http_timeout,
}
Expand Down Expand Up @@ -365,11 +365,8 @@ impl InserterInner {
}

async fn wait_handle(&mut self) -> Result<Vec<u8>> {
let res = match tokio::time::timeout(
self.stream_load_http_timeout,
self.join_handle.as_mut().unwrap(),
)
.await
let res = match tokio::time::timeout(self.stream_load_http_timeout, &mut self.join_handle)
.await
{
Ok(res) => res.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))??,
Err(err) => return Err(SinkError::DorisStarrocksConnect(anyhow!(err))),
Expand Down Expand Up @@ -480,7 +477,7 @@ impl StarrocksTxnRequestBuilder {
.pool_idle_timeout(POOL_IDLE_TIMEOUT)
.redirect(redirect::Policy::none())
.build()
.unwrap();
.map_err(|err| SinkError::DorisStarrocksConnect(anyhow!(err)))?;

Ok(Self {
url_begin,
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/sink/nats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl AsyncTruncateSinkWriter for NatsSinkWriter {
chunk: StreamChunk,
mut add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
) -> Result<()> {
let mut data = chunk_to_json(chunk, &self.json_encoder).unwrap();
let mut data = chunk_to_json(chunk, &self.json_encoder)?;
for item in &mut data {
let publish_ack_future = Retry::spawn(
ExponentialBackoff::from_millis(100).map(jitter).take(3),
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/sink/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl FormattedSink for RedisSinkPayloadWriter {
type V = Vec<u8>;

async fn write_one(&mut self, k: Option<Self::K>, v: Option<Self::V>) -> Result<()> {
let k = k.unwrap();
let k = k.ok_or_else(|| SinkError::Redis("The redis key cannot be null".to_string()))?;
match v {
Some(v) => self.pipe.set(k, v),
None => self.pipe.del(k),
Expand Down
11 changes: 8 additions & 3 deletions src/connector/src/sink/sqlserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

use anyhow::anyhow;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use risingwave_common::array::{Op, RowRef, StreamChunk};
use risingwave_common::bitmap::Bitmap;
Expand Down Expand Up @@ -504,8 +504,13 @@ impl SqlClient {
config.database(&msconfig.database);
config.trust_cert();

let tcp = TcpStream::connect(config.get_addr()).await.unwrap();
tcp.set_nodelay(true).unwrap();
let tcp = TcpStream::connect(config.get_addr())
.await
.context("failed to connect to sql server")
.map_err(SinkError::SqlServer)?;
tcp.set_nodelay(true)
.context("failed to setting nodelay when connecting to sql server")
.map_err(SinkError::SqlServer)?;
let client = Client::connect(config, tcp.compat_write()).await?;
Ok(Self { client })
}
Expand Down

0 comments on commit 321e8c0

Please sign in to comment.