Skip to content

Commit

Permalink
feat: Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bakjos committed Mar 5, 2024
1 parent c9a47a0 commit b8c2fc7
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 99 deletions.
7 changes: 4 additions & 3 deletions integration_tests/mqtt/create_source.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ WITH (
connector='mqtt',
host='mqtt-server',
topic= 'test',
qos = '1'
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON;


Expand All @@ -23,9 +23,10 @@ WITH
host='mqtt-server',
topic= 'test',
type = 'append-only',
force_append_only='true',
retain = 'false',
qos = '1'
qos = 'at_least_once',
) FORMAT PLAIN ENCODE JSON (
force_append_only='true',
);

INSERT INTO
Expand Down
69 changes: 64 additions & 5 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ use risingwave_common::bail;
use serde_derive::Deserialize;
use serde_with::json::JsonString;
use serde_with::{serde_as, DisplayFromStr};
use strum_macros::{Display, EnumString};
use tempfile::NamedTempFile;
use time::OffsetDateTime;
use url::Url;
use with_options::WithOptions;

use crate::aws_utils::load_file_descriptor_from_s3;
use crate::deserialize_duration_from_string;
use crate::error::ConnectorResult;
use crate::sink::SinkError;
use crate::source::nats::source::NatsOffset;
use crate::{deserialize_bool_from_string, deserialize_duration_from_string};
// The file describes the common abstractions for each connector and can be used in both source and
// sink.

Expand Down Expand Up @@ -685,29 +686,81 @@ impl NatsCommon {
}
}

#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
#[strum(serialize_all = "snake_case")]
#[allow(clippy::enum_variant_names)]
pub enum QualityOfService {
AtLeastOnce,
AtMostOnce,
ExactlyOnce,
}

#[derive(Debug, Clone, PartialEq, Display, Deserialize, EnumString)]
#[strum(serialize_all = "snake_case")]
pub enum Protocol {
Tls,
Ssl,
}

#[serde_as]
#[derive(Deserialize, Debug, Clone, WithOptions)]
pub struct MqttCommon {
/// Protocol used for RisingWave to communicate with Kafka brokers. Could be tcp or ssl
/// Protocol used for RisingWave to communicate with the mqtt brokers. Could be tcp or ssl
#[serde_as(as = "Option<DisplayFromStr>")]
#[serde(rename = "protocol")]
pub protocol: Option<String>,
pub protocol: Option<Protocol>,

/// Hostname of the mqtt broker
#[serde(rename = "host")]
pub host: String,

/// Port of the mqtt broker, defaults to 1883 for tcp and 8883 for ssl
#[serde(rename = "port")]
pub port: Option<i32>,

/// The topic name to subscribe or publish to. When subscribing, it can be a wildcard topic. e.g /topic/#
#[serde(rename = "topic")]
pub topic: String,

/// Username for the mqtt broker
#[serde(rename = "username")]
pub user: Option<String>,

/// Password for the mqtt broker
#[serde(rename = "password")]
pub password: Option<String>,
#[serde(rename = "client_prefix")]

/// Prefix for the mqtt client id
pub client_prefix: Option<String>,

/// `clean_start = true` removes all the state from queues & instructs the broker
/// to clean all the client state when client disconnects.
///
/// When set `false`, broker will hold the client state and performs pending
/// operations on the client when reconnection with same `client_id`
/// happens. Local queue state is also held to retransmit packets after reconnection.
#[serde(rename = "clean_start")]
#[serde(default, deserialize_with = "deserialize_bool_from_string")]
pub clean_start: bool,

/// The maximum number of inflight messages. Defaults to 100
#[serde(rename = "inflight_messages")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub inflight_messages: Option<usize>,
#[serde(rename = "tls.ca")]

/// Path to CA certificate file for verifying the broker's key.
pub ca: Option<String>,
#[serde(rename = "tls.client_cert")]

/// Path to client's certificate file (PEM). Required for client authentication.
/// Can be a file path under fs:// or a string with the certificate content.
pub client_cert: Option<String>,
#[serde(rename = "tls.client_key")]

/// Path to client's private key file (PEM). Required for client authentication.
/// Can be a file path under fs:// or a string with the private key content.
pub client_key: Option<String>,
}

Expand All @@ -719,7 +772,7 @@ impl MqttCommon {
let ssl = self
.protocol
.as_ref()
.map(|p| p == "ssl")
.map(|p| p == &Protocol::Ssl)
.unwrap_or_default();

let client_id = format!(
Expand All @@ -737,6 +790,9 @@ impl MqttCommon {

let mut options = rumqttc::v5::MqttOptions::new(client_id, &self.host, port);
options.set_keep_alive(std::time::Duration::from_secs(10));

options.set_clean_start(self.clean_start);

if ssl {
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
if let Some(ca) = &self.ca {
Expand Down Expand Up @@ -778,7 +834,10 @@ impl MqttCommon {
options.set_credentials(user, self.password.as_deref().unwrap_or_default());
}

Ok(rumqttc::v5::AsyncClient::new(options, 100))
Ok(rumqttc::v5::AsyncClient::new(
options,
self.inflight_messages.unwrap_or(100),
))
}
}

Expand Down
Loading

0 comments on commit b8c2fc7

Please sign in to comment.