diff --git a/Cargo.lock b/Cargo.lock index 48643c9cf1b65..c24b2d897d18d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10352,6 +10352,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "url", ] [[package]] diff --git a/integration_tests/mqtt/create_source.sql b/integration_tests/mqtt/create_source.sql index 04be605700c76..8c63216125280 100644 --- a/integration_tests/mqtt/create_source.sql +++ b/integration_tests/mqtt/create_source.sql @@ -8,7 +8,7 @@ CREATE TABLE mqtt_source_table ) WITH ( connector='mqtt', - host='mqtt-server', + url='tcp://mqtt-server', topic= 'test', qos = 'at_least_once', ) FORMAT PLAIN ENCODE JSON; @@ -20,7 +20,7 @@ FROM WITH ( connector='mqtt', - host='mqtt-server', + url='tcp://mqtt-server', topic= 'test', type = 'append-only', retain = 'false', diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 164c2f52a66cf..7f7ef4e043eb1 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -113,7 +113,7 @@ risingwave_common = { workspace = true } risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } -rumqttc = "0.22.0" +rumqttc = { version = "0.22.0", features = ["url"] } rust_decimal = "1" rustls-native-certs = "0.6" rustls-pemfile = "1" diff --git a/src/connector/src/common.rs b/src/connector/src/common.rs index b64d9e281eaff..2198585b99c99 100644 --- a/src/connector/src/common.rs +++ b/src/connector/src/common.rs @@ -695,29 +695,23 @@ pub enum QualityOfService { ExactlyOnce, } -#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)] -#[strum(serialize_all = "snake_case")] -pub enum Protocol { - Tcp, - Ssl, -} - #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] pub struct MqttCommon { - /// Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - #[serde_as(as = "Option")] - pub protocol: Option, - - /// Hostname of the mqtt broker - pub host: String, - - /// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - pub port: Option, + /// The url of the broker to connect to. e.g. tcp://localhost. + /// Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, + /// `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. + /// `mqtts://`, `ssl://`, `wss://` + pub url: String, /// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# pub topic: String, + /// The quality of service to use when publishing messages. Defaults to at_most_once. + /// Could be at_most_once, at_least_once or exactly_once + #[serde_as(as = "Option")] + pub qos: Option, + /// Username for the mqtt broker #[serde(rename = "username")] pub user: Option, @@ -759,64 +753,32 @@ pub struct MqttCommon { impl MqttCommon { pub(crate) fn build_client( &self, + actor_id: u32, id: u32, ) -> ConnectorResult<(rumqttc::v5::AsyncClient, rumqttc::v5::EventLoop)> { - let ssl = self - .protocol - .as_ref() - .map(|p| p == &Protocol::Ssl) - .unwrap_or_default(); - let client_id = format!( - "{}_{}{}", + "{}_{}_{}", self.client_prefix.as_deref().unwrap_or("risingwave"), - id, - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis() - % 100000, + actor_id, + id ); - let port = self.port.unwrap_or(if ssl { 8883 } else { 1883 }) as u16; + let mut url = url::Url::parse(&self.url)?; - let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port); + let ssl = match url.scheme() { + "mqtts" | "ssl" | "wss" => true, + _ => false, + }; + + url.query_pairs_mut().append_pair("client_id", &client_id); + + let mut options = rumqttc::v5::MqttOptions::try_from(url)?; options.set_keep_alive(std::time::Duration::from_secs(10)); options.set_clean_start(self.clean_start); if ssl { - let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); - if let Some(ca) = &self.ca { - let certificates = load_certs(ca)?; - for cert in certificates { - root_cert_store.add(&cert).unwrap(); - } - } else { - for cert in - rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - root_cert_store - .add(&tokio_rustls::rustls::Certificate(cert.0)) - .unwrap(); - } - } - - let builder = tokio_rustls::rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); - - let tls_config = if let (Some(client_cert), Some(client_key)) = - (self.client_cert.as_ref(), self.client_key.as_ref()) - { - let certs = load_certs(client_cert)?; - let key = load_private_key(client_key)?; - - builder.with_client_auth_cert(certs, key)? - } else { - builder.with_no_client_auth() - }; - + let tls_config = self.get_tls_config()?; options.set_transport(rumqttc::Transport::tls_with_config( rumqttc::TlsConfiguration::Rustls(std::sync::Arc::new(tls_config)), )); @@ -831,6 +793,52 @@ impl MqttCommon { self.inflight_messages.unwrap_or(100), )) } + + pub(crate) fn qos(&self) -> rumqttc::v5::mqttbytes::QoS { + self.qos + .as_ref() + .map(|qos| match qos { + QualityOfService::AtMostOnce => rumqttc::v5::mqttbytes::QoS::AtMostOnce, + QualityOfService::AtLeastOnce => rumqttc::v5::mqttbytes::QoS::AtLeastOnce, + QualityOfService::ExactlyOnce => rumqttc::v5::mqttbytes::QoS::ExactlyOnce, + }) + .unwrap_or(rumqttc::v5::mqttbytes::QoS::AtMostOnce) + } + + fn get_tls_config(&self) -> ConnectorResult { + let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + if let Some(ca) = &self.ca { + let certificates = load_certs(ca)?; + for cert in certificates { + root_cert_store.add(&cert).unwrap(); + } + } else { + for cert in + rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store + .add(&tokio_rustls::rustls::Certificate(cert.0)) + .unwrap(); + } + } + + let builder = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let tls_config = if let (Some(client_cert), Some(client_key)) = + (self.client_cert.as_ref(), self.client_key.as_ref()) + { + let certs = load_certs(client_cert)?; + let key = load_private_key(client_key)?; + + builder.with_client_auth_cert(certs, key)? + } else { + builder.with_no_client_auth() + }; + + Ok(tls_config) + } } fn load_certs(certificates: &str) -> ConnectorResult> { diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 1317981f88919..603b2c01f5810 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -60,6 +60,7 @@ def_anyhow_newtype! { google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error", tokio_rustls::rustls::Error => "TLS error", rumqttc::v5::ClientError => "MQTT error", + rumqttc::v5::OptionError => "MQTT error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 6442789b12617..040e62bfe9634 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -22,7 +22,7 @@ use risingwave_common::catalog::Schema; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::ConnectionError; use serde_derive::Deserialize; -use serde_with::{serde_as, DisplayFromStr}; +use serde_with::serde_as; use thiserror_ext::AsReport; use with_options::WithOptions; @@ -30,7 +30,7 @@ use super::catalog::SinkFormatDesc; use super::formatter::SinkFormatterImpl; use super::writer::FormattedSink; use super::{DummySinkCommitCoordinator, SinkWriterParam}; -use crate::common::{MqttCommon, QualityOfService}; +use crate::common::MqttCommon; use crate::sink::catalog::desc::SinkDesc; use crate::sink::log_store::DeliveryFutureManagerAddFuture; use crate::sink::writer::{ @@ -47,11 +47,6 @@ pub struct MqttConfig { #[serde(flatten)] pub common: MqttCommon, - /// The quality of service to use when publishing messages. Defaults to at_most_once. - /// Could be at_most_once, at_least_once or exactly_once - #[serde_as(as = "Option")] - pub qos: Option, - /// Whether the message should be retained by the broker #[serde(default, deserialize_with = "deserialize_bool_from_string")] pub retain: bool, @@ -132,7 +127,7 @@ impl Sink for MqttSink { ))); } - let _client = (self.config.common.build_client(0)) + let _client = (self.config.common.build_client(0, 0)) .context("validate mqtt sink error") .map_err(SinkError::Mqtt)?; @@ -174,19 +169,11 @@ impl MqttSinkWriter { ) .await?; - let qos = config - .qos - .as_ref() - .map(|qos| match qos { - QualityOfService::AtMostOnce => QoS::AtMostOnce, - QualityOfService::AtLeastOnce => QoS::AtLeastOnce, - QualityOfService::ExactlyOnce => QoS::ExactlyOnce, - }) - .unwrap_or(QoS::AtMostOnce); + let qos = config.common.qos(); let (client, mut eventloop) = config .common - .build_client(id as u32) + .build_client(0, id as u32) .map_err(|e| SinkError::Mqtt(anyhow!(e)))?; let stopped = Arc::new(AtomicBool::new(false)); @@ -196,24 +183,23 @@ impl MqttSinkWriter { while !stopped_clone.load(std::sync::atomic::Ordering::Relaxed) { match eventloop.poll().await { Ok(_) => (), - Err(err) => { - if let ConnectionError::Timeout(_) = err { + Err(err) => match err { + ConnectionError::Timeout(_) => { continue; } - - if let ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) = err { - if err.kind() != std::io::ErrorKind::ConnectionAborted { - tracing::error!( - "Failed to poll mqtt eventloop: {}", - err.as_report() - ); - std::thread::sleep(std::time::Duration::from_secs(1)); - } - } else { + ConnectionError::MqttState(rumqttc::v5::StateError::Io(err)) + | ConnectionError::Io(err) + if err.kind() == std::io::ErrorKind::ConnectionAborted + || err.kind() == std::io::ErrorKind::ConnectionReset => + { + continue; + } + err => { + println!("Err: {:?}", err); tracing::error!("Failed to poll mqtt eventloop: {}", err.as_report()); std::thread::sleep(std::time::Duration::from_secs(1)); } - } + }, } } }); diff --git a/src/connector/src/source/mqtt/enumerator/mod.rs b/src/connector/src/source/mqtt/enumerator/mod.rs index 5cfd952ab0121..1013f31e07d5e 100644 --- a/src/connector/src/source/mqtt/enumerator/mod.rs +++ b/src/connector/src/source/mqtt/enumerator/mod.rs @@ -45,7 +45,7 @@ impl SplitEnumerator for MqttSplitEnumerator { properties: Self::Properties, context: SourceEnumeratorContextRef, ) -> ConnectorResult { - let (client, mut eventloop) = properties.common.build_client(context.info.source_id)?; + let (client, mut eventloop) = properties.common.build_client(context.info.source_id, 0)?; let topic = properties.common.topic.clone(); let mut topics = HashSet::new(); @@ -92,7 +92,7 @@ impl SplitEnumerator for MqttSplitEnumerator { continue; } tracing::error!( - "[Enumerator] Failed to subscribe to topic {}: {}", + "Failed to subscribe to topic {}: {}", topic, err.as_report(), ); @@ -127,7 +127,7 @@ impl SplitEnumerator for MqttSplitEnumerator { bail!("Failed to connect to mqtt broker"); } - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } diff --git a/src/connector/src/source/mqtt/source/reader.rs b/src/connector/src/source/mqtt/source/reader.rs index 21a0e60289a56..cdc88f7702a86 100644 --- a/src/connector/src/source/mqtt/source/reader.rs +++ b/src/connector/src/source/mqtt/source/reader.rs @@ -21,7 +21,6 @@ use thiserror_ext::AsReport; use super::message::MqttMessage; use super::MqttSplit; -use crate::common::QualityOfService; use crate::error::ConnectorResult as Result; use crate::parser::ParserConfig; use crate::source::common::{into_chunk_stream, CommonSplitReader}; @@ -52,17 +51,9 @@ impl SplitReader for MqttSplitReader { ) -> Result { let (client, eventloop) = properties .common - .build_client(source_ctx.source_info.fragment_id)?; + .build_client(source_ctx.source_info.actor_id, source_ctx.source_info.fragment_id)?; - let qos = properties - .qos - .as_ref() - .map(|qos| match qos { - QualityOfService::AtMostOnce => QoS::AtMostOnce, - QualityOfService::AtLeastOnce => QoS::AtLeastOnce, - QualityOfService::ExactlyOnce => QoS::ExactlyOnce, - }) - .unwrap_or(QoS::AtMostOnce); + let qos = properties.common.qos(); client .subscribe_many( diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index c233fb6e3a33b..1bdda0b484ce4 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -53,7 +53,6 @@ impl WithOptions for i64 {} impl WithOptions for f64 {} impl WithOptions for std::time::Duration {} impl WithOptions for crate::common::QualityOfService {} -impl WithOptions for crate::common::Protocol {} impl WithOptions for crate::sink::kafka::CompressionCodec {} impl WithOptions for nexmark::config::RateShape {} impl WithOptions for nexmark::event::EventType {} diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 70f1240297cd6..f0830f1da88d0 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -348,22 +348,18 @@ KinesisSinkConfig: alias: kinesis.assumerole.external_id MqttConfig: fields: - - name: protocol - field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - required: false - - name: host + - name: url field_type: String - comments: Hostname of the mqtt broker + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` required: true - - name: port - field_type: i32 - comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - required: false - name: topic field_type: String comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true + - name: qos + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once + required: false - name: username field_type: String comments: Username for the mqtt broker @@ -397,10 +393,6 @@ MqttConfig: field_type: String comments: Path to client's private key file (PEM). Required for client authentication. Can be a file path under fs:// or a string with the private key content. required: false - - name: qos - field_type: QualityOfService - comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once - required: false - name: retain field_type: bool comments: Whether the message should be retained by the broker diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index cb929d32ddce4..b45b933e205ef 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -245,22 +245,18 @@ KinesisProperties: alias: kinesis.assumerole.external_id MqttProperties: fields: - - name: protocol - field_type: Protocol - comments: Protocol used for RisingWave to communicate with the mqtt brokers. Could be `tcp` or `ssl`, defaults to `tcp` - required: false - - name: host + - name: url field_type: String - comments: Hostname of the mqtt broker + comments: The url of the broker to connect to. e.g. tcp://localhost. Must be prefixed with one of either `tcp://`, `mqtt://`, `ssl://`,`mqtts://`, `ws://` or `wss://` to denote the protocol for establishing a connection with the broker. `mqtts://`, `ssl://`, `wss://` required: true - - name: port - field_type: i32 - comments: Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl - required: false - name: topic field_type: String comments: The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/# required: true + - name: qos + field_type: QualityOfService + comments: The quality of service to use when publishing messages. Defaults to at_most_once. Could be at_most_once, at_least_once or exactly_once + required: false - name: username field_type: String comments: Username for the mqtt broker