Skip to content

Commit

Permalink
Resolves PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SirCipher committed Jun 26, 2024
1 parent e3ab5f8 commit 6082bc7
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 74 deletions.
5 changes: 3 additions & 2 deletions client/swimos_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ edition = "2021"

[features]
default = []
tls = ["swimos_remote/tls"]
deflate = ["runtime/deflate"]
trust_dns = ["swimos_runtime/trust_dns"]
ring_provider = ["swimos_remote/ring_provider"]
aws_lc_rs_provider = ["swimos_remote/aws_lc_rs_provider"]

[dependencies]
runtime = { path = "../runtime" }
Expand All @@ -25,4 +26,4 @@ tokio = { workspace = true, features = ["sync"] }
futures = { workspace = true }
futures-util = { workspace = true }
tracing = { workspace = true }

rustls = { workspace = true }
119 changes: 78 additions & 41 deletions client/swimos_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,138 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#[cfg(not(feature = "deflate"))]
use ratchet::NoExtProvider;
use ratchet::WebSocketStream;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use swimos_remote::websocket::RatchetClient;
use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc};

use futures_util::future::BoxFuture;
#[cfg(feature = "deflate")]
use ratchet::deflate::{DeflateConfig, DeflateExtProvider};
use ratchet::{
deflate::{DeflateConfig, DeflateExtProvider},
WebSocketStream,
};
use rustls::crypto::CryptoProvider;
use tokio::{sync::mpsc, sync::mpsc::error::SendError, sync::oneshot::error::RecvError};
pub use url::Url;

use runtime::{
start_runtime, ClientConfig, DownlinkRuntimeError, RawHandle, Transport, WebSocketConfig,
};
pub use runtime::{CommandError, Commander, RemotePath};
use std::sync::Arc;
pub use swimos_client_api::DownlinkConfig;
pub use swimos_downlink::lifecycle::{
BasicEventDownlinkLifecycle, BasicMapDownlinkLifecycle, BasicValueDownlinkLifecycle,
EventDownlinkLifecycle, MapDownlinkLifecycle, ValueDownlinkLifecycle,
pub use swimos_downlink::{
lifecycle::BasicEventDownlinkLifecycle, lifecycle::BasicMapDownlinkLifecycle,
lifecycle::BasicValueDownlinkLifecycle, lifecycle::EventDownlinkLifecycle,
lifecycle::MapDownlinkLifecycle, lifecycle::ValueDownlinkLifecycle,
};
use swimos_downlink::{
ChannelError, DownlinkTask, EventDownlinkModel, MapDownlinkHandle, MapDownlinkModel, MapKey,
MapValue, NotYetSyncedError, ValueDownlinkModel, ValueDownlinkSet,
};
use swimos_form::Form;
use swimos_remote::dns::Resolver;
use swimos_remote::plain::TokioPlainTextNetworking;
#[cfg(feature = "tls")]
use swimos_remote::tls::{ClientConfig as TlsConfig, RustlsClientNetworking, TlsError};
use swimos_remote::ClientConnections;
pub use swimos_remote::tls::ClientConfig as TlsConfig;
use swimos_remote::tls::TlsError;
use swimos_remote::{
dns::Resolver,
plain::TokioPlainTextNetworking,
tls::{CryptoProviderConfig, RustlsClientNetworking},
websocket::RatchetClient,
ClientConnections,
};
use swimos_runtime::downlink::{DownlinkOptions, DownlinkRuntimeConfig};
use swimos_utilities::trigger;
use swimos_utilities::trigger::promise;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::oneshot::error::RecvError;
pub use url::Url;
use swimos_utilities::{trigger, trigger::promise};

pub type DownlinkOperationResult<T> = Result<T, DownlinkRuntimeError>;

#[derive(Debug, Default)]
#[derive(Default)]
pub struct SwimClientBuilder {
config: ClientConfig,
client_config: ClientConfig,
}

impl SwimClientBuilder {
pub fn new(config: ClientConfig) -> SwimClientBuilder {
SwimClientBuilder { config }
pub fn new(client_config: ClientConfig) -> SwimClientBuilder {
SwimClientBuilder { client_config }
}

/// Sets the websocket configuration.
pub fn set_websocket_config(mut self, to: WebSocketConfig) -> SwimClientBuilder {
self.config.websocket = to;
self.client_config.websocket = to;
self
}

/// Size of the buffers to communicate with the socket.
pub fn set_remote_buffer_size(mut self, to: NonZeroUsize) -> SwimClientBuilder {
self.config.remote_buffer_size = to;
self.client_config.remote_buffer_size = to;
self
}

/// Sets the buffer size between the runtime and transport tasks.
pub fn set_transport_buffer_size(mut self, to: NonZeroUsize) -> SwimClientBuilder {
self.config.transport_buffer_size = to;
self.client_config.transport_buffer_size = to;
self
}

/// Sets the deflate extension configuration for WebSocket connections.
#[cfg(feature = "deflate")]
pub fn set_deflate_config(mut self, to: DeflateConfig) -> SwimClientBuilder {
self.config.websocket.deflate_config = Some(to);
self.client_config.websocket.deflate_config = Some(to);
self
}

/// Enables TLS support.
pub fn set_tls_config(self, tls_config: TlsConfig) -> SwimClientTlsBuilder {
SwimClientTlsBuilder {
client_config: self.client_config,
tls_config,
crypto_provider: Default::default(),
}
}

/// Builds the client.
pub async fn build(self) -> (SwimClient, BoxFuture<'static, ()>) {
let SwimClientBuilder { config } = self;
let SwimClientBuilder { client_config } = self;
open_client(
config,
client_config,
TokioPlainTextNetworking::new(Arc::new(Resolver::new().await)),
)
.await
}
}

pub struct SwimClientTlsBuilder {
client_config: ClientConfig,
tls_config: TlsConfig,
crypto_provider: CryptoProviderConfig,
}

impl SwimClientTlsBuilder {
/// Uses the process-default [`CryptoProvider`] for any TLS connections.
///
/// This is only used if the TLS configuration has been set.
pub fn with_default_crypto_provider(mut self) -> Self {
self.crypto_provider = CryptoProviderConfig::ProcessDefault;
self
}

/// Uses the provided [`CryptoProvider`] for any TLS connections.
///
/// This is only used if the TLS configuration has been set.
pub fn with_crypto_provider(mut self, provider: Arc<CryptoProvider>) -> Self {
self.crypto_provider = CryptoProviderConfig::Provided(provider);
self
}

/// Builds the client using the provided TLS configuration.
#[cfg(feature = "tls")]
pub async fn build_tls(
self,
tls_config: TlsConfig,
) -> Result<(SwimClient, BoxFuture<'static, ()>), TlsError> {
let SwimClientBuilder { config } = self;
pub async fn build(self) -> Result<(SwimClient, BoxFuture<'static, ()>), TlsError> {
let SwimClientTlsBuilder {
client_config,
tls_config,
crypto_provider,
} = self;
Ok(open_client(
config,
RustlsClientNetworking::try_from_config(Arc::new(Resolver::new().await), tls_config)?,
client_config,
RustlsClientNetworking::build(
Arc::new(Resolver::new().await),
tls_config,
crypto_provider.build(),
)?,
)
.await)
}
Expand Down
2 changes: 2 additions & 0 deletions runtime/swimos_remote/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
[features]
default = []
tls = ["rustls", "webpki", "webpki-roots", "tokio-rustls", "rustls-pemfile"]
ring_provider = []
aws_lc_rs_provider = []

[dependencies]
ratchet = { workspace = true, features = ["deflate", "split"] }
Expand Down
12 changes: 2 additions & 10 deletions runtime/swimos_remote/src/tls/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use rustls::crypto::CryptoProvider;
use std::sync::Arc;

/// Supported certificate formats for TLS connections.
pub enum CertFormat {
Pem,
Expand Down Expand Up @@ -87,17 +84,14 @@ pub struct ServerConfig {
/// `SSLKEYLOGFILE` environment variable, and writes keys into it. While this may be enabled,
/// if `SSLKEYLOGFILE` is not set, it will do nothing.
pub enable_log_file: bool,
/// [`CryptoProvider`] to use when building the [`rustls::ServerConfig`].
pub provider: Arc<CryptoProvider>,
}

impl ServerConfig {
pub fn new(chain: CertChain, key: PrivateKey, provider: Arc<CryptoProvider>) -> Self {
pub fn new(chain: CertChain, key: PrivateKey) -> Self {
ServerConfig {
chain,
key,
enable_log_file: false,
provider,
}
}
}
Expand All @@ -106,15 +100,13 @@ impl ServerConfig {
pub struct ClientConfig {
pub use_webpki_roots: bool,
pub custom_roots: Vec<CertificateFile>,
pub provider: Arc<CryptoProvider>,
}

impl ClientConfig {
pub fn new(custom_roots: Vec<CertificateFile>, provider: Arc<CryptoProvider>) -> Self {
pub fn new(custom_roots: Vec<CertificateFile>) -> Self {
ClientConfig {
use_webpki_roots: true,
custom_roots,
provider,
}
}
}
37 changes: 37 additions & 0 deletions runtime/swimos_remote/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,40 @@ pub use config::{
pub use errors::TlsError;
pub use maybe::MaybeTlsStream;
pub use net::{RustlsClientNetworking, RustlsListener, RustlsNetworking, RustlsServerNetworking};
use rustls::crypto::CryptoProvider;
use std::sync::Arc;

#[derive(Default)]
pub enum CryptoProviderConfig {
ProcessDefault,
#[default]
FromFeatureFlags,
Provided(Arc<CryptoProvider>),
}

impl CryptoProviderConfig {
pub fn build(self) -> Arc<CryptoProvider> {
match self {
CryptoProviderConfig::ProcessDefault => CryptoProvider::get_default()
.expect("No default cryptographic provider specified")
.clone(),
CryptoProviderConfig::FromFeatureFlags => {
#[cfg(all(feature = "ring_provider", not(feature = "aws_lc_rs_provider")))]
{
return Arc::new(rustls::crypto::ring::default_provider());
}

#[cfg(all(feature = "aws_lc_rs_provider", not(feature = "ring_provider")))]
{
return Arc::new(rustls::crypto::aws_lc_rs::default_provider());
}

#[allow(unreachable_code)]
{
panic!("Ambiguous cryptographic provider feature flags specified. Only \"ring_provider\" or \"aws_lc_rs_provider\" may be specified")
}
}
CryptoProviderConfig::Provided(provider) => provider,
}
}
}
5 changes: 3 additions & 2 deletions runtime/swimos_remote/src/tls/net/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::{net::SocketAddr, sync::Arc};

use futures::{future::BoxFuture, FutureExt};
use rustls::crypto::CryptoProvider;
use rustls::pki_types::ServerName;
use rustls::RootCertStore;

Expand All @@ -40,14 +41,14 @@ impl RustlsClientNetworking {
}
}

pub fn try_from_config(
pub fn build(
resolver: Arc<Resolver>,
config: ClientConfig,
provider: Arc<CryptoProvider>,
) -> Result<Self, TlsError> {
let ClientConfig {
use_webpki_roots,
custom_roots,
provider,
} = config;
let mut root_store = RootCertStore::empty();
if use_webpki_roots {
Expand Down
11 changes: 5 additions & 6 deletions runtime/swimos_remote/src/tls/net/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use futures::{
stream::{unfold, BoxStream, FuturesUnordered},
Future, FutureExt, Stream, StreamExt, TryStreamExt,
};
use rustls::crypto::CryptoProvider;
use rustls::pki_types::PrivateKeyDer;
use rustls::KeyLogFile;
use rustls_pemfile::Item;
Expand Down Expand Up @@ -64,17 +65,15 @@ impl RustlsServerNetworking {
pub fn new(acceptor: TlsAcceptor) -> Self {
RustlsServerNetworking { acceptor }
}
}

impl TryFrom<ServerConfig> for RustlsServerNetworking {
type Error = TlsError;

fn try_from(config: ServerConfig) -> Result<Self, Self::Error> {
pub fn build(
config: ServerConfig,
provider: Arc<CryptoProvider>,
) -> Result<RustlsServerNetworking, TlsError> {
let ServerConfig {
chain: CertChain(certs),
key,
enable_log_file,
provider,
} = config;

let mut chain = vec![];
Expand Down
10 changes: 5 additions & 5 deletions runtime/swimos_remote/src/tls/net/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ fn make_server_config() -> ServerConfig {
chain,
key,
enable_log_file: false,
provider: Arc::new(aws_lc_rs::default_provider()),
}
}

Expand All @@ -62,17 +61,18 @@ fn make_client_config() -> ClientConfig {
ClientConfig {
use_webpki_roots: true,
custom_roots: vec![CertificateFile::der(ca_cert)],
provider: Arc::new(aws_lc_rs::default_provider()),
}
}

#[tokio::test]
async fn perform_handshake() {
let server_net =
RustlsServerNetworking::try_from(make_server_config()).expect("Invalid server config.");
let client_net = RustlsClientNetworking::try_from_config(
let crypto_provider = Arc::new(aws_lc_rs::default_provider());
let server_net = RustlsServerNetworking::build(make_server_config(), crypto_provider.clone())
.expect("Invalid server config.");
let client_net = RustlsClientNetworking::build(
Arc::new(Resolver::new().await),
make_client_config(),
crypto_provider,
)
.expect("Invalid client config.");

Expand Down
4 changes: 3 additions & 1 deletion server/swimos_server_app/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ authors = ["Swim Inc. developers [email protected]"]
edition = "2021"

[features]
default = []
default = ["aws_lc_rs_provider"]
rocks_store = ["swimos_rocks_store"]
trust_dns = ["swimos_runtime/trust_dns"]
ring_provider = ["swimos_remote/ring_provider"]
aws_lc_rs_provider = ["swimos_remote/aws_lc_rs_provider"]

[dependencies]
futures = { workspace = true }
Expand Down
Loading

0 comments on commit 6082bc7

Please sign in to comment.