diff --git a/.config/nextest.toml b/.config/nextest.toml index 2d2b2d93f..418f5463a 100644 --- a/.config/nextest.toml +++ b/.config/nextest.toml @@ -10,6 +10,8 @@ retries = { backoff = "exponential", count = 2, delay = "5s" } # the string "num-cpus". Can be overridden through the `--test-threads` option. test-threads = 1 +slow-timeout = { period = "80s", terminate-after = 3, grace-period = "0s" } + # Show these test statuses in the output. # # The possible values this can take are: @@ -25,10 +27,10 @@ test-threads = 1 # failed and retried tests. # # Can be overridden through the `--status-level` flag. -status-level = "pass" +status-level = "all" # Similar to status-level, show these test statuses at the end of the run. -final-status-level = "none" +final-status-level = "all" # "failure-output" defines when standard output and standard error for failing tests are produced. # Accepted values are diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index a9fafe182..22951e2a5 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -12,6 +12,7 @@ env: # 40 MiB stack RUST_MIN_STACK: 40971520 RUST_LOG: "citadel=warn" + IN_CI: "true" jobs: core_libs: @@ -95,14 +96,17 @@ jobs: - uses: Avarok-Cybersecurity/gh-actions-deps@master - uses: taiki-e/install-action@nextest - name: Single-threaded testing - run: cargo nextest run --package citadel_sdk --features=localhost-testing,localhost-testing-loopback-only + run: cargo nextest run --package citadel_sdk --features=localhost-testing if: ${{ !startsWith(matrix.os, 'windows') }} - name: Single-threaded testing (windows only) - run: cargo nextest run --package citadel_sdk --features=localhost-testing,localhost-testing-loopback-only,vendored + run: cargo nextest run --package citadel_sdk --features=localhost-testing,vendored + if: startsWith(matrix.os, 'windows') + - name: Multi-threaded testing (windows only) + run: cargo nextest run --package citadel_sdk --features=multi-threaded,localhost-testing,vendored if: startsWith(matrix.os, 'windows') - name: Multi-threaded testing - if: startsWith(matrix.os, 'ubuntu') - run: cargo nextest run --package citadel_sdk --features=multi-threaded,localhost-testing,localhost-testing-loopback-only + if: ${{ !startsWith(matrix.os, 'windows') }} + run: cargo nextest run --package citadel_sdk --features=multi-threaded,localhost-testing citadel_sdk_release: strategy: @@ -114,9 +118,9 @@ jobs: - uses: Avarok-Cybersecurity/gh-actions-deps@master - uses: taiki-e/install-action@nextest - name: Single-threaded testing - run: cargo nextest run --package citadel_sdk --features=localhost-testing,localhost-testing-loopback-only --release + run: cargo nextest run --package citadel_sdk --features=localhost-testing --release - name: Multi-threaded testing - run: cargo nextest run --package citadel_sdk --features=multi-threaded,localhost-testing,localhost-testing-loopback-only --release + run: cargo nextest run --package citadel_sdk --features=multi-threaded,localhost-testing --release misc_checks: name: miscellaneous diff --git a/Cargo.toml b/Cargo.toml index 266125454..4381f0d77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,8 +62,8 @@ rand = { default-features = false, version = "0.8.5" } async-stream = { default-features = false, version = "0.3.3" } sync_wrapper = { default-features = false, version = "1.0.0" } async-recursion = { version = "1.0.4" } -rstest = { version = "0.18.2" } -bincode2 = { default-features = false, version = "2.0.1" } +rstest = { version = "0.23.0" } +bincode = { default-features = false, version = "1.3.3" } serde = { version="1.0.152", default-features = false } futures = { version = "0.3.25", default-features = false } byteorder = { version = "1.4.3", default-features=false } @@ -82,7 +82,7 @@ async-trait-with-sync = { default-features = false, version = "0.1.36" } uuid = { version = "1.2.2", default-features = false } tracing = { version = "0.1.37", default-features = false } lazy_static = { default-features = false, version = "1.4.0" } -socket2 = { version = "0.5.1", default-features = false } +socket2 = { version = "0.5.7", default-features = false } rustls-native-certs = { version = "0.6.2", default-features = false } igd = { version = "^0.12.0", default-features = false } quinn = { version = "0.10.2", default-features = false } diff --git a/async_ip/src/lib.rs b/async_ip/src/lib.rs index aae3a4ff8..9a27e78d0 100644 --- a/async_ip/src/lib.rs +++ b/async_ip/src/lib.rs @@ -121,17 +121,7 @@ pub async fn get_ip_from(client: Option, addr: &str) -> Result { pub update_in_progress: Arc, // if local is initiator, then in the case both nodes send a FastMessage at the same time (causing an update to the keys), the initiator takes preference, and the non-initiator's upgrade attempt gets dropped (if update_in_progress) pub local_is_initiator: bool, - pub rolling_object_id: u64, + pub rolling_object_id: ObjectId, pub rolling_group_id: u64, pub lock_set_by_alice: Option, /// Alice sends to Bob, then bob updates internally the toolset. However. Bob can't send packets to Alice quite yet using that newest version. He must first wait from Alice to commit on her end and wait for an ACK. @@ -39,7 +41,7 @@ impl PeerSessionCrypto { toolset, update_in_progress: Arc::new(AtomicBool::new(false)), local_is_initiator, - rolling_object_id: 1, + rolling_object_id: ObjectId::random(), rolling_group_id: 0, lock_set_by_alice: None, latest_usable_version: 0, @@ -200,9 +202,8 @@ impl PeerSessionCrypto { self.rolling_group_id.wrapping_sub(1) } - pub fn get_and_increment_object_id(&mut self) -> u64 { - self.rolling_object_id = self.rolling_object_id.wrapping_add(1); - self.rolling_object_id.wrapping_sub(1) + pub fn get_next_object_id(&mut self) -> ObjectId { + Uuid::new_v4().as_u128().into() } /// Returns a new constructor only if a concurrent update isn't occurring @@ -230,7 +231,7 @@ impl PeerSessionCrypto { self.update_in_progress = Arc::new(AtomicBool::new(false)); self.lock_set_by_alice = None; self.rolling_group_id = 0; - self.rolling_object_id = 0; + self.rolling_object_id = ObjectId::random(); } /// Gets the parameters used at registrations diff --git a/citadel_crypt/src/entropy_bank.rs b/citadel_crypt/src/entropy_bank.rs index 5838c5a78..7555ea0f7 100644 --- a/citadel_crypt/src/entropy_bank.rs +++ b/citadel_crypt/src/entropy_bank.rs @@ -246,12 +246,12 @@ impl EntropyBank { /// Serializes self to a vector pub fn serialize_to_vec(&self) -> Result, CryptError> { - bincode2::serialize(self).map_err(|err| CryptError::DrillUpdateError(err.to_string())) + bincode::serialize(self).map_err(|err| CryptError::DrillUpdateError(err.to_string())) } /// Deserializes self from a set of bytes pub fn deserialize_from>(drill: T) -> Result> { - bincode2::deserialize(drill.as_ref()) + bincode::deserialize(drill.as_ref()) .map_err(|err| CryptError::DrillUpdateError(err.to_string())) } } diff --git a/citadel_crypt/src/lib.rs b/citadel_crypt/src/lib.rs index 985c41187..ea39dc3a3 100644 --- a/citadel_crypt/src/lib.rs +++ b/citadel_crypt/src/lib.rs @@ -18,6 +18,7 @@ pub mod prelude { pub use crate::entropy_bank::EntropyBank; pub use crate::misc::CryptError; pub use crate::packet_vector::PacketVector; + pub use crate::streaming_crypt_scrambler::FixedSizedSource; pub use crate::toolset::Toolset; pub use citadel_types::crypto::SecBuffer; pub use citadel_types::crypto::SecurityLevel; diff --git a/citadel_crypt/src/scramble/crypt_splitter.rs b/citadel_crypt/src/scramble/crypt_splitter.rs index df875328d..0afb49742 100644 --- a/citadel_crypt/src/scramble/crypt_splitter.rs +++ b/citadel_crypt/src/scramble/crypt_splitter.rs @@ -13,6 +13,7 @@ use crate::entropy_bank::EntropyBank; use crate::packet_vector::{generate_packet_vector, PacketVector}; use crate::prelude::CryptError; use crate::stacked_ratchet::Ratchet; +pub use citadel_types::prelude::ObjectId; #[cfg(not(target_family = "wasm"))] use rayon::prelude::*; @@ -62,7 +63,7 @@ pub fn generate_scrambler_metadata>( header_size_bytes: usize, security_level: SecurityLevel, group_id: u64, - object_id: u64, + object_id: ObjectId, enx: EncryptionAlgorithm, sig_alg: SigAlgorithm, transfer_type: &TransferType, @@ -141,7 +142,7 @@ fn get_scramble_encrypt_config<'a, R: Ratchet>( header_size_bytes: usize, security_level: SecurityLevel, group_id: u64, - object_id: u64, + object_id: ObjectId, transfer_type: &TransferType, empty_transfer: bool, ) -> Result< @@ -190,13 +191,13 @@ pub fn par_scramble_encrypt_group, R: Ratchet, F, const N: usize> static_aux_ratchet: &R, header_size_bytes: usize, target_cid: u64, - object_id: u64, + object_id: ObjectId, group_id: u64, transfer_type: TransferType, header_inscriber: F, ) -> Result, CryptError> where - F: Fn(&PacketVector, &EntropyBank, u64, u64, &mut BytesMut) + Send + Sync, + F: Fn(&PacketVector, &EntropyBank, ObjectId, u64, &mut BytesMut) + Send + Sync, { let mut plain_text = Cow::Borrowed(plain_text.as_ref()); @@ -211,10 +212,9 @@ where } if let TransferType::RemoteEncryptedVirtualFilesystem { security_level, .. } = &transfer_type { - log::trace!(target: "citadel", "Detected REVFS. Locally encrypting w/level {security_level:?} | Ratchet used: {} w/version {}", static_aux_ratchet.get_cid(), static_aux_ratchet.version()); + log::trace!(target: "citadel", "Detected REVFS. Locally encrypting object {object_id} w/level {security_level:?} | Ratchet used: {} w/version {}", static_aux_ratchet.get_cid(), static_aux_ratchet.version()); // pre-encrypt let local_encrypted = static_aux_ratchet.local_encrypt(plain_text, *security_level)?; - plain_text = Cow::Owned(local_encrypted); } @@ -303,9 +303,9 @@ fn scramble_encrypt_wave( msg_pqc: &PostQuantumContainer, scramble_drill: &EntropyBank, target_cid: u64, - object_id: u64, + object_id: ObjectId, header_size_bytes: usize, - header_inscriber: impl Fn(&PacketVector, &EntropyBank, u64, u64, &mut BytesMut) + Send + Sync, + header_inscriber: impl Fn(&PacketVector, &EntropyBank, ObjectId, u64, &mut BytesMut) + Send + Sync, ) -> Vec<(usize, PacketCoordinate)> { let ciphertext = msg_drill .encrypt(msg_pqc, bytes_to_encrypt_for_this_wave) @@ -336,7 +336,7 @@ pub fn oneshot_unencrypted_group_unified( plain_text: SecureMessagePacket, header_size_bytes: usize, group_id: u64, - object_id: u64, + object_id: ObjectId, empty_transfer: bool, ) -> Result, CryptError> { let len = plain_text.message_len() as u64; @@ -435,7 +435,7 @@ pub struct GroupReceiverConfig { // this is NOT inscribed; only for transmission pub header_size_bytes: u64, pub group_id: u64, - pub object_id: u64, + pub object_id: ObjectId, // only relevant for files. Note: if transfer type is RemoteVirtualFileystem, then, // the receiving endpoint won't decrypt the first level of encryption since the goal // is to keep the file remotely encrypted @@ -450,7 +450,7 @@ impl GroupReceiverConfig { #[allow(clippy::too_many_arguments)] pub fn new_refresh( group_id: u64, - object_id: u64, + object_id: ObjectId, header_size_bytes: u64, plaintext_length: u64, max_packet_payload_size: u32, diff --git a/citadel_crypt/src/stacked_ratchet.rs b/citadel_crypt/src/stacked_ratchet.rs index 7051c21da..ad73b7a03 100644 --- a/citadel_crypt/src/stacked_ratchet.rs +++ b/citadel_crypt/src/stacked_ratchet.rs @@ -586,23 +586,23 @@ pub mod constructor { impl BobToAliceTransfer { pub fn serialize_into(&self, buf: &mut BytesMut) -> Option<()> { - let len = bincode2::serialized_size(self).ok()?; + let len = bincode::serialized_size(self).ok()?; buf.reserve(len as usize); - bincode2::serialize_into(buf.writer(), self).ok() + bincode::serialize_into(buf.writer(), self).ok() } pub fn deserialize_from>(source: T) -> Option { - bincode2::deserialize(source.as_ref()).ok() + bincode::deserialize(source.as_ref()).ok() } } impl AliceToBobTransfer { pub fn serialize_to_vec(&self) -> Option> { - bincode2::serialize(self).ok() + bincode::serialize(self).ok() } pub fn deserialize_from(source: &[u8]) -> Option { - bincode2::deserialize(source).ok() + bincode::deserialize(source).ok() } /// Gets the declared new version diff --git a/citadel_crypt/src/streaming_crypt_scrambler.rs b/citadel_crypt/src/streaming_crypt_scrambler.rs index c466f2526..edad75ce1 100644 --- a/citadel_crypt/src/streaming_crypt_scrambler.rs +++ b/citadel_crypt/src/streaming_crypt_scrambler.rs @@ -15,6 +15,7 @@ use crate::stacked_ratchet::StackedRatchet; use citadel_io::Mutex; use citadel_io::{BlockingSpawn, BlockingSpawnError}; use citadel_types::crypto::SecurityLevel; +use citadel_types::prelude::ObjectId; use citadel_types::proto::TransferType; use futures::Future; use num_integer::Integer; @@ -41,11 +42,14 @@ impl FixedSizedSource for std::fs::File { /// Generic function for inscribing headers on packets pub trait HeaderInscriberFn: - for<'a> Fn(&'a PacketVector, &'a EntropyBank, u64, u64, &'a mut BytesMut) + Send + Sync + 'static + for<'a> Fn(&'a PacketVector, &'a EntropyBank, ObjectId, u64, &'a mut BytesMut) + + Send + + Sync + + 'static { } impl< - T: for<'a> Fn(&'a PacketVector, &'a EntropyBank, u64, u64, &'a mut BytesMut) + T: for<'a> Fn(&'a PacketVector, &'a EntropyBank, ObjectId, u64, &'a mut BytesMut) + Send + Sync + 'static, @@ -162,7 +166,7 @@ impl>> From for BytesSource { pub fn scramble_encrypt_source( mut source: S, max_group_size: Option, - object_id: u64, + object_id: ObjectId, group_sender: GroupChanneler, CryptError>>, stop: Receiver<()>, security_level: SecurityLevel, @@ -266,7 +270,7 @@ struct AsyncCryptScrambler { transfer_type: TransferType, file_len: usize, read_cursor: usize, - object_id: u64, + object_id: ObjectId, header_size_bytes: usize, target_cid: u64, group_id: u64, diff --git a/citadel_crypt/src/toolset.rs b/citadel_crypt/src/toolset.rs index b995a2171..17abe5c8e 100644 --- a/citadel_crypt/src/toolset.rs +++ b/citadel_crypt/src/toolset.rs @@ -239,12 +239,12 @@ impl Toolset { /// Serializes the toolset to a buffer pub fn serialize_to_vec(&self) -> Result, CryptError> { - bincode2::serialize(self).map_err(|err| CryptError::DrillUpdateError(err.to_string())) + bincode::serialize(self).map_err(|err| CryptError::DrillUpdateError(err.to_string())) } /// Deserializes from a slice of bytes pub fn deserialize_from_bytes>(input: T) -> Result> { - bincode2::deserialize(input.as_ref()) + bincode::deserialize(input.as_ref()) .map_err(|err| CryptError::DrillUpdateError(err.to_string())) } diff --git a/citadel_crypt/tests/primary.rs b/citadel_crypt/tests/primary.rs index 49d021c6a..61b779ea6 100644 --- a/citadel_crypt/tests/primary.rs +++ b/citadel_crypt/tests/primary.rs @@ -16,7 +16,7 @@ mod tests { AlgorithmsExt, CryptoParameters, EncryptionAlgorithm, KemAlgorithm, SecBuffer, SigAlgorithm, KEM_ALGORITHM_COUNT, }; - use citadel_types::proto::TransferType; + use citadel_types::proto::{ObjectId, TransferType}; use rstest::rstest; #[cfg(not(target_family = "wasm"))] use std::path::PathBuf; @@ -128,9 +128,9 @@ mod tests { #[test] fn test_sec_buffer() { let buf = SecBuffer::from("Hello, world!"); - let serde = bincode2::serialize(&buf).unwrap(); + let serde = bincode::serialize(&buf).unwrap(); std::mem::drop(buf); - let buf = bincode2::deserialize::(&serde).unwrap(); + let buf = bincode::deserialize::(&serde).unwrap(); assert_eq!(buf.as_ref(), b"Hello, world!"); let cloned = buf.clone(); @@ -197,9 +197,9 @@ mod tests { fn secbytes() { citadel_logging::setup_log(); let buf = SecBuffer::from("Hello, world!"); - let serde = bincode2::serialize(&buf).unwrap(); + let serde = bincode::serialize(&buf).unwrap(); std::mem::drop(buf); - let buf = bincode2::deserialize::(&serde).unwrap(); + let buf = bincode::deserialize::(&serde).unwrap(); assert_eq!(buf.as_ref(), b"Hello, world!"); let cloned = buf.clone(); @@ -728,7 +728,7 @@ mod tests { &pseudo_static_aux_ratchet_alice, HEADER_SIZE_BYTES, 0, - 0, + ObjectId::zero(), 0, transfer_type.clone(), |_vec, _drill, _target_cid, _, buffer| { @@ -765,7 +765,13 @@ mod tests { } const HEADER_LEN: usize = 52; - fn header_inscribe(_: &PacketVector, _: &EntropyBank, _: u64, _: u64, packet: &mut BytesMut) { + fn header_inscribe( + _: &PacketVector, + _: &EntropyBank, + _: ObjectId, + _: u64, + packet: &mut BytesMut, + ) { for x in 0..HEADER_LEN { packet.put_u8((x % 255) as u8) } @@ -896,7 +902,7 @@ mod tests { let (bytes, _num_groups, _mxbpg) = scramble_encrypt_source::<_, _, HEADER_LEN>( source, None, - 99, + ObjectId::zero(), group_sender_tx, stop_rx, security_level, diff --git a/citadel_logging/src/lib.rs b/citadel_logging/src/lib.rs index eba10aa86..c27a25511 100644 --- a/citadel_logging/src/lib.rs +++ b/citadel_logging/src/lib.rs @@ -7,7 +7,7 @@ use tracing_subscriber::EnvFilter; /// Sets up the logging for any crate pub fn setup_log() { std::panic::set_hook(Box::new(|info| { - error!(target: "citadel", "Panic occurred: {:#?}", info); + error!(target: "citadel", "Panic occurred: {}", info); std::process::exit(1); })); @@ -18,7 +18,7 @@ pub fn setup_log_no_panic_hook() { let _ = SubscriberBuilder::default() .with_line_number(true) .with_file(true) - .with_span_events(FmtSpan::FULL) + .with_span_events(FmtSpan::NONE) .with_env_filter(EnvFilter::from_default_env()) .finish() .try_init(); diff --git a/citadel_pqcrypto/Cargo.toml b/citadel_pqcrypto/Cargo.toml index 6aca8d2ed..5f84a7c8e 100644 --- a/citadel_pqcrypto/Cargo.toml +++ b/citadel_pqcrypto/Cargo.toml @@ -37,7 +37,7 @@ wasm = [] [dependencies] generic-array = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive", "rc"] } -bincode2 = { workspace = true } +bincode = { workspace = true } aes-gcm = { workspace = true, features = ["heapless", "aes", "alloc"]} chacha20poly1305 = { workspace = true, features = ["heapless", "alloc"] } bytes = { workspace = true } diff --git a/citadel_pqcrypto/src/encryption.rs b/citadel_pqcrypto/src/encryption.rs index f2bf697e3..24f06881c 100644 --- a/citadel_pqcrypto/src/encryption.rs +++ b/citadel_pqcrypto/src/encryption.rs @@ -295,7 +295,7 @@ pub(crate) mod kyber_module { // encrypt the 32-byte scramble dict using post-quantum pke let scram_crypt_ser = - bincode2::serialize(&scram_crypt_dict).map_err(|err| Error::Other(err.to_string()))?; + bincode::serialize(&scram_crypt_dict).map_err(|err| Error::Other(err.to_string()))?; let encrypted_scramble_dict = encrypt_pke(kem_alg, public_key, scram_crypt_ser, nonce)?; input @@ -333,7 +333,7 @@ pub(crate) mod kyber_module { let (_, encrypted_scramble_dict) = input.as_ref().split_at(split_pt); let decrypted_scramble_dict = decrypt_pke(kem_alg, local_sk, encrypted_scramble_dict)?; let scram_crypt_dict: ScramCryptDictionary<32> = - bincode2::deserialize(&decrypted_scramble_dict) + bincode::deserialize(&decrypted_scramble_dict) .map_err(|err| Error::Other(err.to_string()))?; // remove the encrypted scramble data from the input buf let truncate_point = input.len().saturating_sub(encrypted_scramble_dict_len); diff --git a/citadel_pqcrypto/src/lib.rs b/citadel_pqcrypto/src/lib.rs index ea2b3575d..363648e11 100644 --- a/citadel_pqcrypto/src/lib.rs +++ b/citadel_pqcrypto/src/lib.rs @@ -528,13 +528,13 @@ impl PostQuantumContainer { /// Serializes the entire package to a vector pub fn serialize_to_vector(&self) -> Result, Error> { - bincode2::serialize(self).map_err(|_err| Error::Generic("Deserialization failure")) + bincode::serialize(self).map_err(|_err| Error::Generic("Deserialization failure")) } /// Attempts to deserialize the input bytes presumed to be of type [PostQuantumExport], /// into a [PostQuantumContainer] pub fn deserialize_from_bytes>(bytes: B) -> Result { - bincode2::deserialize::(bytes.as_ref()) + bincode::deserialize::(bytes.as_ref()) .map_err(|_err| Error::Generic("Deserialization failure")) } diff --git a/citadel_pqcrypto/tests/primary.rs b/citadel_pqcrypto/tests/primary.rs index 8a7210e58..2a0432070 100644 --- a/citadel_pqcrypto/tests/primary.rs +++ b/citadel_pqcrypto/tests/primary.rs @@ -438,8 +438,8 @@ mod tests { .unwrap() } - #[should_panic] #[test] + #[should_panic] fn test_kyber_bad_psks() { citadel_logging::setup_log_no_panic_hook(); run( diff --git a/citadel_proto/Cargo.toml b/citadel_proto/Cargo.toml index 0fade6770..73c3fc6d5 100644 --- a/citadel_proto/Cargo.toml +++ b/citadel_proto/Cargo.toml @@ -20,7 +20,6 @@ redis = ["citadel_user/redis"] webrtc = ["webrtc-util"] localhost-testing = ["citadel_wire/localhost-testing", "citadel_user/localhost-testing", "tracing"] localhost-testing-assert-no-proxy = ["localhost-testing"] -localhost-testing-loopback-only = ["citadel_wire/localhost-testing-loopback-only"] google-services = ["citadel_user/google-services"] vendored = ["citadel_user/vendored", "citadel_wire/vendored"] diff --git a/citadel_proto/src/lib.rs b/citadel_proto/src/lib.rs index 71d4a9d07..c3820ef65 100644 --- a/citadel_proto/src/lib.rs +++ b/citadel_proto/src/lib.rs @@ -27,7 +27,7 @@ pub const fn build_tag() -> &'static str { pub mod macros { use either::Either; - use crate::proto::session::HdpSessionInner; + use crate::proto::session::CitadelSessionInner; pub type OwnedReadGuard<'a, T> = std::cell::Ref<'a, T>; pub type OwnedWriteGuard<'a, T> = std::cell::RefMut<'a, T>; @@ -44,7 +44,7 @@ pub mod macros { impl SyncContextRequirements for T {} pub type WeakBorrowType = std::rc::Weak>; - pub type SessionBorrow<'a> = std::cell::RefMut<'a, HdpSessionInner>; + pub type SessionBorrow<'a> = std::cell::RefMut<'a, CitadelSessionInner>; pub struct WeakBorrow { pub inner: std::rc::Weak>, @@ -177,7 +177,7 @@ pub mod macros { pub mod macros { use either::Either; - use crate::proto::session::HdpSessionInner; + use crate::proto::session::CitadelSessionInner; pub type OwnedReadGuard<'a, T> = citadel_io::RwLockReadGuard<'a, T>; pub type OwnedWriteGuard<'a, T> = citadel_io::RwLockWriteGuard<'a, T>; @@ -194,7 +194,7 @@ pub mod macros { impl SyncContextRequirements for T {} pub type WeakBorrowType = std::sync::Weak>; - pub type SessionBorrow<'a> = citadel_io::RwLockWriteGuard<'a, HdpSessionInner>; + pub type SessionBorrow<'a> = citadel_io::RwLockWriteGuard<'a, CitadelSessionInner>; pub struct WeakBorrow { pub inner: std::sync::Weak>, diff --git a/citadel_proto/src/proto/misc/underlying_proto.rs b/citadel_proto/src/proto/misc/underlying_proto.rs index 1a92f9ecf..ce1f3f80a 100644 --- a/citadel_proto/src/proto/misc/underlying_proto.rs +++ b/citadel_proto/src/proto/misc/underlying_proto.rs @@ -5,7 +5,7 @@ use citadel_user::re_exports::__private::Formatter; use citadel_wire::exports::{Certificate, PrivateKey}; use citadel_wire::tls::TLSQUICInterop; use std::fmt::Debug; -use std::net::{SocketAddr, TcpListener}; +use std::net::{SocketAddr, TcpListener, ToSocketAddrs}; use std::path::Path; use std::sync::Arc; @@ -23,12 +23,15 @@ impl ServerUnderlyingProtocol { Self::Tcp(None) } + pub fn new_tcp(bind_addr: T) -> Result { + let listener = citadel_wire::socket_helpers::get_tcp_listener(bind_addr)?; + Self::from_tokio_tcp_listener(listener) + } + /// Creates a new [`ServerUnderlyingProtocol`] with a preset [`std::net::TcpListener`] - pub fn from_tcp_listener(listener: TcpListener) -> Result { + pub fn from_std_tcp_listener(listener: TcpListener) -> Result { listener.set_nonblocking(true)?; - Ok(Self::Tcp(Some(Arc::new(Mutex::new(Some( - tokio::net::TcpListener::from_std(listener)?, - )))))) + Self::from_tokio_tcp_listener(tokio::net::TcpListener::from_std(listener)?) } /// Creates a new [`ServerUnderlyingProtocol`] with a preset [`tokio::net::TcpListener`] diff --git a/citadel_proto/src/proto/mod.rs b/citadel_proto/src/proto/mod.rs index 25adb91fb..c316cb9b7 100644 --- a/citadel_proto/src/proto/mod.rs +++ b/citadel_proto/src/proto/mod.rs @@ -1,6 +1,6 @@ use crate::proto::outbound_sender::OutboundPrimaryStreamSender; use crate::proto::packet::HdpHeader; -use crate::proto::session::HdpSession; +use crate::proto::session::CitadelSession; use crate::proto::state_container::StateContainerInner; use bytes::BytesMut; @@ -39,7 +39,7 @@ pub(crate) mod validation; /// Returns the preferred primary stream for returning a response pub(crate) fn get_preferred_primary_stream( header: &HdpHeader, - session: &HdpSession, + session: &CitadelSession, state_container: &StateContainerInner, ) -> Option { if header.target_cid.get() != 0 { diff --git a/citadel_proto/src/proto/node.rs b/citadel_proto/src/proto/node.rs index 507dfe98e..57f4e2863 100644 --- a/citadel_proto/src/proto/node.rs +++ b/citadel_proto/src/proto/node.rs @@ -33,7 +33,7 @@ use crate::proto::outbound_sender::{unbounded, BoundedReceiver, BoundedSender, U use crate::proto::packet_processor::includes::Duration; use crate::proto::peer::p2p_conn_handler::generic_error; use crate::proto::remote::{NodeRemote, Ticket}; -use crate::proto::session::{HdpSession, HdpSessionInitMode}; +use crate::proto::session::{CitadelSession, HdpSessionInitMode}; use crate::proto::session_manager::HdpSessionManager; use citadel_wire::exports::tokio_rustls::rustls::{ClientConfig, ServerName}; use citadel_wire::exports::Endpoint; @@ -160,7 +160,7 @@ impl Node { let node_type = read.local_node_type; let (session_spawner_tx, session_spawner_rx) = unbounded(); - let session_spawner = HdpSession::session_future_receiver(session_spawner_rx); + let session_spawner = CitadelSession::session_future_receiver(session_spawner_rx); let (outbound_send_request_tx, outbound_send_request_rx) = BoundedSender::new(MAX_OUTGOING_UNPROCESSED_REQUESTS); // for the Hdp remote diff --git a/citadel_proto/src/proto/packet_crafter.rs b/citadel_proto/src/proto/packet_crafter.rs index 9fac62a9b..21a398d86 100644 --- a/citadel_proto/src/proto/packet_crafter.rs +++ b/citadel_proto/src/proto/packet_crafter.rs @@ -12,6 +12,7 @@ use crate::proto::state_container::VirtualTargetType; use citadel_crypt::scramble::crypt_splitter::oneshot_unencrypted_group_unified; use citadel_crypt::secure_buffer::sec_packet::SecureMessagePacket; use citadel_crypt::stacked_ratchet::{Ratchet, StackedRatchet}; +use citadel_types::prelude::ObjectId; #[derive(Debug)] /// A thin wrapper used for convenient creation of zero-copy outgoing buffers @@ -63,7 +64,7 @@ pub struct GroupTransmitter { /// Contained within Self::group_transmitter, but is here for convenience group_config: GroupReceiverConfig, /// The ID of the object that is being transmitted - pub object_id: u64, + pub object_id: ObjectId, pub group_id: u64, /// For interfacing with the higher-level kernel ticket: Ticket, @@ -98,7 +99,7 @@ impl GroupTransmitter { to_primary_stream: OutboundPrimaryStreamSender, group_sender: GroupSenderDevice, hyper_ratchet: RatchetPacketCrafterContainer, - object_id: u64, + object_id: ObjectId, ticket: Ticket, security_level: SecurityLevel, time_tracker: TimeTracker, @@ -126,7 +127,7 @@ impl GroupTransmitter { #[allow(clippy::too_many_arguments)] pub fn new_message( to_primary_stream: OutboundPrimaryStreamSender, - object_id: u64, + object_id: ObjectId, hyper_ratchet: RatchetPacketCrafterContainer, input_packet: SecureProtocolPacket, security_level: SecurityLevel, @@ -241,6 +242,7 @@ pub(crate) mod group { use crate::proto::validation::group::{GroupHeader, GroupHeaderAck, WaveAck}; use citadel_crypt::endpoint_crypto_container::KemTransferStatus; use citadel_crypt::stacked_ratchet::StackedRatchet; + use citadel_types::proto::ObjectId; use citadel_user::serialization::SyncIO; use std::ops::RangeInclusive; @@ -296,7 +298,7 @@ pub(crate) mod group { packet }; - packet.put_u64(processor.object_id); + packet.put_u128(processor.object_id.0); processor .hyper_ratchet_container @@ -320,7 +322,7 @@ pub(crate) mod group { hyper_ratchet: &StackedRatchet, group_id: u64, target_cid: u64, - object_id: u64, + object_id: ObjectId, ticket: Ticket, initial_wave_window: Option>, fast_msg: bool, @@ -367,7 +369,7 @@ pub(crate) mod group { pub(crate) fn craft_wave_payload_packet_into( coords: &PacketVector, scramble_drill: &EntropyBank, - object_id: u64, + object_id: ObjectId, target_cid: u64, mut buffer: &mut BytesMut, ) { @@ -377,7 +379,7 @@ pub(crate) mod group { cmd_aux: packet_flags::cmd::aux::group::GROUP_PAYLOAD, algorithm: 0, security_level: 0, // Irrelevant; supplied by the wave header anyways - context_info: U128::new(object_id as _), + context_info: U128::new(object_id.0), group: U64::new(coords.group_id), wave_id: U32::new(coords.wave_id), session_cid: U64::new(scramble_drill.get_cid()), @@ -400,7 +402,7 @@ pub(crate) mod group { #[allow(clippy::too_many_arguments)] pub(crate) fn craft_wave_ack( hyper_ratchet: &StackedRatchet, - object_id: u32, + object_id: ObjectId, target_cid: u64, group_id: u64, wave_id: u32, @@ -414,7 +416,7 @@ pub(crate) mod group { cmd_aux: packet_flags::cmd::aux::group::WAVE_ACK, algorithm: 0, security_level: security_level.value(), - context_info: U128::new(object_id as _), + context_info: U128::new(object_id.0), group: U64::new(group_id), wave_id: U32::new(wave_id), session_cid: U64::new(hyper_ratchet.get_cid()), @@ -1598,7 +1600,7 @@ pub(crate) mod file { use citadel_crypt::stacked_ratchet::StackedRatchet; use citadel_types::crypto::SecurityLevel; use citadel_types::prelude::TransferType; - use citadel_types::proto::VirtualObjectMetadata; + use citadel_types::proto::{ObjectId, VirtualObjectMetadata}; use citadel_user::serialization::SyncIO; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -1607,7 +1609,7 @@ pub(crate) mod file { #[derive(Serialize, Deserialize, Debug)] pub struct FileTransferErrorPacket { pub error_message: String, - pub object_id: u64, + pub object_id: ObjectId, } pub(crate) fn craft_file_error_packet( @@ -1617,7 +1619,7 @@ pub(crate) mod file { virtual_target: VirtualTargetType, timestamp: i64, error_message: String, - object_id: u64, + object_id: ObjectId, ) -> BytesMut { let header = HdpHeader { protocol_version: (*crate::constants::PROTOCOL_VERSION).into(), @@ -1704,7 +1706,7 @@ pub(crate) mod file { pub struct FileHeaderAckPacket { pub success: bool, pub virtual_target: VirtualTargetType, - pub object_id: u64, + pub object_id: ObjectId, pub transfer_type: TransferType, } @@ -1712,7 +1714,7 @@ pub(crate) mod file { pub(crate) fn craft_file_header_ack_packet( hyper_ratchet: &StackedRatchet, success: bool, - object_id: u64, + object_id: ObjectId, target_cid: u64, ticket: Ticket, security_level: SecurityLevel, diff --git a/citadel_proto/src/proto/packet_processor/connect_packet.rs b/citadel_proto/src/proto/packet_processor/connect_packet.rs index c396c568a..9850ffd46 100644 --- a/citadel_proto/src/proto/packet_processor/connect_packet.rs +++ b/citadel_proto/src/proto/packet_processor/connect_packet.rs @@ -10,7 +10,7 @@ use std::sync::atomic::Ordering; /// This will optionally return an HdpPacket as a response if deemed necessary #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = sess_ref.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_connect( - sess_ref: &HdpSession, + sess_ref: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, ) -> Result { diff --git a/citadel_proto/src/proto/packet_processor/deregister_packet.rs b/citadel_proto/src/proto/packet_processor/deregister_packet.rs index f868e3e5b..6dcd8dbf2 100644 --- a/citadel_proto/src/proto/packet_processor/deregister_packet.rs +++ b/citadel_proto/src/proto/packet_processor/deregister_packet.rs @@ -8,7 +8,7 @@ use std::sync::atomic::Ordering; /// processes a deregister packet. The client must be connected to the HyperLAN Server in order to DeRegister #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_ref.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_deregister( - session_ref: &HdpSession, + session_ref: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, ) -> Result { @@ -90,7 +90,7 @@ pub async fn process_deregister( async fn deregister_client_from_self( implicated_cid: u64, - session_ref: &HdpSession, + session_ref: &CitadelSession, hyper_ratchet: &StackedRatchet, timestamp: i64, security_level: SecurityLevel, @@ -159,7 +159,7 @@ async fn deregister_client_from_self( async fn deregister_from_hyperlan_server_as_client( implicated_cid: u64, - session_ref: &HdpSession, + session_ref: &CitadelSession, ) -> Result { let session = session_ref; let (acc_manager, dereg_ticket) = { diff --git a/citadel_proto/src/proto/packet_processor/disconnect_packet.rs b/citadel_proto/src/proto/packet_processor/disconnect_packet.rs index e3f0c6fce..165073c33 100644 --- a/citadel_proto/src/proto/packet_processor/disconnect_packet.rs +++ b/citadel_proto/src/proto/packet_processor/disconnect_packet.rs @@ -9,7 +9,7 @@ pub const SUCCESS_DISCONNECT: &str = "Successfully Disconnected"; /// Stage 1: Bob sends Alice an FINAL, whereafter Alice may disconnect #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_disconnect( - session: &HdpSession, + session: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, ) -> Result { diff --git a/citadel_proto/src/proto/packet_processor/file_packet.rs b/citadel_proto/src/proto/packet_processor/file_packet.rs index 0bd0a7a87..39f32cdfe 100644 --- a/citadel_proto/src/proto/packet_processor/file_packet.rs +++ b/citadel_proto/src/proto/packet_processor/file_packet.rs @@ -12,7 +12,7 @@ use std::sync::atomic::Ordering; #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub fn process_file_packet( - session: &HdpSession, + session: &CitadelSession, packet: HdpPacket, proxy_cid_info: Option<(u64, u64)>, ) -> Result { @@ -203,8 +203,15 @@ pub fn process_file_packet( .revfs_get_file_info(revfs_cid, packet.virtual_path) .await { - Ok((source, local_encryption_level)) => { + Ok((source, metadata)) => { let transfer_type = TransferType::FileTransfer; // use a basic file transfer since we don't need to data to be locally encrypted when sending it back + let Some(local_encryption_level) = + metadata.get_security_level() + else { + log::error!(target: "citadel", "The requested file was not designated as a RE-VFS type, yet, a metadata file existed for it"); + return; + }; + match session.process_outbound_file( ticket, None, @@ -213,6 +220,7 @@ pub fn process_file_packet( packet.security_level, transfer_type, Some(local_encryption_level), + Some(metadata), move |source| { if delete_on_pull { spawn!(tokio::fs::remove_file(source)); diff --git a/citadel_proto/src/proto/packet_processor/hole_punch.rs b/citadel_proto/src/proto/packet_processor/hole_punch.rs index 5c8e52993..1164f83ab 100644 --- a/citadel_proto/src/proto/packet_processor/hole_punch.rs +++ b/citadel_proto/src/proto/packet_processor/hole_punch.rs @@ -7,7 +7,7 @@ use crate::proto::packet_processor::primary_group_packet::{ /// This will handle an inbound group packet #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub fn process_hole_punch( - session: &HdpSession, + session: &CitadelSession, packet: HdpPacket, hr_version: u32, proxy_cid_info: Option<(u64, u64)>, diff --git a/citadel_proto/src/proto/packet_processor/keep_alive_packet.rs b/citadel_proto/src/proto/packet_processor/keep_alive_packet.rs index a75c402b1..30a467e91 100644 --- a/citadel_proto/src/proto/packet_processor/keep_alive_packet.rs +++ b/citadel_proto/src/proto/packet_processor/keep_alive_packet.rs @@ -8,7 +8,7 @@ use std::sync::atomic::Ordering; #[allow(unused_results, unused_must_use)] #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_keep_alive( - session: &HdpSession, + session: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, ) -> Result { diff --git a/citadel_proto/src/proto/packet_processor/mod.rs b/citadel_proto/src/proto/packet_processor/mod.rs index 10b365920..845364e2d 100644 --- a/citadel_proto/src/proto/packet_processor/mod.rs +++ b/citadel_proto/src/proto/packet_processor/mod.rs @@ -18,7 +18,7 @@ pub mod includes { pub use crate::proto::node_result::NodeResult; pub(crate) use crate::proto::packet::packet_flags; pub use crate::proto::packet::{HdpHeader, HdpPacket}; - pub use crate::proto::session::{HdpSession, HdpSessionInner, SessionState}; + pub use crate::proto::session::{CitadelSession, CitadelSessionInner, SessionState}; pub(crate) use crate::proto::{packet_crafter, validation}; pub use super::super::state_container::VirtualConnectionType; diff --git a/citadel_proto/src/proto/packet_processor/peer/group_broadcast.rs b/citadel_proto/src/proto/packet_processor/peer/group_broadcast.rs index 5072654f2..a884e7a91 100644 --- a/citadel_proto/src/proto/packet_processor/peer/group_broadcast.rs +++ b/citadel_proto/src/proto/packet_processor/peer/group_broadcast.rs @@ -114,7 +114,7 @@ pub enum GroupBroadcast { #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_ref.is_server, src = header.session_cid.get(), target = header.target_cid.get())))] pub async fn process_group_broadcast( - session_ref: &HdpSession, + session_ref: &CitadelSession, header: Ref<&[u8], HdpHeader>, payload: &[u8], sess_hyper_ratchet: &StackedRatchet, @@ -702,7 +702,7 @@ pub async fn process_group_broadcast( fn create_group_channel( ticket: Ticket, key: MessageGroupKey, - session: &HdpSession, + session: &CitadelSession, ) -> Result { let channel = inner_mut_state!(session.state_container) .setup_group_channel_endpoints(key, ticket, session)?; @@ -732,7 +732,7 @@ impl From for GroupBroadcastPayload { } fn forward_signal( - session: &HdpSession, + session: &CitadelSession, ticket: Ticket, key: Option, broadcast: GroupBroadcast, diff --git a/citadel_proto/src/proto/packet_processor/peer/mod.rs b/citadel_proto/src/proto/packet_processor/peer/mod.rs index 0b02b0911..5aade0a40 100644 --- a/citadel_proto/src/proto/packet_processor/peer/mod.rs +++ b/citadel_proto/src/proto/packet_processor/peer/mod.rs @@ -1,6 +1,6 @@ use crate::error::NetworkError; use crate::prelude::{ConnectFail, NodeResult, Ticket}; -use crate::proto::session::HdpSession; +use crate::proto::session::CitadelSession; pub mod group_broadcast; pub mod peer_cmd_packet; @@ -8,7 +8,7 @@ pub mod server; pub mod signal_handler_interface; pub(crate) fn send_dc_signal_peer>( - session: &HdpSession, + session: &CitadelSession, ticket: Ticket, err: T, ) -> Result<(), NetworkError> { diff --git a/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs b/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs index 0d6766284..a9fadb32b 100644 --- a/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs +++ b/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs @@ -29,7 +29,7 @@ use crate::proto::peer::hole_punch_compat_sink_stream::ReliableOrderedCompatStre use crate::proto::peer::p2p_conn_handler::attempt_simultaneous_hole_punch; use crate::proto::peer::peer_crypt::{KeyExchangeProcess, PeerNatInfo}; use crate::proto::peer::peer_layer::{ - NodeConnectionType, PeerConnectionType, PeerResponse, PeerSignal, + HyperNodePeerLayerInner, NodeConnectionType, PeerConnectionType, PeerResponse, PeerSignal, }; use crate::proto::remote::Ticket; use crate::proto::session_manager::HdpSessionManager; @@ -41,7 +41,7 @@ use netbeam::sync::network_endpoint::NetworkEndpoint; /// HyperLAN client and the HyperLAN Server #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_orig.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_peer_cmd( - session_orig: &HdpSession, + session_orig: &CitadelSession, aux_cmd: u8, packet: HdpPacket, header_drill_version: u32, @@ -833,6 +833,8 @@ pub async fn process_peer_cmd( _ => {} } + log::trace!(target: "citadel", "Forwarding signal {signal:?} to kernel"); + session .kernel_tx .unbounded_send(NodeResult::PeerEvent(PeerEvent { @@ -868,7 +870,7 @@ pub async fn process_peer_cmd( } async fn process_signal_command_as_server( - sess_ref: &HdpSession, + sess_ref: &CitadelSession, signal: PeerSignal, ticket: Ticket, sess_hyper_ratchet: StackedRatchet, @@ -1005,7 +1007,7 @@ async fn process_signal_command_as_server( if let Some(ticket_new) = peer_layer.check_simultaneous_register(implicated_cid, target_cid) { - log::trace!(target: "citadel", "Simultaneous register detected! Simulating implicated_cid={} sent an accept_register to target={}", implicated_cid, target_cid); + log::info!(target: "citadel", "Simultaneous register detected! Simulating implicated_cid={} sent an accept_register to target={}", implicated_cid, target_cid); peer_layer.insert_mapped_ticket(implicated_cid, ticket_new, ticket); // route signal to peer drop(peer_layer); @@ -1055,7 +1057,6 @@ async fn process_signal_command_as_server( let to_primary_stream = return_if_none!(session.to_primary_stream.clone()); let sess_mgr = session.session_manager.clone(); - drop(peer_layer); route_signal_and_register_ticket_forwards( PeerSignal::PostRegister { peer_conn_type, @@ -1073,6 +1074,7 @@ async fn process_signal_command_as_server( &sess_mgr, &sess_hyper_ratchet, security_level, + &mut peer_layer, ) .await } @@ -1253,7 +1255,6 @@ async fn process_signal_command_as_server( .await?; Ok(PrimaryProcessorResult::Void) } else { - drop(peer_layer); route_signal_and_register_ticket_forwards( PeerSignal::PostConnect { peer_conn_type, @@ -1272,6 +1273,7 @@ async fn process_signal_command_as_server( &sess_mgr, &sess_hyper_ratchet, security_level, + &mut peer_layer, ) .await } @@ -1663,12 +1665,13 @@ pub(crate) async fn route_signal_and_register_ticket_forwards( sess_mgr: &HdpSessionManager, sess_hyper_ratchet: &StackedRatchet, security_level: SecurityLevel, + peer_layer: &mut HyperNodePeerLayerInner, ) -> Result { let sess_hyper_ratchet_2 = sess_hyper_ratchet.clone(); let to_primary_stream = to_primary_stream.clone(); // Give the target_cid 10 seconds to respond - let res = sess_mgr.route_signal_primary(implicated_cid, target_cid, ticket, signal.clone(), move |peer_hyper_ratchet| { + let res = sess_mgr.route_signal_primary(peer_layer, implicated_cid, target_cid, ticket, signal.clone(), move |peer_hyper_ratchet| { packet_crafter::peer_cmd::craft_peer_signal(peer_hyper_ratchet, signal.clone(), ticket, timestamp, security_level) }, timeout, move |stale_signal| { // on timeout, run this @@ -1708,9 +1711,9 @@ pub(crate) async fn route_signal_response( target_cid: u64, timestamp: i64, ticket: Ticket, - session: HdpSession, + session: CitadelSession, sess_hyper_ratchet: &StackedRatchet, - on_route_finished: impl FnOnce(&HdpSession, &HdpSession, PeerSignal), + on_route_finished: impl FnOnce(&CitadelSession, &CitadelSession, PeerSignal), security_level: SecurityLevel, ) -> Result { trace!(target: "citadel", "Routing signal {:?} | impl: {} | target: {}", signal, implicated_cid, target_cid); diff --git a/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs b/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs index c2f5e85d3..9dc6d5f41 100644 --- a/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs +++ b/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs @@ -4,7 +4,7 @@ use crate::proto::packet_processor::includes::VirtualConnectionType; use crate::proto::packet_processor::peer::peer_cmd_packet::route_signal_response; use crate::proto::packet_processor::PrimaryProcessorResult; use crate::proto::remote::Ticket; -use crate::proto::session::HdpSession; +use crate::proto::session::CitadelSession; use citadel_crypt::stacked_ratchet::StackedRatchet; use citadel_types::crypto::SecurityLevel; use citadel_types::proto::{SessionSecuritySettings, UdpMode}; @@ -20,7 +20,7 @@ pub(crate) async fn handle_response_phase_post_connect( implicated_cid: u64, target_cid: u64, timestamp: i64, - session: &HdpSession, + session: &CitadelSession, sess_hyper_ratchet: &StackedRatchet, security_level: SecurityLevel, ) -> Result { diff --git a/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs b/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs index 925057cec..7c663b823 100644 --- a/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs +++ b/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs @@ -4,7 +4,7 @@ use crate::proto::packet_processor::peer::peer_cmd_packet::route_signal_response use crate::proto::packet_processor::PrimaryProcessorResult; use crate::proto::peer::peer_layer::Username; use crate::proto::remote::Ticket; -use crate::proto::session::HdpSession; +use crate::proto::session::CitadelSession; use citadel_crypt::stacked_ratchet::StackedRatchet; use citadel_types::crypto::SecurityLevel; @@ -18,7 +18,7 @@ pub async fn handle_response_phase_post_register( implicated_cid: u64, target_cid: u64, timestamp: i64, - session: &HdpSession, + session: &CitadelSession, sess_hyper_ratchet: &StackedRatchet, security_level: SecurityLevel, ) -> Result { diff --git a/citadel_proto/src/proto/packet_processor/preconnect_packet.rs b/citadel_proto/src/proto/packet_processor/preconnect_packet.rs index d9ac72d17..1fbeeae1f 100644 --- a/citadel_proto/src/proto/packet_processor/preconnect_packet.rs +++ b/citadel_proto/src/proto/packet_processor/preconnect_packet.rs @@ -26,7 +26,7 @@ use std::sync::atomic::Ordering; /// Handles preconnect packets. Handles the NAT traversal #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_orig.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_preconnect( - session_orig: &HdpSession, + session_orig: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, ) -> Result { @@ -524,7 +524,7 @@ pub async fn process_preconnect( } fn begin_connect_process( - session: &HdpSession, + session: &CitadelSession, hyper_ratchet: &StackedRatchet, security_level: SecurityLevel, ) -> Result { @@ -560,7 +560,7 @@ fn begin_connect_process( fn send_success_as_initiator( udp_splittable: Option, hyper_ratchet: &StackedRatchet, - session: &HdpSession, + session: &CitadelSession, security_level: SecurityLevel, implicated_cid: u64, state_container: &mut StateContainerInner, @@ -579,7 +579,7 @@ fn send_success_as_initiator( fn handle_success_as_receiver( udp_splittable: Option, - session: &HdpSession, + session: &CitadelSession, implicated_cid: u64, state_container: &mut StateContainerInner, ) -> Result { @@ -601,7 +601,7 @@ fn handle_success_as_receiver( if let Some(udp_splittable) = udp_splittable { let peer_addr = udp_splittable.peer_addr(); // the UDP subsystem will automatically engage at this point - HdpSession::udp_socket_loader( + CitadelSession::udp_socket_loader( session.clone(), VirtualTargetType::LocalGroupServer { implicated_cid }, udp_splittable, @@ -675,10 +675,8 @@ fn proto_version_out_of_sync(adjacent_proto_version: u32) -> Result UdpSplittableTypes { log::trace!(target: "citadel", "Will use Raw UDP for UDP transmission"); - UdpSplittableTypes::Raw(RawUdpSocketConnector::new( - socket.socket, - socket.addr.send_address, - )) + let send_addr = socket.addr.send_address; + UdpSplittableTypes::Raw(RawUdpSocketConnector::new(socket.into_socket(), send_addr)) } fn get_quic_udp_interface(quic_conn: Connection, local_addr: SocketAddr) -> UdpSplittableTypes { diff --git a/citadel_proto/src/proto/packet_processor/primary_group_packet.rs b/citadel_proto/src/proto/packet_processor/primary_group_packet.rs index ba8bd1a17..012bf702f 100644 --- a/citadel_proto/src/proto/packet_processor/primary_group_packet.rs +++ b/citadel_proto/src/proto/packet_processor/primary_group_packet.rs @@ -17,6 +17,7 @@ use citadel_crypt::misc::CryptError; use citadel_crypt::stacked_ratchet::constructor::{AliceToBobTransferType, ConstructorType}; use citadel_crypt::stacked_ratchet::{Ratchet, RatchetType, StackedRatchet}; use citadel_types::crypto::SecrecyMode; +use citadel_types::prelude::ObjectId; use citadel_types::proto::UdpMode; use std::ops::Deref; use std::sync::atomic::Ordering; @@ -30,14 +31,14 @@ use std::sync::atomic::Ordering; /// will be provided. In this case, we must use the virtual conn's crypto #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_ref.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub fn process_primary_packet( - session_ref: &HdpSession, + session_ref: &CitadelSession, cmd_aux: u8, packet: HdpPacket, proxy_cid_info: Option<(u64, u64)>, ) -> Result { let session = session_ref; - let HdpSessionInner { + let CitadelSessionInner { time_tracker, state_container, state, @@ -130,7 +131,7 @@ pub fn process_primary_packet( if is_message { let (plaintext, transfer, object_id) = return_if_none!( validation::group::validate_message(&mut payload), - "Bad message packet" + "Bad GROUP HEADER packet" ); log::trace!(target: "citadel", "Recv FastMessage. version {} w/ CID {} (local CID: {})", hyper_ratchet.version(), hyper_ratchet.get_cid(), header.session_cid.get()); // Here, we do not go through all the fiasco like above. We just forward the message to the kernel, then send an ACK @@ -227,11 +228,11 @@ pub fn process_primary_packet( if group.has_begun { if group.receiver.has_expired(GROUP_EXPIRE_TIME_MS) { if state_container.meta_expiry_state.expired() { - log::error!(target: "citadel", "Inbound group {} has expired; removing for {}.", group_id, peer_cid); + log::warn!(target: "citadel", "Inbound group {} has expired; removing for {}.", group_id, peer_cid); if let Some(group) = state_container.inbound_groups.remove(&key) { - if group.object_id != 0 { + if group.object_id != ObjectId::zero() { // belongs to a file. Delete file; stop transmission - let key = FileKey::new(peer_cid, group.object_id); + let key = FileKey::new(group.object_id); if let Some(_file) = state_container.inbound_files.remove(&key) { // dropping this will automatically drop the future streaming to HD log::warn!(target: "citadel", "File transfer expired"); @@ -675,7 +676,7 @@ impl ToolsetUpdate<'_> { /// target_cid: from header.target_cid /// Returns: Ok(latest_hyper_ratchet) pub(crate) fn attempt_kem_as_alice_finish( - session: &HdpSession, + session: &CitadelSession, base_session_secrecy_mode: SecrecyMode, peer_cid: u64, target_cid: u64, @@ -786,7 +787,7 @@ pub(crate) fn attempt_kem_as_alice_finish( /// NOTE! Assumes the `hr` passed is the latest version IF the transfer is some pub(crate) fn attempt_kem_as_bob( - session: &HdpSession, + session: &CitadelSession, resp_target_cid: u64, header: &Ref<&[u8], HdpHeader>, transfer: Option, diff --git a/citadel_proto/src/proto/packet_processor/raw_primary_packet.rs b/citadel_proto/src/proto/packet_processor/raw_primary_packet.rs index ffdba5e34..b03800251 100644 --- a/citadel_proto/src/proto/packet_processor/raw_primary_packet.rs +++ b/citadel_proto/src/proto/packet_processor/raw_primary_packet.rs @@ -9,7 +9,7 @@ use crate::error::NetworkError; #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(implicated_cid=this_implicated_cid, is_server=session.is_server, packet_len=packet.len())))] pub async fn process_raw_packet( this_implicated_cid: Option, - session: &HdpSession, + session: &CitadelSession, remote_peer: SocketAddr, local_primary_port: u16, packet: BytesMut, @@ -129,7 +129,7 @@ pub(crate) fn check_proxy( cmd_aux: u8, header_session_cid: u64, target_cid: u64, - session: &HdpSession, + session: &CitadelSession, endpoint_cid_info: &mut Option<(u64, u64)>, recv_port_type: ReceivePortType, packet: HdpPacket, diff --git a/citadel_proto/src/proto/packet_processor/register_packet.rs b/citadel_proto/src/proto/packet_processor/register_packet.rs index cb4e25969..824d42a86 100644 --- a/citadel_proto/src/proto/packet_processor/register_packet.rs +++ b/citadel_proto/src/proto/packet_processor/register_packet.rs @@ -10,7 +10,7 @@ use std::sync::atomic::Ordering; /// This will handle a registration packet #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session_ref.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub async fn process_register( - session_ref: &HdpSession, + session_ref: &CitadelSession, packet: HdpPacket, remote_addr: SocketAddr, ) -> Result { @@ -113,7 +113,9 @@ pub async fn process_register( .state .store(SessionState::NeedsRegister, Ordering::Relaxed); - return Ok(PrimaryProcessorResult::Void); + return Ok(PrimaryProcessorResult::EndSession( + "Unable to validate STAGE0_REGISTER packet", + )); } } } else { @@ -324,7 +326,7 @@ pub async fn process_register( { Ok(new_cnac) => { if passwordless { - HdpSession::begin_connect(&session, &new_cnac)?; + CitadelSession::begin_connect(&session, &new_cnac)?; inner_mut_state!(session.state_container).cnac = Some(new_cnac); // begin_connect will handle the connection process from here on out diff --git a/citadel_proto/src/proto/packet_processor/rekey_packet.rs b/citadel_proto/src/proto/packet_processor/rekey_packet.rs index c5cd20405..55b9ae1d8 100644 --- a/citadel_proto/src/proto/packet_processor/rekey_packet.rs +++ b/citadel_proto/src/proto/packet_processor/rekey_packet.rs @@ -15,7 +15,7 @@ use std::sync::atomic::Ordering; #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub fn process_rekey( - session: &HdpSession, + session: &CitadelSession, packet: HdpPacket, header_drill_vers: u32, proxy_cid_info: Option<(u64, u64)>, @@ -25,7 +25,7 @@ pub fn process_rekey( return Ok(PrimaryProcessorResult::Void); } - let HdpSessionInner { + let CitadelSessionInner { state_container, time_tracker, .. diff --git a/citadel_proto/src/proto/packet_processor/udp_packet.rs b/citadel_proto/src/proto/packet_processor/udp_packet.rs index 35640f777..cbe00fba3 100644 --- a/citadel_proto/src/proto/packet_processor/udp_packet.rs +++ b/citadel_proto/src/proto/packet_processor/udp_packet.rs @@ -6,7 +6,7 @@ use crate::proto::packet_processor::primary_group_packet::get_resp_target_cid_fr /// This will handle an inbound group packet #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = _session.is_server, src = packet.parse().unwrap().0.session_cid.get(), target = packet.parse().unwrap().0.target_cid.get())))] pub fn process_udp_packet( - _session: &HdpSession, + _session: &CitadelSession, packet: HdpPacket, hr_version: u32, accessor: &EndpointCryptoAccessor, diff --git a/citadel_proto/src/proto/peer/p2p_conn_handler.rs b/citadel_proto/src/proto/peer/p2p_conn_handler.rs index 01d0beb5c..a878f524d 100644 --- a/citadel_proto/src/proto/peer/p2p_conn_handler.rs +++ b/citadel_proto/src/proto/peer/p2p_conn_handler.rs @@ -16,7 +16,7 @@ use crate::proto::packet_processor::includes::{Duration, Instant, SocketAddr}; use crate::proto::peer::peer_crypt::PeerNatInfo; use crate::proto::peer::peer_layer::PeerConnectionType; use crate::proto::remote::Ticket; -use crate::proto::session::HdpSession; +use crate::proto::session::CitadelSession; use crate::proto::state_container::VirtualConnectionType; use citadel_user::re_exports::__private::Formatter; use citadel_wire::exports::tokio_rustls::rustls; @@ -71,7 +71,7 @@ impl Drop for DirectP2PRemote { async fn setup_listener_non_initiator( local_bind_addr: SocketAddr, remote_addr: SocketAddr, - session: HdpSession, + session: CitadelSession, v_conn: VirtualConnectionType, hole_punched_addr: TargettedSocketAddr, ticket: Ticket, @@ -96,7 +96,7 @@ async fn setup_listener_non_initiator( async fn p2p_conn_handler( mut p2p_listener: GenericNetworkListener, - session: HdpSession, + session: CitadelSession, _necessary_remote_addr: SocketAddr, v_conn: VirtualConnectionType, hole_punched_addr: TargettedSocketAddr, @@ -112,7 +112,7 @@ async fn p2p_conn_handler( match p2p_listener.next().await { Some(Ok((p2p_stream, _))) => { - let session = HdpSession::upgrade_weak(weak) + let session = CitadelSession::upgrade_weak(weak) .ok_or(NetworkError::InternalError("HdpSession dropped"))?; /* @@ -153,7 +153,7 @@ async fn p2p_conn_handler( fn handle_p2p_stream( mut p2p_stream: GenericNetworkStream, implicated_cid: DualRwLock>, - session: HdpSession, + session: CitadelSession, kernel_tx: UnboundedSender, from_listener: bool, v_conn: VirtualConnectionType, @@ -191,9 +191,9 @@ fn handle_p2p_stream( p2p_primary_stream_tx.clone(), peer_cid, ); - let writer_future = HdpSession::outbound_stream(p2p_primary_stream_rx, sink); + let writer_future = CitadelSession::outbound_stream(p2p_primary_stream_rx, sink); let reader_future = - HdpSession::execute_inbound_stream(stream, session.clone(), Some(p2p_handle)); + CitadelSession::execute_inbound_stream(stream, session.clone(), Some(p2p_handle)); let stopper_future = p2p_stopper(stopper_rx); let direct_p2p_remote = DirectP2PRemote::new(stopper_tx, p2p_primary_stream_tx, from_listener); @@ -205,7 +205,7 @@ fn handle_p2p_stream( state_container .insert_direct_p2p_connection(direct_p2p_remote, v_conn.get_target_cid()) .map_err(|err| generic_error(err.into_string()))?; - HdpSession::udp_socket_loader( + CitadelSession::udp_socket_loader( sess.clone(), v_conn, UdpSplittableTypes::Quic(udp_conn), @@ -288,7 +288,7 @@ async fn p2p_stopper(receiver: Receiver<()>) -> Result<(), NetworkError> { pub(crate) async fn attempt_simultaneous_hole_punch( peer_connection_type: PeerConnectionType, ticket: Ticket, - session: HdpSession, + session: CitadelSession, peer_nat_info: PeerNatInfo, implicated_cid: DualRwLock>, kernel_tx: UnboundedSender, @@ -311,7 +311,7 @@ pub(crate) async fn attempt_simultaneous_hole_punch( .map_err(generic_error)?; let remote_connect_addr = hole_punched_socket.addr.send_address; let addr = hole_punched_socket.addr; - let local_addr = hole_punched_socket.socket.local_addr()?; + let local_addr = hole_punched_socket.local_addr()?; log::trace!(target: "citadel", "~!@ P2P UDP Hole-punch finished @!~ | is initiator: {}", is_initiator); app.sync().await.map_err(generic_error)?; @@ -319,8 +319,9 @@ pub(crate) async fn attempt_simultaneous_hole_punch( // if local IS the initiator, then start connecting. It should work if is_initiator { // give time for non-initiator to setup local bind + // TODO: Replace with biconn channel logic tokio::time::sleep(Duration::from_millis(200)).await; - let socket = hole_punched_socket.socket; + let socket = hole_punched_socket.into_socket(); let quic_endpoint = citadel_wire::quic::QuicClient::new_with_config(socket, client_config.clone()) .map_err(generic_error)?; diff --git a/citadel_proto/src/proto/peer/peer_layer.rs b/citadel_proto/src/proto/peer/peer_layer.rs index e9f4dbe20..308ee2dc6 100644 --- a/citadel_proto/src/proto/peer/peer_layer.rs +++ b/citadel_proto/src/proto/peer/peer_layer.rs @@ -466,7 +466,7 @@ impl HyperNodePeerLayerInner { ) -> Option { let this = self.inner.read(); let peer_map = this.observed_postings.get(&peer_cid)?; - log::trace!(target: "citadel", "[simultaneous checking] peer_map len: {}", peer_map.len()); + log::trace!(target: "citadel", "[simultaneous checking] peer_map len: {} | {:?}", peer_map.len(), peer_map.values().map(|r| &r.signal).collect::>()); peer_map .iter() .find(|(_, posting)| (fx)(posting)) diff --git a/citadel_proto/src/proto/session.rs b/citadel_proto/src/proto/session.rs index 8679727c3..2cb316130 100644 --- a/citadel_proto/src/proto/session.rs +++ b/citadel_proto/src/proto/session.rs @@ -70,7 +70,7 @@ use crate::proto::state_subcontainers::preconnect_state_container::UdpChannelSen use crate::proto::state_subcontainers::rekey_container::calculate_update_frequency; use crate::proto::transfer_stats::TransferStats; use atomic::Atomic; -use citadel_crypt::prelude::ConstructorOpts; +use citadel_crypt::prelude::{ConstructorOpts, FixedSizedSource}; use citadel_crypt::streaming_crypt_scrambler::{scramble_encrypt_source, ObjectSource}; use citadel_types::proto::TransferType; use citadel_user::backend::PersistenceHandler; @@ -94,11 +94,11 @@ use citadel_types::crypto::SecurityLevel; //define_outer_struct_wrapper!(HdpSession, HdpSessionInner); /// Allows a connection stream to be worked on by a single worker -pub struct HdpSession { +pub struct CitadelSession { #[cfg(not(feature = "multi-threaded"))] - pub inner: std::rc::Rc, + pub inner: std::rc::Rc, #[cfg(feature = "multi-threaded")] - pub inner: std::sync::Arc, + pub inner: std::sync::Arc, } enum SessionShutdownReason { @@ -106,7 +106,7 @@ enum SessionShutdownReason { Error(NetworkError), } -impl HdpSession { +impl CitadelSession { pub fn strong_count(&self) -> usize { #[cfg(not(feature = "multi-threaded"))] { @@ -120,28 +120,28 @@ impl HdpSession { } #[cfg(not(feature = "multi-threaded"))] - pub fn as_weak(&self) -> std::rc::Weak { + pub fn as_weak(&self) -> std::rc::Weak { std::rc::Rc::downgrade(&self.inner) } #[cfg(feature = "multi-threaded")] - pub fn as_weak(&self) -> std::sync::Weak { + pub fn as_weak(&self) -> std::sync::Weak { std::sync::Arc::downgrade(&self.inner) } #[cfg(feature = "multi-threaded")] - pub fn upgrade_weak(this: &std::sync::Weak) -> Option { + pub fn upgrade_weak(this: &std::sync::Weak) -> Option { this.upgrade().map(|inner| Self { inner }) } #[cfg(not(feature = "multi-threaded"))] - pub fn upgrade_weak(this: &std::rc::Weak) -> Option { + pub fn upgrade_weak(this: &std::rc::Weak) -> Option { this.upgrade().map(|inner| Self { inner }) } } -impl From for HdpSession { - fn from(inner: HdpSessionInner) -> Self { +impl From for CitadelSession { + fn from(inner: CitadelSessionInner) -> Self { #[cfg(not(feature = "multi-threaded"))] { Self { @@ -158,15 +158,15 @@ impl From for HdpSession { } } -impl Deref for HdpSession { - type Target = HdpSessionInner; +impl Deref for CitadelSession { + type Target = CitadelSessionInner; fn deref(&self) -> &Self::Target { self.inner.deref() } } -impl Clone for HdpSession { +impl Clone for CitadelSession { fn clone(&self) -> Self { Self { inner: self.inner.clone(), @@ -178,7 +178,7 @@ impl Clone for HdpSession { /// Structure for holding and keep track of packets, as well as basic connection information #[allow(unused)] -pub struct HdpSessionInner { +pub struct CitadelSessionInner { pub(super) implicated_cid: DualRwLock>, pub(super) kernel_ticket: DualCell, pub(super) remote_peer: SocketAddr, @@ -273,7 +273,7 @@ pub(crate) struct ClientOnlySessionInitSettings { pub connect_mode: Option, } -impl HdpSession { +impl CitadelSession { pub(crate) fn new( session_init_params: SessionInitParams, ) -> Result<(tokio::sync::broadcast::Sender<()>, Self), NetworkError> { @@ -356,7 +356,7 @@ impl HdpSession { let init_time = session_init_params.init_time; let session_password = session_init_params.session_password; - let mut inner = HdpSessionInner { + let mut inner = CitadelSessionInner { hypernode_peer_layer, connect_mode: DualRwLock::from(connect_mode), primary_stream_quic_conn: DualRwLock::from(None), @@ -410,7 +410,7 @@ impl HdpSession { Ok((stopper_tx, Self::from(inner))) } - /// Once the [HdpSession] is created, it can then be executed to begin handling a periodic connection handler. + /// Once the [CitadelSession] is created, it can then be executed to begin handling a periodic connection handler. /// This will automatically stop running once the internal state is set to Disconnected /// `tcp_stream`: this goes to the adjacent HyperNode /// `p2p_listener`: This is TCP listener bound to the same local_addr as tcp_stream. Required for TCP hole-punching @@ -550,7 +550,7 @@ impl HdpSession { zero_packet: Option, persistence_handler: PersistenceHandler, to_outbound: OutboundPrimaryStreamSender, - session: HdpSession, + session: CitadelSession, state: SessionState, timestamp: i64, cnac: Option, @@ -644,7 +644,7 @@ impl HdpSession { } pub(crate) fn begin_connect( - session: &HdpSession, + session: &CitadelSession, cnac: &ClientNetworkAccount, ) -> Result<(), NetworkError> { log::trace!(target: "citadel", "Beginning pre-connect subroutine!"); @@ -716,7 +716,7 @@ impl HdpSession { // tcp_conn_awaiter must be provided in order to know when the begin loading the UDP conn for the user. The TCP connection must first be loaded in order to place the udp conn inside the virtual_conn hashmap pub(crate) fn udp_socket_loader( - this: HdpSession, + this: CitadelSession, v_target: VirtualTargetType, udp_conn: UdpSplittableTypes, addr: TargettedSocketAddr, @@ -727,7 +727,7 @@ impl HdpSession { std::mem::drop(this); let task = async move { let (listener, udp_sender_future, stopper_rx) = { - let this = HdpSession::upgrade_weak(&this_weak) + let this = CitadelSession::upgrade_weak(&this_weak) .ok_or(NetworkError::InternalError("HdpSession no longer exists"))?; let sess = this; @@ -758,7 +758,7 @@ impl HdpSession { .map_err(|err| NetworkError::Generic(err.to_string()))?; } - let sess = HdpSession::upgrade_weak(&this_weak) + let sess = CitadelSession::upgrade_weak(&this_weak) .ok_or(NetworkError::InternalError("HdpSession no longer exists"))?; let accessor = match v_target { @@ -911,7 +911,7 @@ impl HdpSession { )] pub async fn execute_inbound_stream( mut reader: CleanShutdownStream, - this_main: HdpSession, + this_main: CitadelSession, p2p_handle: Option, ) -> Result<(), NetworkError> { let this_main = &this_main; @@ -957,12 +957,12 @@ impl HdpSession { result: Result, primary_stream: &OutboundPrimaryStreamSender, kernel_tx: &UnboundedSender, - session: &HdpSession, + session: &CitadelSession, cid_opt: Option, ) -> std::io::Result<()> { match result { Ok(PrimaryProcessorResult::ReplyToSender(return_packet)) => { - HdpSession::send_to_primary_stream_closure( + CitadelSession::send_to_primary_stream_closure( primary_stream, kernel_tx, return_packet, @@ -998,7 +998,7 @@ impl HdpSession { } fn handle_session_terminating_error( - session: &HdpSession, + session: &CitadelSession, err: std::io::Error, is_server: bool, peer_cid: Option, @@ -1107,7 +1107,7 @@ impl HdpSession { feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug)) )] - async fn execute_queue_worker(this_main: HdpSession) -> Result<(), NetworkError> { + async fn execute_queue_worker(this_main: CitadelSession) -> Result<(), NetworkError> { log::trace!(target: "citadel", "HdpSession async timer subroutine executed"); let queue_worker = { @@ -1386,6 +1386,7 @@ impl HdpSession { security_level: SecurityLevel, transfer_type: TransferType, local_encryption_level: Option, + virtual_object_metadata: Option, post_close_hook: impl for<'a> FnOnce(PathBuf) + Send + 'static, ) -> Result<(), NetworkError> { let this = self; @@ -1395,6 +1396,20 @@ impl HdpSession { let file = File::open(&source_path).map_err(|err| NetworkError::Generic(err.to_string()))?; + + if let Some(virtual_object_metadata) = &virtual_object_metadata { + let expected_min_length = virtual_object_metadata.plaintext_length; + let file_length = file + .length() + .map_err(|err| NetworkError::Generic(err.to_string()))?; + if file_length < expected_min_length as u64 { + log::warn!(target: "citadel", "The REVFS file cannot be pulled since it has not yet synchronized with the filesystem: Current file length: {file_length}, expected min length: {expected_min_length}"); + return Err(NetworkError::InternalError( + "The REVFS file cannot be pulled since it has not yet synchronized with the filesystem", + )); + } + } + let file_metadata = file .metadata() .map_err(|err| NetworkError::Generic(err.to_string()))?; @@ -1438,7 +1453,10 @@ impl HdpSession { .as_mut() .unwrap() .peer_session_crypto; - let object_id = crypt_container.get_and_increment_object_id(); + let object_id = virtual_object_metadata + .as_ref() + .map(|r| r.object_id) + .unwrap_or_else(|| crypt_container.get_next_object_id()); let group_id_start = crypt_container.get_and_increment_group_id(); let latest_hr = crypt_container.get_hyper_ratchet(None).cloned().unwrap(); let static_aux_ratchet = crypt_container @@ -1531,9 +1549,11 @@ impl HdpSession { return Err(NetworkError::msg("File transfer is not enabled for this p2p session. Both nodes must use a filesystem backend")); } - let object_id = endpoint_container - .endpoint_crypto - .get_and_increment_object_id(); + let object_id = virtual_object_metadata + .as_ref() + .map(|r| r.object_id) + .unwrap_or_else(|| endpoint_container.endpoint_crypto.get_next_object_id()); + // reserve group ids let start_group_id = endpoint_container .endpoint_crypto @@ -1544,12 +1564,6 @@ impl HdpSession { .get_hyper_ratchet(None) .unwrap(); - /*let static_aux_ratchet = endpoint_container - .endpoint_crypto - .toolset - .get_static_auxiliary_ratchet() - .clone();*/ - let preferred_primary_stream = endpoint_container .get_direct_p2p_primary_stream() .cloned() @@ -1648,7 +1662,7 @@ impl HdpSession { next_gs_alerter: next_gs_alerter.clone(), start: Some(start), }; - let file_key = FileKey::new(key_cid, object_id); + let file_key = FileKey::new(object_id); let _ = state_container .outbound_files .insert(file_key, outbound_file_transfer_container); @@ -1864,7 +1878,7 @@ impl HdpSession { // TODO: Make a generic version to allow requests the ability to bypass the session manager pub(crate) fn spawn_message_sender_function( - this: HdpSession, + this: CitadelSession, mut rx: tokio::sync::mpsc::Receiver, ) { let task = async move { @@ -2063,7 +2077,7 @@ impl HdpSession { } async fn listen_udp_port( - this: HdpSession, + this: CitadelSession, _hole_punched_addr_ip: IpAddr, local_port: u16, mut stream: S, @@ -2215,7 +2229,7 @@ impl HdpSession { } } -impl HdpSessionInner { +impl CitadelSessionInner { /// Stores the proposed credentials into the register state container pub(crate) fn store_proposed_credentials(&mut self, proposed_credentials: ProposedCredentials) { let mut state_container = inner_mut_state!(self.state_container); @@ -2370,7 +2384,7 @@ impl HdpSessionInner { } } -impl Drop for HdpSessionInner { +impl Drop for CitadelSessionInner { fn drop(&mut self) { log::trace!(target: "citadel", "*** Dropping HdpSession {:?} ***", self.implicated_cid.get()); self.send_session_dc_signal(None, false, "Session dropped"); diff --git a/citadel_proto/src/proto/session_manager.rs b/citadel_proto/src/proto/session_manager.rs index faea43786..368c06d0a 100644 --- a/citadel_proto/src/proto/session_manager.rs +++ b/citadel_proto/src/proto/session_manager.rs @@ -31,11 +31,12 @@ use crate::proto::packet_processor::includes::{Duration, Instant}; use crate::proto::packet_processor::peer::group_broadcast::GroupBroadcast; use crate::proto::packet_processor::PrimaryProcessorResult; use crate::proto::peer::peer_layer::{ - HyperNodePeerLayer, MailboxTransfer, PeerConnectionType, PeerResponse, PeerSignal, + HyperNodePeerLayer, HyperNodePeerLayerInner, MailboxTransfer, PeerConnectionType, PeerResponse, + PeerSignal, }; use crate::proto::remote::{NodeRemote, Ticket}; use crate::proto::session::{ - ClientOnlySessionInitSettings, HdpSession, HdpSessionInitMode, SessionInitParams, + CitadelSession, ClientOnlySessionInitSettings, HdpSessionInitMode, SessionInitParams, }; use crate::proto::state_container::{VirtualConnectionType, VirtualTargetType}; use citadel_crypt::streaming_crypt_scrambler::ObjectSource; @@ -56,7 +57,7 @@ define_outer_struct_wrapper!(HdpSessionManager, HdpSessionManagerInner); /// Used for handling stateful connections between two peer pub struct HdpSessionManagerInner { local_node_type: NodeType, - pub(crate) sessions: HashMap, HdpSession)>, + pub(crate) sessions: HashMap, CitadelSession)>, account_manager: AccountManager, pub(crate) hypernode_peer_layer: HyperNodePeerLayer, server_remote: Option, @@ -64,7 +65,7 @@ pub struct HdpSessionManagerInner { /// Connections which have no implicated CID go herein. They are strictly expected to be /// in the state of NeedsRegister. Once they leave that state, they are eventually polled /// by the [HdpSessionManager] and thereafter placed inside an appropriate session - pub provisional_connections: HashMap, HdpSession)>, + pub provisional_connections: HashMap, CitadelSession)>, kernel_tx: UnboundedSender, time_tracker: TimeTracker, clean_shutdown_tracker_tx: UnboundedSender<()>, @@ -311,7 +312,7 @@ impl HdpSessionManager { session_password, }; - let (stopper, new_session) = HdpSession::new(session_init_params)?; + let (stopper, new_session) = CitadelSession::new(session_init_params)?; if let Some((_prev_conn_init_time, _stopper, lingering_session)) = inner_mut!(self) .provisional_connections @@ -343,7 +344,7 @@ impl HdpSessionManager { #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(implicated_cid=new_session.implicated_cid.get(), is_server=new_session.is_server, peer_addr=peer_addr.to_string())))] async fn execute_session_with_safe_shutdown( session_manager: HdpSessionManager, - new_session: HdpSession, + new_session: CitadelSession, peer_addr: SocketAddr, tcp_stream: GenericNetworkStream, ) -> Result<(), NetworkError> { @@ -517,7 +518,7 @@ impl HdpSessionManager { session_password, }; - let (stopper, new_session) = HdpSession::new(session_init_params)?; + let (stopper, new_session) = CitadelSession::new(session_init_params)?; this.provisional_connections .insert(peer_addr, (init_time, stopper, new_session.clone())); drop(this); @@ -573,6 +574,7 @@ impl HdpSessionManager { security_level, transfer_type, local_encryption_level, + None, |_| {}, ) } else { @@ -1079,6 +1081,7 @@ impl HdpSessionManager { #[allow(clippy::too_many_arguments)] pub async fn route_signal_primary( &self, + peer_layer: &mut HyperNodePeerLayerInner, implicated_cid: u64, target_cid: u64, ticket: Ticket, @@ -1107,11 +1110,7 @@ impl HdpSessionManager { // get the target cid's session if let Some(ref sess_ref) = sess { - sess_ref - .hypernode_peer_layer - .inner - .write() - .await + peer_layer .insert_tracked_posting(implicated_cid, timeout, ticket, signal, on_timeout) .await; let peer_sender = sess_ref.to_primary_stream.as_ref().unwrap(); @@ -1126,11 +1125,7 @@ impl HdpSessionManager { // session is not active, but user is registered (thus offline). Setup return ticket tracker on implicated_cid // and deliver to the mailbox of target_cid, that way target_cid receives mail on connect. TODO: external svc route, if available { - let peer_layer = { inner!(self).hypernode_peer_layer.clone() }; peer_layer - .inner - .write() - .await .insert_tracked_posting( implicated_cid, timeout, @@ -1254,9 +1249,12 @@ impl HdpSessionManager { implicated_cid: u64, target_cid: u64, ticket: Ticket, - session: &HdpSession, + session: &CitadelSession, packet: impl FnOnce(&StackedRatchet) -> BytesMut, - post_send: impl FnOnce(&HdpSession, PeerSignal) -> Result, + post_send: impl FnOnce( + &CitadelSession, + PeerSignal, + ) -> Result, ) -> Result, String> { // Instead of checking for registration, check the `implicated_cid`'s timed queue for a ticket corresponding to Ticket. let tracked_posting = { diff --git a/citadel_proto/src/proto/state_container.rs b/citadel_proto/src/proto/state_container.rs index a1bcc388f..83a4e0024 100644 --- a/citadel_proto/src/proto/state_container.rs +++ b/citadel_proto/src/proto/state_container.rs @@ -36,7 +36,7 @@ use crate::proto::packet_crafter::peer_cmd::C2S_ENCRYPTION_ONLY; use crate::proto::packet_crafter::{ GroupTransmitter, RatchetPacketCrafterContainer, SecureProtocolPacket, }; -use crate::proto::packet_processor::includes::{HdpSession, Instant, SocketAddr}; +use crate::proto::packet_processor::includes::{CitadelSession, Instant, SocketAddr}; use crate::proto::packet_processor::peer::group_broadcast::GroupBroadcast; use crate::proto::packet_processor::PrimaryProcessorResult; use crate::proto::peer::channel::{PeerChannel, UdpChannel}; @@ -62,6 +62,7 @@ use citadel_crypt::stacked_ratchet::{Ratchet, StackedRatchet}; use citadel_types::crypto::SecBuffer; use citadel_types::crypto::SecrecyMode; use citadel_types::crypto::SecurityLevel; +use citadel_types::prelude::ObjectId; use citadel_types::proto::{ MessageGroupKey, ObjectTransferOrientation, ObjectTransferStatus, SessionSecuritySettings, TransferType, UdpMode, VirtualObjectMetadata, @@ -136,21 +137,20 @@ pub struct StateContainerInner { pub(crate) struct GroupKey { target_cid: u64, group_id: u64, - object_id: u64, + object_id: ObjectId, } #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)] pub struct FileKey { - pub target_cid: u64, // wave payload get the object id inscribed - pub object_id: u64, + pub object_id: ObjectId, } /// when the GROUP_HEADER comes inbound with virtual file metadata, this should be created alongside /// an async task fired-up on the threadpool #[allow(dead_code)] pub(crate) struct InboundFileTransfer { - pub object_id: u64, + pub object_id: ObjectId, pub total_groups: usize, pub groups_rendered: usize, pub last_group_window_len: usize, @@ -176,7 +176,7 @@ pub(crate) struct OutboundFileTransfer { } impl GroupKey { - pub fn new(target_cid: u64, group_id: u64, object_id: u64) -> Self { + pub fn new(target_cid: u64, group_id: u64, object_id: ObjectId) -> Self { Self { target_cid, group_id, @@ -186,11 +186,8 @@ impl GroupKey { } impl FileKey { - pub fn new(target_cid: u64, object_id: u64) -> Self { - Self { - target_cid, - object_id, - } + pub fn new(object_id: ObjectId) -> Self { + Self { object_id } } } @@ -392,6 +389,22 @@ impl VirtualConnectionType { _ => {} } } + + pub fn set_implicated_cid(&mut self, cid: u64) { + match self { + VirtualConnectionType::LocalGroupPeer { + implicated_cid, + peer_cid: _, + } + | VirtualConnectionType::ExternalGroupPeer { + implicated_cid, + interserver_cid: _, + peer_cid: _, + } => *implicated_cid = cid, + + _ => {} + } + } } impl Display for VirtualConnectionType { @@ -510,12 +523,12 @@ pub(crate) struct GroupReceiverContainer { max_window_size: usize, window_drift: isize, waves_in_window_finished: usize, - pub object_id: u64, + pub object_id: ObjectId, } impl GroupReceiverContainer { pub fn new( - object_id: u64, + object_id: ObjectId, receiver: GroupReceiver, virtual_target: VirtualTargetType, security_level: SecurityLevel, @@ -805,7 +818,7 @@ impl StateContainerInner { target_cid: u64, connection_type: VirtualConnectionType, endpoint_crypto: PeerSessionCrypto, - sess: &HdpSession, + sess: &CitadelSession, file_transfer_compatible: bool, ) -> PeerChannel { let (channel_tx, channel_rx) = unbounded(); @@ -827,7 +840,7 @@ impl StateContainerInner { tx, ); let to_channel = OrderedChannel::new(channel_tx); - HdpSession::spawn_message_sender_function(sess.clone(), rx); + CitadelSession::spawn_message_sender_function(sess.clone(), rx); let endpoint_container = Some(EndpointChannelContainer { default_security_settings, @@ -861,7 +874,7 @@ impl StateContainerInner { security_level: SecurityLevel, channel_ticket: Ticket, implicated_cid: u64, - session: &HdpSession, + session: &CitadelSession, ) -> PeerChannel { let (channel_tx, channel_rx) = unbounded(); let (tx, rx) = crate::proto::outbound_sender::channel(MAX_OUTGOING_UNPROCESSED_REQUESTS); @@ -876,7 +889,7 @@ impl StateContainerInner { channel_rx, tx, ); - HdpSession::spawn_message_sender_function(session.clone(), rx); + CitadelSession::spawn_message_sender_function(session.clone(), rx); let c2s = C2SChannelContainer { to_channel: OrderedChannel::new(channel_tx), @@ -1048,9 +1061,9 @@ impl StateContainerInner { ticket.into(), ); // check to see if we need to copy the last wave window - let last_window_size = if object_id != 0 { + let last_window_size = if object_id != ObjectId::zero() { // copy previous window - let file_key = FileKey::new(header.session_cid.get(), object_id); + let file_key = FileKey::new(object_id); if let Some(inbound_file_transfer) = self.inbound_files.get(&file_key) { inbound_file_transfer.last_group_window_len } else { @@ -1095,22 +1108,26 @@ impl StateContainerInner { pers: &PersistenceHandler, state_container: StateContainer, hyper_ratchet: StackedRatchet, - target_cid: u64, + _target_cid: u64, v_target_flipped: VirtualTargetType, preferred_primary_stream: OutboundPrimaryStreamSender, local_encryption_level: Option, ) -> bool { - let key = FileKey::new(header.session_cid.get(), metadata_orig.object_id); + let target_cid = v_target_flipped.get_target_cid(); + let implicated_cid = v_target_flipped.get_implicated_cid(); + let object_id = metadata_orig.object_id; + + let key = FileKey::new(object_id); let ticket = header.context_info.get().into(); let is_revfs_pull = local_encryption_level.is_some(); + log::trace!(target: "citadel", "File header {implicated_cid}: {key:?} | revfs_pull: {is_revfs_pull}"); + if let std::collections::hash_map::Entry::Vacant(e) = self.inbound_files.entry(key) { let (stream_to_hd, stream_to_hd_rx) = unbounded::>(); - let (start_recv_tx, start_recv_rx) = tokio::sync::oneshot::channel::(); let security_level_rebound: SecurityLevel = header.security_level.into(); let timestamp = self.time_tracker.get_global_time_ns(); - let object_id = metadata_orig.object_id; let pers = pers.clone(); let metadata = metadata_orig.clone(); let tt = self.time_tracker; @@ -1129,39 +1146,47 @@ impl StateContainerInner { local_encryption_level, }; + let (start_recv_tx, start_recv_rx) = if !is_revfs_pull { + let (tx, rx) = tokio::sync::oneshot::channel(); + (Some(tx), Some(rx)) + } else { + (None, None) + }; + e.insert(entry); let (handle, tx_status) = ObjectTransferHandler::new( - header.session_cid.get(), - header.target_cid.get(), + target_cid, + implicated_cid, metadata.clone(), ObjectTransferOrientation::Receiver { is_revfs_pull }, - Some(start_recv_tx), - ); - self.file_transfer_handles.insert( - key, - crate::proto::outbound_sender::UnboundedSender(tx_status.clone()), + start_recv_tx, ); + self.file_transfer_handles + .insert(key, UnboundedSender(tx_status.clone())); // finally, alert the kernel (receiver) if let Err(err) = self .kernel_tx .unbounded_send(NodeResult::ObjectTransferHandle(ObjectTransferHandle { ticket, handle, - implicated_cid: hyper_ratchet.get_cid(), + implicated_cid, })) { log::error!(target: "citadel", "Failed to send the ObjectTransferHandle to the kernel: {err:?}"); } + let is_server = self.is_server; + let task = async move { - let res = if is_revfs_pull { - // auto-accept for revfs pull requests - log::trace!(target: "citadel", "Auto-accepting for REVFS pull request"); - Ok(true) + log::info!(target: "citadel", "File transfer initiated, awaiting acceptance ... | revfs_pull: {is_revfs_pull}"); + let res = if let Some(start_rx) = start_recv_rx { + start_rx.await } else { - start_recv_rx.await + Ok(true) }; + log::info!(target: "citadel", "File transfer initiated! | revfs_pull: {is_revfs_pull}"); + let accepted = res.as_ref().map(|r| *r).unwrap_or(false); // first, send a rebound signal immediately to the sender // to ensure the sender knows if the user accepted or not @@ -1196,13 +1221,15 @@ impl StateContainerInner { .await { Ok(()) => { - log::info!(target: "citadel", "Successfully synced file to backend | {is_revfs_pull}"); + // TODO: Consider adding a function that waits for the actual file size to be equal to the metadata plaintext length + // in order to not allow the kernel logic to prematurely read the file contents while still syncing. + log::info!(target: "citadel", "Successfully synced file to backend | revfs_pull: {is_revfs_pull} | is_server: {is_server}"); let status = match success_receiving_rx.await { Ok(header) => { // write the header let wave_ack = packet_crafter::group::craft_wave_ack( &hyper_ratchet, - header.context_info.get() as u32, + object_id, get_resp_target_cid_from_header(&header), header.group.get(), header.wave_id.get(), @@ -1234,7 +1261,13 @@ impl StateContainerInner { } } } else { + if let Err(err) = tx_status.send(ObjectTransferStatus::Fail( + "User did not accept file transfer".to_string(), + )) { + log::error!(target: "citadel", "Unable to send object transfer status to handle: {err:?}"); + } // user did not accept. cleanup local + log::warn!(target: "citadel", "User did not accept file transfer"); let mut state_container = inner_mut_state!(state_container); let _ = state_container.inbound_files.remove(&key); let _ = state_container.file_transfer_handles.remove(&key); @@ -1272,7 +1305,7 @@ impl StateContainerInner { success: bool, implicated_cid: u64, ticket: Ticket, - object_id: u64, + object_id: ObjectId, v_target: VirtualTargetType, _transfer_type: TransferType, ) -> Option<()> { @@ -1283,11 +1316,11 @@ impl StateContainerInner { } => { let receiver_cid = implicated_cid; // since the order hasn't flipped yet, get the implicated cid - (FileKey::new(implicated_cid, object_id), receiver_cid) + (FileKey::new(object_id), receiver_cid) } VirtualConnectionType::LocalGroupServer { implicated_cid } => { - (FileKey::new(implicated_cid, object_id), implicated_cid) + (FileKey::new(object_id), implicated_cid) } _ => { @@ -1357,12 +1390,12 @@ impl StateContainerInner { #[allow(clippy::too_many_arguments)] pub fn on_group_header_ack_received( &mut self, - session: &HdpSession, + session: &CitadelSession, base_session_secrecy_mode: SecrecyMode, peer_cid: u64, target_cid: u64, group_id: u64, - object_id: u64, + object_id: ObjectId, next_window: Option>, transfer: KemTransferStatus, fast_msg: bool, @@ -1412,7 +1445,7 @@ impl StateContainerInner { &mut self, header: &HdpHeader, error_message: T, - object_id: u64, + object_id: ObjectId, ) -> Result<(), NetworkError> { let target_cid = header.session_cid.get(); self.notify_object_transfer_handle_failure_with(target_cid, object_id, error_message) @@ -1420,12 +1453,12 @@ impl StateContainerInner { pub fn notify_object_transfer_handle_failure_with>( &mut self, - target_cid: u64, - object_id: u64, + _target_cid: u64, + object_id: ObjectId, error_message: T, ) -> Result<(), NetworkError> { // let group_key = GroupKey::new(target_cid, group_id, object_id); - let file_key = FileKey::new(target_cid, object_id); + let file_key = FileKey::new(object_id); let file_transfer_handle = self.file_transfer_handles .get_mut(&file_key) @@ -1445,10 +1478,10 @@ impl StateContainerInner { header: &HdpHeader, payload: Bytes, hr: &StackedRatchet, - ) -> Result { + ) -> Result { let target_cid = header.session_cid.get(); let group_id = header.group.get(); - let object_id = header.context_info.get() as u64; + let object_id = header.context_info.get().into(); let group_key = GroupKey::new(target_cid, group_id, object_id); let grc = self.inbound_groups.get_mut(&group_key).ok_or_else(|| { ( @@ -1456,12 +1489,12 @@ impl StateContainerInner { "inbound_groups does not contain key for {group_key:?}" )), Ticket(0), - 0, + 0.into(), ) })?; let ticket = grc.ticket; - let file_key = FileKey::new(target_cid, grc.object_id); + let file_key = FileKey::new(grc.object_id); let file_container = self.inbound_files.get_mut(&file_key).ok_or_else(|| { ( NetworkError::msg(format!( @@ -1521,9 +1554,10 @@ impl StateContainerInner { GroupReceiverStatus::GROUP_COMPLETE(_last_wid) => { let receiver = self.inbound_groups.remove(&group_key).unwrap().receiver; let mut chunk = receiver.finalize(); - log::info!(target: "citadel", "GROUP {} COMPLETE. Total groups: {} | Plaintext len: {} | Received plaintext len: {}", group_id, file_container.total_groups, file_container.metadata.plaintext_length, chunk.len()); + log::trace!(target: "citadel", "GROUP {} COMPLETE. Total groups: {} | Plaintext len: {} | Received plaintext len: {}", group_id, file_container.total_groups, file_container.metadata.plaintext_length, chunk.len()); if let Some(local_encryption_level) = file_container.local_encryption_level { + log::trace!(target: "citadel", "Detected REVFS. Locally decrypting object {object_id} with level {local_encryption_level:?} | Ratchet used: {} w/version {}", hr.get_cid(), hr.version()); // which static hr do we need? Since we are receiving this chunk, always our local account's let static_aux_hr = self .cnac @@ -1543,7 +1577,7 @@ impl StateContainerInner { send_wave_ack = true; - if group_id as usize == file_container.total_groups.saturating_sub(1) { + if group_id as usize >= file_container.total_groups.saturating_sub(1) { complete = true; let file_container = self.inbound_files.remove(&file_key).unwrap(); // status of reception complete now located where the streaming to HD completes @@ -1561,12 +1595,20 @@ impl StateContainerInner { ) })?; } else { - file_container.last_group_finish_time = Instant::now(); - // TODO: Compute Mb/s + let now = Instant::now(); + let bytes_per_sec = file_container.metadata.plaintext_length as f32 + / now + .duration_since(file_container.last_group_finish_time) + .as_secs_f32() + .round(); + let mb_per_sec = bytes_per_sec / (1024.0f32 * 1024.0f32); + log::info!(target: "citadel", "Sending reception tick for group {} of {} | {:.2} MB/s", group_id, file_container.total_groups, mb_per_sec); + + file_container.last_group_finish_time = now; let status = ObjectTransferStatus::ReceptionTick( group_id as usize, file_container.total_groups, - 0 as f32, + mb_per_sec, ); // sending the wave ack will complete the group on the initiator side file_transfer_handle.unbounded_send(status).map_err(|err| { @@ -1599,7 +1641,7 @@ impl StateContainerInner { if !complete { let wave_ack = packet_crafter::group::craft_wave_ack( hr, - header.context_info.get() as u32, + header.context_info.get().into(), get_resp_target_cid_from_header(header), header.group.get(), header.wave_id.get(), @@ -1622,10 +1664,10 @@ impl StateContainerInner { #[allow(unused_results)] pub fn on_wave_ack_received( &mut self, - _implicated_cid: u64, + implicated_cid: u64, header: &Ref<&[u8], HdpHeader>, ) -> bool { - let object_id = header.context_info.get() as u64; + let object_id = header.context_info.get().into(); let group = header.group.get(); let wave_id = header.wave_id.get(); let target_cid = header.session_cid.get(); @@ -1655,7 +1697,7 @@ impl StateContainerInner { log::trace!(target: "citadel", "Notified object sender to begin sending the next group"); } - let file_key = FileKey::new(target_cid, object_id); + let file_key = FileKey::new(object_id); if let Some(tx) = self.file_transfer_handles.get(&file_key) { let status = if relative_group_id as usize @@ -1672,6 +1714,7 @@ impl StateContainerInner { ObjectTransferStatus::TransferComplete }; + log::trace!(target: "citadel", "Transmitter {implicated_cid}: {file_key:?} received final wave ack. Sending status to local node: {:?}", status); if let Err(err) = tx.unbounded_send(status.clone()) { // if the server is using an accept-only policy with no further responses, this branch // will be reached @@ -1682,11 +1725,11 @@ impl StateContainerInner { if matches!(status, ObjectTransferStatus::TransferComplete) { // remove the transmitter. Dropping will stop related futures - log::trace!(target: "citadel", "FileTransfer is complete!"); + log::trace!(target: "citadel", "FileTransfer is complete! Local is server? {}", self.is_server); let _ = self.file_transfer_handles.remove(&file_key); } } else { - log::error!(target: "citadel", "Unable to find ObjectTransferHandle for {:?}", file_key); + log::error!(target: "citadel", "Unable to find ObjectTransferHandle for {:?} | Local is {implicated_cid} | FileKeys available: {:?}", file_key, self.file_transfer_handles.keys().copied().collect::>()); } delete_group = true; @@ -1859,7 +1902,7 @@ impl StateContainerInner { } // object singleton == 0 implies that the data does not belong to a file - const OBJECT_SINGLETON: u64 = 0; + const OBJECT_SINGLETON: ObjectId = ObjectId::zero(); // Drop this to ensure that it doesn't block other async closures from accessing the inner device // std::mem::drop(this); let (mut transmitter, group_id, target_cid) = match virtual_target { @@ -2184,7 +2227,7 @@ impl StateContainerInner { let to_primary_stream = self.get_primary_stream().unwrap(); let kernel_tx = &self.kernel_tx; - HdpSession::send_to_primary_stream_closure( + CitadelSession::send_to_primary_stream_closure( to_primary_stream, kernel_tx, stage0_packet, @@ -2347,7 +2390,7 @@ impl StateContainerInner { &mut self, key: MessageGroupKey, ticket: Ticket, - session: &HdpSession, + session: &CitadelSession, ) -> Result { let (tx, rx) = unbounded(); let implicated_cid = self @@ -2367,7 +2410,7 @@ impl StateContainerInner { let (to_session_tx, to_session_rx) = crate::proto::outbound_sender::channel(MAX_OUTGOING_UNPROCESSED_REQUESTS); - HdpSession::spawn_message_sender_function(session.clone(), to_session_rx); + CitadelSession::spawn_message_sender_function(session.clone(), to_session_rx); Ok(GroupChannel::new( self.hdp_server_remote.clone(), diff --git a/citadel_proto/src/proto/validation.rs b/citadel_proto/src/proto/validation.rs index 83b4999fc..e3c8ec72d 100644 --- a/citadel_proto/src/proto/validation.rs +++ b/citadel_proto/src/proto/validation.rs @@ -44,6 +44,7 @@ pub(crate) mod group { use citadel_crypt::stacked_ratchet::StackedRatchet; use citadel_types::crypto::SecBuffer; use citadel_types::crypto::SecurityLevel; + use citadel_types::proto::ObjectId; use citadel_user::serialization::SyncIO; use serde::{Deserialize, Serialize}; @@ -84,13 +85,14 @@ pub(crate) mod group { pub(crate) fn validate_message( payload_orig: &mut BytesMut, - ) -> Option<(SecBuffer, Option, u64)> { + ) -> Option<(SecBuffer, Option, ObjectId)> { // Safely check that there are 8 bytes in length, then, split at the end - 8 - if payload_orig.len() < 8 { + if payload_orig.len() < std::mem::size_of::() { return None; } - let mut payload = payload_orig.split_to(payload_orig.len() - 8); - let object_id = payload_orig.reader().read_u64::().ok()?; + let mut payload = + payload_orig.split_to(payload_orig.len() - std::mem::size_of::()); + let object_id = payload_orig.reader().read_u128::().ok()?.into(); let message = SecureProtocolPacket::extract_message(&mut payload).ok()?; let deser = SyncIO::deserialize_from_vector(&payload[..]).ok()?; Some((message.into(), deser, object_id)) @@ -103,11 +105,11 @@ pub(crate) mod group { fast_msg: bool, initial_window: Option>, transfer: KemTransferStatus, - object_id: u64, + object_id: ObjectId, }, NotReady { fast_msg: bool, - object_id: u64, + object_id: ObjectId, }, } diff --git a/citadel_proto/tests/connections.rs b/citadel_proto/tests/connections.rs index a9e4ab693..7e29d65f8 100644 --- a/citadel_proto/tests/connections.rs +++ b/citadel_proto/tests/connections.rs @@ -57,7 +57,7 @@ pub mod tests { #[rstest] #[case("127.0.0.1:0")] #[case("[::1]:0")] - #[timeout(Duration::from_secs(240))] + #[timeout(Duration::from_secs(60))] #[tokio::test(flavor = "multi_thread")] async fn test_tcp_or_tls( #[case] addr: SocketAddr, @@ -109,7 +109,7 @@ pub mod tests { #[rstest] #[case("127.0.0.1:0")] #[case("[::1]:0")] - #[timeout(Duration::from_secs(240))] + #[timeout(Duration::from_secs(60))] #[tokio::test(flavor = "multi_thread")] async fn test_many_proto_conns( #[case] addr: SocketAddr, @@ -125,11 +125,6 @@ pub mod tests { let count = 32; // keep this value low to ensure that runners don't get exhausted and run out of FD's for proto in protocols { - if matches!(proto, ServerUnderlyingProtocol::Tls(..)) && cfg!(windows) { - citadel_logging::warn!(target: "citadel", "Will skip test since self-signed certs may not necessarily work on windows runner"); - continue; - } - log::trace!(target: "citadel", "Testing proto {:?}", &proto); let cnt = &AtomicUsize::new(0); diff --git a/citadel_sdk/Cargo.toml b/citadel_sdk/Cargo.toml index bfc083bda..388a7bed4 100644 --- a/citadel_sdk/Cargo.toml +++ b/citadel_sdk/Cargo.toml @@ -28,7 +28,6 @@ vendored = ["citadel_proto/vendored"] # for testing only localhost-testing = ["citadel_proto/localhost-testing", "tracing", "citadel_io/deadlock-detection"] localhost-testing-assert-no-proxy = ["citadel_proto/localhost-testing-assert-no-proxy"] -localhost-testing-loopback-only = ["citadel_proto/localhost-testing-loopback-only"] doc-images = ["embed-doc-image"] @@ -48,6 +47,7 @@ citadel_logging = { workspace = true } anyhow = { workspace = true } bytes = { workspace = true } citadel_types = { workspace = true } +citadel_wire = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["rt"] } diff --git a/citadel_sdk/src/fs.rs b/citadel_sdk/src/fs.rs index 6797fb4e0..4831b6775 100644 --- a/citadel_sdk/src/fs.rs +++ b/citadel_sdk/src/fs.rs @@ -78,12 +78,13 @@ pub async fn delete + Send>( #[cfg(test)] mod tests { use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel; - use crate::prefabs::server::accept_file_transfer_kernel::AcceptFileTransferKernel; + use crate::prefabs::server::accept_file_transfer_kernel::{ + exhaust_file_transfer, AcceptFileTransferKernel, + }; use crate::prefabs::client::peer_connection::{FileTransferHandleRx, PeerConnectionKernel}; use crate::prelude::*; use crate::test_common::wait_for_peers; - use futures::StreamExt; use rstest::rstest; use std::net::SocketAddr; use std::path::PathBuf; @@ -96,7 +97,6 @@ mod tests { } #[rstest] - #[timeout(std::time::Duration::from_secs(90))] #[case( EncryptionAlgorithm::AES_GCM_256, KemAlgorithm::Kyber, @@ -107,8 +107,9 @@ mod tests { KemAlgorithm::Kyber, SigAlgorithm::Falcon1024 )] + #[timeout(std::time::Duration::from_secs(90))] #[tokio::test] - async fn test_c2s_file_transfer_revfsq( + async fn test_c2s_file_transfer_revfs( #[case] enx: EncryptionAlgorithm, #[case] kem: KemAlgorithm, #[case] sig: SigAlgorithm, @@ -133,7 +134,7 @@ mod tests { UdpMode::Disabled, session_security_settings, None, - |_channel, remote| async move { + |_success, remote| async move { log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); // write to file to the RE-VFS @@ -144,15 +145,15 @@ mod tests { &virtual_path, ) .await?; - log::trace!(target: "citadel", "***CLIENT FILE TRANSFER SUCCESS***"); + log::info!(target: "citadel", "***CLIENT FILE TRANSFER SUCCESS***"); // now, pull it let save_dir = crate::fs::read(&remote, virtual_path).await?; // now, compare bytes - log::trace!(target: "citadel", "***CLIENT REVFS PULL SUCCESS"); + log::info!(target: "citadel", "***CLIENT REVFS PULL SUCCESS"); let original_bytes = tokio::fs::read(&source_dir).await.unwrap(); let revfs_pulled_bytes = tokio::fs::read(&save_dir).await.unwrap(); assert_eq!(original_bytes, revfs_pulled_bytes); - log::trace!(target: "citadel", "***CLIENT REVFS PULL COMPARE SUCCESS"); + log::info!(target: "citadel", "***CLIENT REVFS PULL COMPARE SUCCESS"); client_success.store(true, Ordering::Relaxed); remote.shutdown_kernel().await }, @@ -172,12 +173,12 @@ mod tests { } #[rstest] - #[timeout(std::time::Duration::from_secs(90))] #[case( EncryptionAlgorithm::AES_GCM_256, KemAlgorithm::Kyber, SigAlgorithm::None )] + #[timeout(std::time::Duration::from_secs(90))] #[tokio::test] async fn test_c2s_file_transfer_revfs_take( #[case] enx: EncryptionAlgorithm, @@ -245,12 +246,12 @@ mod tests { } #[rstest] - #[timeout(std::time::Duration::from_secs(90))] #[case( EncryptionAlgorithm::AES_GCM_256, KemAlgorithm::Kyber, SigAlgorithm::None )] + #[timeout(std::time::Duration::from_secs(90))] #[tokio::test] async fn test_c2s_file_transfer_revfs_delete( #[case] enx: EncryptionAlgorithm, @@ -320,7 +321,7 @@ mod tests { #[rstest] #[case(SecrecyMode::BestEffort)] - #[timeout(std::time::Duration::from_secs(240))] + #[timeout(Duration::from_secs(60))] #[tokio::test(flavor = "multi_thread")] async fn test_p2p_file_transfer_revfs( #[case] secrecy_mode: SecrecyMode, @@ -358,14 +359,14 @@ mod tests { move |mut connection, remote_outer| async move { wait_for_peers().await; let mut connection = connection.recv().await.unwrap()?; + let cid = connection.channel.get_implicated_cid(); wait_for_peers().await; // The other peer will send the file first - log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); + log::info!(target: "citadel", "***CLIENT A {cid} LOGIN SUCCESS :: File transfer next ***"); + let remote = connection.remote.clone(); let handle_orig = connection.incoming_object_transfer_handles.take().unwrap(); accept_all(handle_orig); - wait_for_peers().await; - /* let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); // write the file to the RE-VFS crate::fs::write_with_security_level( @@ -375,17 +376,18 @@ mod tests { &virtual_path, ) .await?; - log::error!(target: "citadel", "X01"); - log::trace!(target: "citadel", "***CLIENT FILE TRANSFER SUCCESS***"); + log::info!(target: "citadel", "***CLIENT A {cid} FILE TRANSFER SUCCESS***"); + tokio::time::sleep(Duration::from_secs(1)).await; + wait_for_peers().await; // now, pull it let save_dir = crate::fs::read(&remote, virtual_path).await?; // now, compare bytes - log::trace!(target: "citadel", "***CLIENT REVFS PULL SUCCESS"); + log::info!(target: "citadel", "***CLIENT A {cid} REVFS PULL SUCCESS"); let original_bytes = tokio::fs::read(&source_dir).await.unwrap(); let revfs_pulled_bytes = tokio::fs::read(&save_dir).await.unwrap(); assert_eq!(original_bytes, revfs_pulled_bytes); - log::trace!(target: "citadel", "***CLIENT REVFS PULL COMPARE SUCCESS"); - wait_for_peers().await;*/ + log::info!(target: "citadel", "***CLIENT A {cid} REVFS PULL COMPARE SUCCESS"); + wait_for_peers().await; client0_success.store(true, Ordering::Relaxed); remote_outer.shutdown_kernel().await }, @@ -401,10 +403,13 @@ mod tests { None, move |mut connection, remote_outer| async move { wait_for_peers().await; - let connection = connection.recv().await.unwrap()?; + let mut connection = connection.recv().await.unwrap()?; + let cid = connection.channel.get_implicated_cid(); wait_for_peers().await; let remote = connection.remote.clone(); - log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); + let handle_orig = connection.incoming_object_transfer_handles.take().unwrap(); + accept_all(handle_orig); + log::info!(target: "citadel", "***CLIENT B {cid} LOGIN SUCCESS :: File transfer next ***"); let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); // write the file to the RE-VFS crate::fs::write_with_security_level( @@ -414,22 +419,20 @@ mod tests { &virtual_path, ) .await?; - log::trace!(target: "citadel", "***CLIENT FILE TRANSFER SUCCESS***"); + log::info!(target: "citadel", "***CLIENT B {cid} FILE TRANSFER SUCCESS***"); + // Wait some time for the file to synchronize + tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; + wait_for_peers().await; // now, pull it let save_dir = crate::fs::read(&remote, virtual_path).await?; // now, compare bytes - log::trace!(target: "citadel", "***CLIENT REVFS PULL SUCCESS"); + log::info!(target: "citadel", "***CLIENT B {cid} REVFS PULL SUCCESS"); let original_bytes = tokio::fs::read(&source_dir).await.unwrap(); let revfs_pulled_bytes = tokio::fs::read(&save_dir).await.unwrap(); assert_eq!(original_bytes, revfs_pulled_bytes); - log::trace!(target: "citadel", "***CLIENT REVFS PULL COMPARE SUCCESS"); + log::info!(target: "citadel", "***CLIENT B {cid} REVFS PULL COMPARE SUCCESS"); wait_for_peers().await; - /* - // Now, accept the peer's incoming handle - let handle_orig = connection.incoming_object_transfer_handles.take().unwrap(); - accept_all(handle_orig); - - wait_for_peers().await;*/ client1_success.store(true, Ordering::Relaxed); remote_outer.shutdown_kernel().await }, @@ -458,18 +461,11 @@ mod tests { fn accept_all(mut rx: FileTransferHandleRx) { let handle = tokio::task::spawn(async move { while let Some(mut handle) = rx.recv().await { - let _ = handle.accept(); - // Exhaust the stream - let handle = tokio::task::spawn(async move { - while let Some(evt) = handle.next().await { - if let ObjectTransferStatus::Fail(err) = evt { - log::error!(target: "citadel", "File Transfer Failed: {err:?}"); - std::process::exit(1); - } - } - }); - - drop(handle); + if let Err(err) = handle.accept() { + log::error!(target: "citadel", "Failed to accept file transfer: {err:?}"); + } + + exhaust_file_transfer(handle); } }); diff --git a/citadel_sdk/src/prefabs/client/peer_connection.rs b/citadel_sdk/src/prefabs/client/peer_connection.rs index f55c05148..8cd487206 100644 --- a/citadel_sdk/src/prefabs/client/peer_connection.rs +++ b/citadel_sdk/src/prefabs/client/peer_connection.rs @@ -39,7 +39,7 @@ struct PeerContext { #[derive(Debug)] pub struct FileTransferHandleRx { pub inner: tokio::sync::mpsc::UnboundedReceiver, - pub peer_conn: PeerConnectionType, + pub conn_type: VirtualTargetType, } impl std::ops::Deref for FileTransferHandleRx { @@ -58,7 +58,7 @@ impl std::ops::DerefMut for FileTransferHandleRx { impl Drop for FileTransferHandleRx { fn drop(&mut self) { - log::trace!(target: "citadel", "Dropping file transfer handle receiver {:?}", self.peer_conn); + log::trace!(target: "citadel", "Dropping file transfer handle receiver {:?}", self.conn_type); } } @@ -364,7 +364,7 @@ where } let _reg_success = handle.register_to_peer().await?; - log::trace!(target: "citadel", "Peer {:?} registered || success -> now connecting", id); + log::info!(target: "citadel", "Peer {:?} registered || success -> now connecting", id); handle }; @@ -384,7 +384,7 @@ where // add an incoming file transfer receiver success.incoming_object_transfer_handles = Some(FileTransferHandleRx { inner: file_transfer_rx, - peer_conn, + conn_type: peer_conn.as_virtual_connection(), }); let _ = shared .active_peer_conns diff --git a/citadel_sdk/src/prefabs/client/single_connection.rs b/citadel_sdk/src/prefabs/client/single_connection.rs index 0c5b27074..f64775cdb 100644 --- a/citadel_sdk/src/prefabs/client/single_connection.rs +++ b/citadel_sdk/src/prefabs/client/single_connection.rs @@ -1,3 +1,4 @@ +use crate::prefabs::client::peer_connection::FileTransferHandleRx; use crate::prefabs::{get_socket_addr, ClientServerRemote}; use crate::remote_ext::ConnectionSuccess; use crate::remote_ext::ProtocolRemoteExt; @@ -6,7 +7,6 @@ use citadel_proto::prelude::*; use futures::Future; use std::marker::PhantomData; use std::net::{SocketAddr, ToSocketAddrs}; -use std::sync::Arc; use uuid::Uuid; /// A kernel that connects with the given credentials. If the credentials are not yet registered, then the [`Self::new_register`] function may be used, which will register the account before connecting. @@ -21,6 +21,8 @@ pub struct SingleClientServerConnectionKernel { unprocessed_signal_filter_tx: Mutex>>, remote: Option, server_password: Option, + rx_incoming_object_transfer_handle: Mutex>, + tx_incoming_object_transfer_handle: tokio::sync::mpsc::UnboundedSender, // by using fn() -> Fut, the future does not need to be Sync _pd: PhantomData Fut>, } @@ -48,6 +50,18 @@ where F: FnOnce(ConnectionSuccess, ClientServerRemote) -> Fut + Send, Fut: Future> + Send, { + fn generate_object_transfer_handle() -> ( + tokio::sync::mpsc::UnboundedSender, + Mutex>, + ) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let rx = FileTransferHandleRx { + inner: rx, + conn_type: VirtualTargetType::LocalGroupServer { implicated_cid: 0 }, + }; + (tx, Mutex::new(Some(rx))) + } + /// Creates a new connection with a central server entailed by the user information pub fn new_connect, P: Into>( username: T, @@ -57,6 +71,8 @@ where server_password: Option, on_channel_received: F, ) -> Self { + let (tx_incoming_object_transfer_handle, rx_incoming_object_transfer_handle) = + Self::generate_object_transfer_handle(); Self { handler: Mutex::new(Some(on_channel_received)), udp_mode, @@ -64,6 +80,8 @@ where username: username.into(), password: password.into(), })), + rx_incoming_object_transfer_handle, + tx_incoming_object_transfer_handle, session_security_settings, unprocessed_signal_filter_tx: Default::default(), remote: None, @@ -101,6 +119,8 @@ where on_channel_received: F, ) -> Result { let server_addr = get_socket_addr(server_addr)?; + let (tx_incoming_object_transfer_handle, rx_incoming_object_transfer_handle) = + Self::generate_object_transfer_handle(); Ok(Self { handler: Mutex::new(Some(on_channel_received)), udp_mode, @@ -112,6 +132,8 @@ where })), session_security_settings, unprocessed_signal_filter_tx: Default::default(), + rx_incoming_object_transfer_handle, + tx_incoming_object_transfer_handle, remote: None, server_password, _pd: Default::default(), @@ -153,12 +175,16 @@ where on_channel_received: F, ) -> Result { let server_addr = get_socket_addr(server_addr)?; + let (tx_incoming_object_transfer_handle, rx_incoming_object_transfer_handle) = + Self::generate_object_transfer_handle(); Ok(Self { handler: Mutex::new(Some(on_channel_received)), udp_mode, auth_info: Mutex::new(Some(ConnectionType::Passwordless { uuid, server_addr })), session_security_settings, unprocessed_signal_filter_tx: Default::default(), + rx_incoming_object_transfer_handle, + tx_incoming_object_transfer_handle, server_password, remote: None, _pd: Default::default(), @@ -258,31 +284,44 @@ where implicated_cid: connect_success.cid, }; - let unprocessed_signal_filter = if cfg!(feature = "localhost-testing") { - let (reroute_tx, reroute_rx) = tokio::sync::mpsc::unbounded_channel(); - *self.unprocessed_signal_filter_tx.lock() = Some(reroute_tx); - Some(reroute_rx) - } else { - None + let mut handle = { + let mut lock = self.rx_incoming_object_transfer_handle.lock(); + lock.take().expect("Should not have been called before") }; - (handler)( + handle.conn_type.set_implicated_cid(connect_success.cid); + + let (reroute_tx, reroute_rx) = tokio::sync::mpsc::unbounded_channel(); + *self.unprocessed_signal_filter_tx.lock() = Some(reroute_tx); + + handler( connect_success, - ClientServerRemote { - inner: remote, - unprocessed_signals_rx: Arc::new(Mutex::new(unprocessed_signal_filter)), + ClientServerRemote::new( conn_type, + remote, session_security_settings, - }, + Some(reroute_rx), + Some(handle), + ), ) .await } async fn on_node_event_received(&self, message: NodeResult) -> Result<(), NetworkError> { - if let Some(val) = self.unprocessed_signal_filter_tx.lock().as_ref() { - log::info!(target: "citadel", "Will forward message {:?}", val); - if let Err(err) = val.send(message) { - log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {:?}", err) + match message { + NodeResult::ObjectTransferHandle(handle) => { + if let Err(err) = self.tx_incoming_object_transfer_handle.send(handle.handle) { + log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {:?}", err) + } + } + + message => { + if let Some(val) = self.unprocessed_signal_filter_tx.lock().as_ref() { + log::trace!(target: "citadel", "Will forward message {:?}", val); + if let Err(err) = val.send(message) { + log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {:?}", err) + } + } } } diff --git a/citadel_sdk/src/prefabs/mod.rs b/citadel_sdk/src/prefabs/mod.rs index 0fa0e7b3d..575aeed67 100644 --- a/citadel_sdk/src/prefabs/mod.rs +++ b/citadel_sdk/src/prefabs/mod.rs @@ -1,4 +1,5 @@ use crate::impl_remote; +use crate::prefabs::client::peer_connection::FileTransferHandleRx; use citadel_io::Mutex; use citadel_proto::prelude::*; use std::net::{SocketAddr, ToSocketAddrs}; @@ -20,6 +21,7 @@ use crate::remote_ext::ProtocolRemoteExt; pub struct ClientServerRemote { pub(crate) inner: NodeRemote, pub(crate) unprocessed_signals_rx: Arc>>>, + pub(crate) file_transfer_handle_rx: Arc>>, conn_type: VirtualTargetType, session_security_settings: SessionSecuritySettings, } @@ -32,20 +34,32 @@ impl ClientServerRemote { conn_type: VirtualTargetType, remote: NodeRemote, session_security_settings: SessionSecuritySettings, + unprocessed_signals_rx: Option>, + file_transfer_handle_rx: Option, ) -> Self { + // TODO: Add handles, only the server calls this Self { inner: remote, - unprocessed_signals_rx: Default::default(), + unprocessed_signals_rx: Arc::new(Mutex::new(unprocessed_signals_rx)), + file_transfer_handle_rx: Arc::new(Mutex::new(file_transfer_handle_rx)), conn_type, session_security_settings, } } /// Can only be called once per remote. Allows receiving events - pub fn get_unprocessed_signals_receiver( - &self, - ) -> Option> { + pub fn get_unprocessed_signals_receiver(&self) -> Option> { self.unprocessed_signals_rx.lock().take() } + + /// Obtains a receiver which yields incoming file/object transfer handles + pub fn get_incoming_file_transfer_handle(&self) -> Result { + self.file_transfer_handle_rx + .lock() + .take() + .ok_or(NetworkError::InternalError( + "This function has already been called", + )) + } } impl TargetLockedRemote for ClientServerRemote { diff --git a/citadel_sdk/src/prefabs/server/accept_file_transfer_kernel.rs b/citadel_sdk/src/prefabs/server/accept_file_transfer_kernel.rs index 70e46266e..382d95ab3 100644 --- a/citadel_sdk/src/prefabs/server/accept_file_transfer_kernel.rs +++ b/citadel_sdk/src/prefabs/server/accept_file_transfer_kernel.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use futures::StreamExt; #[derive(Default)] pub struct AcceptFileTransferKernel; @@ -19,6 +20,7 @@ impl NetKernel for AcceptFileTransferKernel { .handle .accept() .map_err(|err| NetworkError::Generic(err.into_string()))?; + exhaust_file_transfer(handle.handle); } Ok(()) @@ -28,3 +30,20 @@ impl NetKernel for AcceptFileTransferKernel { Ok(()) } } + +pub fn exhaust_file_transfer(mut handle: ObjectTransferHandler) { + // Exhaust the stream + let handle = tokio::task::spawn(async move { + while let Some(evt) = handle.next().await { + log::info!(target: "citadel", "File Transfer Event: {evt:?}"); + if let ObjectTransferStatus::Fail(err) = &evt { + log::error!(target: "citadel", "File Transfer Failed: {err:?}"); + std::process::exit(1); + } else if let ObjectTransferStatus::TransferComplete = &evt { + break; + } + } + }); + + drop(handle); +} diff --git a/citadel_sdk/src/prefabs/server/client_connect_listener.rs b/citadel_sdk/src/prefabs/server/client_connect_listener.rs index baf356e69..03617cefb 100644 --- a/citadel_sdk/src/prefabs/server/client_connect_listener.rs +++ b/citadel_sdk/src/prefabs/server/client_connect_listener.rs @@ -59,6 +59,8 @@ where conn_type, self.node_remote.clone().unwrap(), session_security_settings, + None, // TODO: Add real handles + None, ); (self.on_channel_received)( ConnectionSuccess { diff --git a/citadel_sdk/src/prefabs/server/internal_service.rs b/citadel_sdk/src/prefabs/server/internal_service.rs index 1f5bc4588..3c4c8d9bb 100644 --- a/citadel_sdk/src/prefabs/server/internal_service.rs +++ b/citadel_sdk/src/prefabs/server/internal_service.rs @@ -63,6 +63,7 @@ mod test { use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{Body, Error, Request, Response, StatusCode}; + use rstest::rstest; use std::convert::Infallible; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -105,13 +106,16 @@ mod test { Ok(()) } + #[rstest] + #[timeout(Duration::from_secs(60))] #[tokio::test] async fn test_internal_service_basic_bytes() { setup_log(); let barrier = &TestBarrier::new(2); let success_count = &AtomicUsize::new(0); let message = &(0..4096).map(|r| (r % 256) as u8).collect::>(); - let server_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0") + .expect("Failed to get TCP listener"); let server_bind_addr = server_listener.local_addr().unwrap(); let server_kernel = InternalServiceKernel::new(|mut internal_server_communicator| async move { @@ -153,7 +157,7 @@ mod test { let server = NodeBuilder::default() .with_node_type(NodeType::Server(server_bind_addr)) .with_underlying_protocol( - ServerUnderlyingProtocol::from_tcp_listener(server_listener).unwrap(), + ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(), ) .build(server_kernel) .unwrap(); @@ -175,12 +179,15 @@ mod test { assert_eq!(success_count.load(Ordering::SeqCst), 2); } + #[rstest] + #[timeout(Duration::from_secs(60))] #[tokio::test] async fn test_internal_service_http() { setup_log(); let barrier = &TestBarrier::new(2); let success_count = &AtomicUsize::new(0); - let server_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0") + .expect("Failed to get TCP listener"); let server_bind_addr = server_listener.local_addr().unwrap(); let server_kernel = InternalServiceKernel::new(|internal_server_communicator| async move { @@ -255,7 +262,7 @@ mod test { let server = NodeBuilder::default() .with_node_type(NodeType::Server(server_bind_addr)) .with_underlying_protocol( - ServerUnderlyingProtocol::from_tcp_listener(server_listener).unwrap(), + ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(), ) .build(server_kernel) .unwrap(); diff --git a/citadel_sdk/src/remote_ext.rs b/citadel_sdk/src/remote_ext.rs index 0703b74a2..678646e61 100644 --- a/citadel_sdk/src/remote_ext.rs +++ b/citadel_sdk/src/remote_ext.rs @@ -655,7 +655,6 @@ pub trait ProtocolRemoteTargetExt: TargetLockedRemote { NodeResult::ObjectTransferHandle(ObjectTransferHandle { mut handle, .. }) => { let mut local_path = None; while let Some(res) = handle.next().await { - log::trace!(target: "citadel", "REVFS PULL EVENT {:?}", res); match res { ObjectTransferStatus::ReceptionBeginning(path, _) => { local_path = Some(path) @@ -1342,7 +1341,6 @@ mod tests { } #[rstest] - #[timeout(std::time::Duration::from_secs(90))] #[case( EncryptionAlgorithm::AES_GCM_256, KemAlgorithm::Kyber, @@ -1353,6 +1351,7 @@ mod tests { KemAlgorithm::Kyber, SigAlgorithm::Falcon1024 )] + #[timeout(std::time::Duration::from_secs(90))] #[tokio::test] async fn test_c2s_file_transfer( #[case] enx: EncryptionAlgorithm, diff --git a/citadel_sdk/src/test_common.rs b/citadel_sdk/src/test_common.rs index f8dda075d..ed1fe306b 100644 --- a/citadel_sdk/src/test_common.rs +++ b/citadel_sdk/src/test_common.rs @@ -17,12 +17,13 @@ pub fn server_test_node<'a, K: NetKernel + 'a>( opts: impl FnOnce(&mut NodeBuilder), ) -> (NodeFuture<'a, K>, SocketAddr) { let mut builder = NodeBuilder::default(); - let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let tcp_listener = citadel_wire::socket_helpers::get_tcp_listener("127.0.0.1:0") + .expect("Failed to create TCP listener"); let bind_addr = tcp_listener.local_addr().unwrap(); let builder = builder .with_node_type(NodeType::Server(bind_addr)) .with_underlying_protocol( - ServerUnderlyingProtocol::from_tcp_listener(tcp_listener).unwrap(), + ServerUnderlyingProtocol::from_tokio_tcp_listener(tcp_listener).unwrap(), ); (opts)(builder); @@ -44,15 +45,15 @@ pub fn server_info<'a>() -> (NodeFuture<'a, EmptyKernel>, SocketAddr) { #[allow(dead_code)] #[cfg(feature = "localhost-testing")] -pub fn server_info_reactive<'a, F: 'a, Fut: 'a>( +pub fn server_info_reactive<'a, F, Fut>( f: F, opts: impl FnOnce(&mut NodeBuilder), ) -> (NodeFuture<'a, Box>, SocketAddr) where - F: Fn(ConnectionSuccess, ClientServerRemote) -> Fut + Send + Sync, - Fut: Future> + Send + Sync, + F: Fn(ConnectionSuccess, ClientServerRemote) -> Fut + Send + Sync + 'a, + Fut: Future> + Send + Sync + 'a, { - crate::test_common::server_test_node( + server_test_node( Box::new(ClientConnectListenerKernel::new(f)) as Box, opts, ) diff --git a/citadel_sdk/tests/stress_tests.rs b/citadel_sdk/tests/stress_tests.rs index a7ec08e8d..bed22d7f0 100644 --- a/citadel_sdk/tests/stress_tests.rs +++ b/citadel_sdk/tests/stress_tests.rs @@ -208,6 +208,11 @@ mod tests { enx: EncryptionAlgorithm, ) { citadel_logging::setup_log(); + + if windows_pipeline_check(kem, secrecy_mode) { + return; + } + citadel_sdk::test_common::TestBarrier::setup(2); static CLIENT_SUCCESS: AtomicBool = AtomicBool::new(false); static SERVER_SUCCESS: AtomicBool = AtomicBool::new(false); @@ -275,6 +280,11 @@ mod tests { #[values(EncryptionAlgorithm::Kyber)] enx: EncryptionAlgorithm, ) { citadel_logging::setup_log(); + + if windows_pipeline_check(kem, secrecy_mode) { + return; + } + citadel_sdk::test_common::TestBarrier::setup(2); static CLIENT_SUCCESS: AtomicBool = AtomicBool::new(false); static SERVER_SUCCESS: AtomicBool = AtomicBool::new(false); @@ -351,6 +361,11 @@ mod tests { enx: EncryptionAlgorithm, ) { citadel_logging::setup_log(); + + if windows_pipeline_check(kem, secrecy_mode) { + return; + } + citadel_sdk::test_common::TestBarrier::setup(2); let client0_success = &AtomicBool::new(false); let client1_success = &AtomicBool::new(false); @@ -448,7 +463,7 @@ mod tests { #[rstest] #[case(500, 3)] - #[timeout(std::time::Duration::from_secs(240))] + #[timeout(std::time::Duration::from_secs(90))] #[tokio::test(flavor = "multi_thread")] async fn stress_test_group_broadcast(#[case] message_count: usize, #[case] peer_count: usize) { citadel_logging::setup_log(); @@ -524,4 +539,18 @@ mod tests { assert!(res.is_ok()); assert_eq!(CLIENT_SUCCESS.load(Ordering::Relaxed), peer_count); } + + /// This test is disabled by default because it is very slow and requires a lot of resources + fn windows_pipeline_check(kem: KemAlgorithm, secrecy_mode: SecrecyMode) -> bool { + if cfg!(windows) + && kem == KemAlgorithm::Ntru + && secrecy_mode == SecrecyMode::Perfect + && std::env::var("IN_CI").is_ok() + { + log::warn!(target: "citadel", "Skipping NTRU/Perfect forward secrecy test on Windows due to performance issues"); + true + } else { + false + } + } } diff --git a/citadel_types/Cargo.toml b/citadel_types/Cargo.toml index 2e33a5c8a..7f8918140 100644 --- a/citadel_types/Cargo.toml +++ b/citadel_types/Cargo.toml @@ -18,7 +18,7 @@ bytes = { workspace = true, features = ["serde"] } twox-hash = { workspace = true } packed_struct = { workspace = true, features = ["serde"] } uuid = { workspace = true, features = ["v4"] } -bincode2 = { workspace = true} +bincode = { workspace = true} [target.'cfg(target_family = "unix")'.dependencies] libc = { workspace = true } diff --git a/citadel_types/src/proto/mod.rs b/citadel_types/src/proto/mod.rs index 3e5fe6474..a9056e5b9 100644 --- a/citadel_types/src/proto/mod.rs +++ b/citadel_types/src/proto/mod.rs @@ -1,7 +1,7 @@ use crate::crypto::{CryptoParameters, SecrecyMode, SecurityLevel}; use crate::utils; use serde::{Deserialize, Serialize}; -use std::fmt::Formatter; +use std::fmt::{Debug, Display, Formatter}; use std::path::PathBuf; use uuid::Uuid; @@ -26,18 +26,50 @@ pub struct VirtualObjectMetadata { pub author: String, pub plaintext_length: usize, pub group_count: usize, - pub object_id: u64, + pub object_id: ObjectId, pub cid: u64, pub transfer_type: TransferType, } +#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize)] +#[repr(transparent)] +pub struct ObjectId(pub u128); + +impl Debug for ObjectId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Display for ObjectId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self, f) + } +} + +impl ObjectId { + pub fn random() -> Self { + Uuid::new_v4().as_u128().into() + } + + pub const fn zero() -> Self { + Self(0) + } +} + +impl From for ObjectId { + fn from(value: u128) -> Self { + Self(value) + } +} + impl VirtualObjectMetadata { pub fn serialize(&self) -> Vec { - bincode2::serialize(self).unwrap() + bincode::serialize(self).unwrap() } pub fn deserialize_from<'a, T: AsRef<[u8]> + 'a>(input: T) -> Option { - bincode2::deserialize(input.as_ref()).ok() + bincode::deserialize(input.as_ref()).ok() } pub fn get_security_level(&self) -> Option { diff --git a/citadel_types/src/utils/mem.rs b/citadel_types/src/utils/mem.rs index ab254c505..2a7332194 100644 --- a/citadel_types/src/utils/mem.rs +++ b/citadel_types/src/utils/mem.rs @@ -1,5 +1,3 @@ -use std::os::raw::c_void; - /// Locks-down the memory location, preventing it from being read until unlocked /// For linux, returns zero if successful /// # Safety @@ -8,11 +6,11 @@ use std::os::raw::c_void; #[cfg(target_family = "unix")] #[allow(unused_results)] pub unsafe fn mlock(ptr: *const u8, len: usize) { - libc::mlock(ptr as *const c_void, len); + libc::mlock(ptr as *const std::os::raw::c_void, len); #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))] - libc::madvise(ptr as *mut c_void, len, libc::MADV_NOCORE); + libc::madvise(ptr as *mut std::os::raw::c_void, len, libc::MADV_NOCORE); #[cfg(target_os = "linux")] - libc::madvise(ptr as *mut c_void, len, libc::MADV_DONTDUMP); + libc::madvise(ptr as *mut std::os::raw::c_void, len, libc::MADV_DONTDUMP); } #[cfg(target_family = "wasm")] @@ -34,11 +32,11 @@ pub unsafe fn mlock(ptr: *const u8, len: usize) { #[cfg(target_family = "unix")] #[allow(unused_results)] pub unsafe fn munlock(ptr: *const u8, len: usize) { - libc::munlock(ptr as *const c_void, len); + libc::munlock(ptr as *const std::os::raw::c_void, len); #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))] - libc::madvise(ptr as *mut c_void, len, libc::MADV_CORE); + libc::madvise(ptr as *mut std::os::raw::c_void, len, libc::MADV_CORE); #[cfg(target_os = "linux")] - libc::madvise(ptr as *mut c_void, len, libc::MADV_DODUMP); + libc::madvise(ptr as *mut std::os::raw::c_void, len, libc::MADV_DODUMP); } #[cfg(target_family = "wasm")] diff --git a/citadel_user/Cargo.toml b/citadel_user/Cargo.toml index 8b6f8b210..a6d9e14af 100644 --- a/citadel_user/Cargo.toml +++ b/citadel_user/Cargo.toml @@ -54,7 +54,7 @@ firebase-rtdb = { workspace = true, optional = true } jwt = { workspace = true, features = ["openssl"], optional = true } openssl = { workspace = true, optional = true } uuid = { workspace = true, features = ["v4"] } -bincode2 = { workspace = true } +bincode = { workspace = true } chrono = { workspace = true, features = ["clock"] } tokio-util = { workspace = true, features = ["io"], optional = true } tokio-stream = { workspace = true, optional = true } diff --git a/citadel_user/src/account_loader.rs b/citadel_user/src/account_loader.rs index 77d69b7f8..d51c800ec 100644 --- a/citadel_user/src/account_loader.rs +++ b/citadel_user/src/account_loader.rs @@ -34,7 +34,6 @@ pub fn load_cnac_files( .collect()) } -use crate::serialization::bincode_config; use serde::de::DeserializeOwned; use std::path::{Path, PathBuf}; @@ -83,8 +82,7 @@ pub fn read>(path: P) -> Result BackendConnection for FilesystemBackend BackendConnection for FilesystemBackend>, std::io::Error> }), ); + let mut reader = tokio::io::BufReader::with_capacity(chunk_size, reader); + if is_virtual_file { // start by writing the metadata file next to it let metadata_path = get_revfs_file_metadata_path(&file_path); @@ -431,7 +436,7 @@ impl BackendConnection for FilesystemBackend Result<(Box, SecurityLevel), AccountError> { + ) -> Result<(Box, VirtualObjectMetadata), AccountError> { let directory_store = self.directory_store.as_ref().unwrap(); let file_path = get_file_path( cid, @@ -453,9 +458,7 @@ impl BackendConnection for FilesystemBackend: Send + Sync { &self, cid: u64, virtual_path: std::path::PathBuf, - ) -> Result<(Box, SecurityLevel), AccountError> { + ) -> Result<(Box, VirtualObjectMetadata), AccountError> { Err(AccountError::Generic( "The target does not support the RE-VFS protocol".into(), )) diff --git a/citadel_user/src/backend/utils/mod.rs b/citadel_user/src/backend/utils/mod.rs index 222b06f27..91e9dbce7 100644 --- a/citadel_user/src/backend/utils/mod.rs +++ b/citadel_user/src/backend/utils/mod.rs @@ -113,25 +113,11 @@ impl ObjectTransferHandler { } fn respond(&mut self, accept: bool) -> Result<(), AccountError> { - if matches!( - self.orientation, - ObjectTransferOrientation::Receiver { - is_revfs_pull: true - } - ) { - let _ = self.start_recv_tx.take(); - return Ok(()); + if let Some(tx) = self.start_recv_tx.take() { + tx.send(accept) + .map_err(|_| AccountError::msg("Failed to send response"))?; } - if matches!(self.orientation, ObjectTransferOrientation::Receiver { .. }) { - self.start_recv_tx - .take() - .ok_or_else(|| AccountError::msg("Start_recv_tx already called"))? - .send(accept) - .map_err(|err| AccountError::msg(err.to_string())) - } else { - let _ = self.start_recv_tx.take(); - Ok(()) - } + Ok(()) } } diff --git a/citadel_user/src/serialization.rs b/citadel_user/src/serialization.rs index 7cbbe9a02..8d82ad09f 100644 --- a/citadel_user/src/serialization.rs +++ b/citadel_user/src/serialization.rs @@ -1,5 +1,5 @@ use crate::misc::AccountError; -use bincode2::BincodeRead; +use bincode::BincodeRead; use bytes::BufMut; use bytes::BytesMut; use serde::de::DeserializeOwned; @@ -28,8 +28,7 @@ pub trait SyncIO { Self: DeserializeOwned, { use bytes::Buf; - bincode_config() - .deserialize_from(input.reader()) + bincode::deserialize_from(input.reader()) .map_err(|err| AccountError::Generic(err.to_string())) } @@ -39,8 +38,7 @@ pub trait SyncIO { T: serde::de::Deserialize<'a>, R: BincodeRead<'a>, { - bincode_config() - .deserialize_in_place(reader, place) + bincode::deserialize_in_place(reader, place) .map_err(|err| AccountError::Generic(err.to_string())) } @@ -49,11 +47,10 @@ pub trait SyncIO { where Self: Serialize, { - bincode_config() - .serialized_size(self) + bincode::serialized_size(self) .and_then(|amt| { buf.reserve(amt as usize); - bincode2::serialize_into(buf.writer(), self) + bincode::serialize_into(buf.writer(), self) }) .map_err(|_| AccountError::Generic("Bad ser".to_string())) } @@ -63,9 +60,7 @@ pub trait SyncIO { where Self: Serialize, { - bincode_config() - .serialize_into(slice, self) - .map_err(|err| AccountError::Generic(err.to_string())) + bincode::serialize_into(slice, self).map_err(|err| AccountError::Generic(err.to_string())) } /// Returns the expected size of the serialized objects @@ -73,35 +68,18 @@ pub trait SyncIO { where Self: Serialize, { - bincode_config() - .serialized_size(self) - .ok() - .map(|res| res as usize) + bincode::serialized_size(self).ok().map(|res| res as usize) } } impl<'a, T> SyncIO for T where T: Serialize + Deserialize<'a> + Sized {} -/// A limited config. Helps prevent oversized allocations from occurring when deserializing incompatible -/// objects -#[inline(always)] -#[allow(unused_results)] -pub fn bincode_config() -> bincode2::Config { - let mut cfg = bincode2::config(); - cfg.limit(1000 * 1000 * 1000 * 4); - cfg -} - /// Deserializes the bytes, T, into type D fn bytes_to_type<'a, D: Deserialize<'a>>(bytes: &'a [u8]) -> Result { - bincode_config() - .deserialize(bytes) - .map_err(|err| AccountError::IoError(err.to_string())) + bincode::deserialize(bytes).map_err(|err| AccountError::IoError(err.to_string())) } /// Converts a type, D to Vec fn type_to_bytes(input: D) -> Result, AccountError> { - bincode_config() - .serialize(&input) - .map_err(|err| AccountError::IoError(err.to_string())) + bincode::serialize(&input).map_err(|err| AccountError::IoError(err.to_string())) } diff --git a/citadel_wire/Cargo.toml b/citadel_wire/Cargo.toml index 251160534..ebf492bf4 100644 --- a/citadel_wire/Cargo.toml +++ b/citadel_wire/Cargo.toml @@ -23,7 +23,6 @@ std = [ "serde/std" ] localhost-testing = ["tracing"] -localhost-testing-loopback-only = [] wasm = [ "citadel_io/wasm", "netbeam/wasm" @@ -39,9 +38,9 @@ citadel_io = { workspace = true } anyhow = { workspace = true } serde = { workspace = true, features = ["derive"] } log = { workspace = true } -bincode2 = { workspace = true } +bincode = { workspace = true } async_ip = { workspace = true } -itertools = { workspace = true, features = ["use_alloc"] } +itertools = { workspace = true, features = ["use_alloc", "use_std"] } either = { workspace = true } netbeam = { workspace = true } uuid = { workspace = true, features = ["v4", "serde"] } diff --git a/citadel_wire/examples/client_sym3.rs b/citadel_wire/examples/client_sym3.rs deleted file mode 100644 index 360451431..000000000 --- a/citadel_wire/examples/client_sym3.rs +++ /dev/null @@ -1,91 +0,0 @@ -use citadel_wire::quic::QuicEndpointConnector; -use citadel_wire::udp_traversal::udp_hole_puncher::UdpHolePuncher; -use netbeam::sync::network_endpoint::NetworkEndpoint; -use netbeam::sync::RelativeNodeType; -use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, BufReader}; - -#[tokio::main] -async fn main() { - //citadel_logging::setup_log(); - - let server_stream = citadel_io::TcpStream::connect("51.81.86.78:25025") - .await - .unwrap(); - - log::trace!(target: "citadel", "Established TCP server connection"); - - let hole_punched_socket = UdpHolePuncher::new( - &NetworkEndpoint::register(RelativeNodeType::Initiator, server_stream) - .await - .unwrap(), - Default::default(), - ) - .await - .unwrap(); - let client_config = Arc::new(citadel_wire::quic::insecure::rustls_client_config()); - log::trace!(target: "citadel", "Successfully hole-punched socket to peer @ {:?}", hole_punched_socket.addr); - let (_conn, mut sink, mut stream) = - citadel_wire::quic::QuicClient::new_with_config(hole_punched_socket.socket, client_config) - .unwrap() - .connect_biconn( - hole_punched_socket.addr.receive_address, - "mail.satorisocial.com", - ) - .await - .unwrap(); - log::trace!(target: "citadel", "Successfully obtained QUIC connection ..."); - - let writer = async move { - let mut stdin = BufReader::new(tokio::io::stdin()).lines(); - while let Ok(Some(input)) = stdin.next_line().await { - log::trace!(target: "citadel", "About to send: {}", &input); - sink.write(input.as_bytes()).await.unwrap(); - } - - log::trace!(target: "citadel", "writer ending"); - }; - - let reader = async move { - let input = &mut [0u8; 4096]; - loop { - let len = stream.read(input).await.unwrap().unwrap(); - if let Ok(string) = String::from_utf8(Vec::from(&input[..len])) { - log::trace!(target: "citadel", "[Message]: {}", string); - } - } - }; - - tokio::select! { - res0 = writer => res0, - res1 = reader => res1 - } - - /* - let writer = async move { - let mut stdin = BufReader::new(tokio::io::stdin()).lines(); - while let Ok(Some(input)) = stdin.next_line().await { - log::trace!(target: "citadel", "About to send (bind:{:?}->{:?}): {}", hole_punched_socket.socket.local_addr().unwrap(), hole_punched_socket.addr.natted, &input); - hole_punched_socket.socket.send_to(input.as_bytes(), hole_punched_socket.addr.natted).await.unwrap(); - } - - log::trace!(target: "citadel", "writer ending"); - }; - - let reader = async move { - let input = &mut [0u8; 4096]; - loop { - let len = hole_punched_socket.socket.recv(input).await.unwrap(); - if let Ok(string) = String::from_utf8(Vec::from(&input[..len])) { - log::trace!(target: "citadel", "[Message]: {}", string); - } - } - }; - - tokio::select! { - res0 = writer => res0, - res1 = reader => res1 - }*/ - - log::trace!(target: "citadel", "Quitting program clientside"); -} diff --git a/citadel_wire/examples/server_sym3.rs b/citadel_wire/examples/server_sym3.rs deleted file mode 100644 index 4c6b542a1..000000000 --- a/citadel_wire/examples/server_sym3.rs +++ /dev/null @@ -1,89 +0,0 @@ -use citadel_wire::quic::QuicEndpointListener; -use citadel_wire::udp_traversal::udp_hole_puncher::UdpHolePuncher; -use netbeam::sync::network_endpoint::NetworkEndpoint; -use netbeam::sync::RelativeNodeType; -use tokio::io::{AsyncBufReadExt, BufReader}; - -#[tokio::main] -async fn main() { - //setup_log(); - let listener = citadel_io::TcpListener::bind("0.0.0.0:25025") - .await - .unwrap(); - let (client_stream, peer_addr) = listener.accept().await.unwrap(); - log::trace!(target: "citadel", "Received client stream from {:?}", peer_addr); - - let hole_punched_socket = UdpHolePuncher::new( - &NetworkEndpoint::register(RelativeNodeType::Receiver, client_stream) - .await - .unwrap(), - Default::default(), - ) - .await - .unwrap(); - log::trace!(target: "citadel", "Successfully hole-punched socket to peer @ {:?}", hole_punched_socket.addr); - - let (_conn, mut sink, mut stream) = citadel_wire::quic::QuicServer::new_from_pkcs_12_der_path( - hole_punched_socket.socket, - "../keys/testing.p12", - "mrmoney10", - ) - .unwrap() - .next_connection() - .await - .unwrap(); - log::trace!(target: "citadel", "Successfully obtained QUIC connection ..."); - - let writer = async move { - let mut stdin = BufReader::new(tokio::io::stdin()).lines(); - while let Ok(Some(input)) = stdin.next_line().await { - log::trace!(target: "citadel", "About to send: {}", &input); - sink.write(input.as_bytes()).await.unwrap(); - } - - log::trace!(target: "citadel", "writer ending"); - }; - - let reader = async move { - let input = &mut [0u8; 4096]; - loop { - let len = stream.read(input).await.unwrap().unwrap(); - if let Ok(string) = String::from_utf8(Vec::from(&input[..len])) { - log::trace!(target: "citadel", "[Message]: {}", string); - } - } - }; - - tokio::select! { - res0 = writer => res0, - res1 = reader => res1 - } - - /* - let writer = async move { - let mut stdin = BufReader::new(tokio::io::stdin()).lines(); - while let Ok(Some(input)) = stdin.next_line().await { - log::trace!(target: "citadel", "About to send (bind:{:?}->{:?}): {}", hole_punched_socket.socket.local_addr().unwrap(), hole_punched_socket.addr.natted, &input); - hole_punched_socket.socket.send_to(input.as_bytes(), hole_punched_socket.addr.natted).await.unwrap(); - } - - log::trace!(target: "citadel", "writer ending"); - }; - - let reader = async move { - let input = &mut [0u8; 4096]; - loop { - let len = hole_punched_socket.socket.recv(input).await.unwrap(); - if let Ok(string) = String::from_utf8(Vec::from(&input[..len])) { - log::trace!(target: "citadel", "[Message]: {}", string); - } - } - }; - - tokio::select! { - res0 = writer => res0, - res1 = reader => res1 - }*/ - - log::trace!(target: "citadel", "Quitting program serverside"); -} diff --git a/citadel_wire/src/standard/nat_identification.rs b/citadel_wire/src/standard/nat_identification.rs index 6adfff0f8..8760d1e92 100644 --- a/citadel_wire/src/standard/nat_identification.rs +++ b/citadel_wire/src/standard/nat_identification.rs @@ -1,16 +1,13 @@ -#![cfg_attr(feature = "localhost-testing-loopback-only", allow(unreachable_code))] - use crate::error::FirewallError; use crate::socket_helpers::is_ipv6_enabled; use async_ip::IpAddressInfo; use futures::stream::FuturesUnordered; -use futures::{Future, StreamExt}; +use futures::StreamExt; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::net::{IpAddr, SocketAddr}; use std::ops::Sub; -use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -27,7 +24,7 @@ const STUN_SERVERS: [&str; 3] = [ ]; const V4_BIND_ADDR: &str = "0.0.0.0:0"; -const IDENTIFY_TIMEOUT: Duration = Duration::from_millis(4500); +pub const IDENTIFY_TIMEOUT: Duration = Duration::from_millis(3000); #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub enum IpTranslation { @@ -97,6 +94,12 @@ impl NatType { tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug)) )] pub async fn identify(stun_servers: Option>) -> Result { + if cfg!(feature = "localhost-testing") { + if let Some(nat_type) = LOCALHOST_TESTING_NAT_TYPE.lock().as_ref() { + return Ok(nat_type.clone()); + } + } + match Self::identify_timeout(IDENTIFY_TIMEOUT, stun_servers).await { Ok(nat_type) => Ok(nat_type), Err(err) => { @@ -393,9 +396,8 @@ pub enum TraversalTypeRequired { } // we only need to check the NAT type once per node -lazy_static::lazy_static! { - pub static ref LOCALHOST_TESTING_NAT_TYPE: citadel_io::Mutex> = citadel_io::Mutex::new(None); -} +static LOCALHOST_TESTING_NAT_TYPE: citadel_io::Mutex> = + citadel_io::Mutex::new(None); impl NatType { /// Returns the NAT traversal type required to access self and other, respectively @@ -547,42 +549,31 @@ async fn get_nat_type(stun_servers: Option>) -> Result Err(anyhow::Error::msg("Unable to get all three STUN addrs")), } }; - let ip_info_future = if cfg!(feature = "localhost-testing") { - Box::pin(async move { Ok(Some(async_ip::IpAddressInfo::localhost())) }) - as Pin< - Box< - dyn Future, async_ip::IpRetrieveError>> - + Send, - >, - > - } else { - Box::pin(async move { - match tokio::time::timeout( - Duration::from_millis(1500), - async_ip::get_all_multi_concurrent(None), - ) - .await - { - Ok(Ok(ip_info)) => Ok(Some(ip_info)), - Ok(Err(err)) => Err(err), - Err(_) => Ok(None), - } - }) + let ip_info_future = async move { + match tokio::time::timeout( + Duration::from_millis(2000), + async_ip::get_all_multi_concurrent(None), + ) + .await + { + Ok(Ok(ip_info)) => Ok(Some(ip_info)), + Ok(Err(err)) => Err(err), + Err(_) => Ok(None), + } }; let (nat_type, ip_info) = tokio::join!(nat_type, ip_info_future); let mut nat_type = nat_type?; - log::trace!(target: "citadel", "NAT Type: {nat_type:?} | IpInfo: {ip_info:?}"); let ip_info = match ip_info { Ok(Some(ip_info)) => ip_info, @@ -597,6 +588,9 @@ async fn get_nat_type(stun_servers: Option>) -> Result std::io::Result<()> { citadel_logging::setup_log(); diff --git a/citadel_wire/src/standard/socket_helpers.rs b/citadel_wire/src/standard/socket_helpers.rs index 8c7da7f9c..2a2e4f289 100644 --- a/citadel_wire/src/standard/socket_helpers.rs +++ b/citadel_wire/src/standard/socket_helpers.rs @@ -3,13 +3,6 @@ use socket2::{Domain, SockAddr, Socket, Type}; use std::net::{IpAddr, SocketAddr, SocketAddrV6}; use std::time::Duration; -/// Given an ip bind addr, finds an open socket at that ip addr -pub fn get_unused_udp_socket_at_bind_ip(bind_addr: IpAddr) -> std::io::Result { - let socket = std::net::UdpSocket::bind((bind_addr, 0))?; - socket.set_nonblocking(true)?; - UdpSocket::from_std(socket) -} - fn get_udp_socket_builder(domain: Domain) -> Result { Ok(socket2::Socket::new(domain, Type::DGRAM, None)?) } @@ -30,7 +23,7 @@ fn setup_base_socket(addr: SocketAddr, socket: &Socket, reuse: bool) -> Result<( socket.set_nonblocking(true)?; - if addr.is_ipv6() { + if !cfg!(windows) && addr.is_ipv6() { socket.set_only_v6(false)?; } @@ -63,6 +56,9 @@ fn get_udp_socket_inner( .to_socket_addrs()? .next() .ok_or_else(|| anyhow::Error::msg("Bad socket addr"))?; + + let addr = windows_check(addr); + log::trace!(target: "citadel", "[Socket helper] Getting UDP (reuse={}) socket @ {:?} ...", reuse, &addr); let domain = if addr.is_ipv4() { Domain::IPV4 @@ -72,11 +68,24 @@ fn get_udp_socket_inner( let socket = get_udp_socket_builder(domain)?; setup_bind(addr, &socket, reuse)?; let std_socket: std::net::UdpSocket = socket.into(); - std_socket.set_nonblocking(true)?; let tokio_socket = citadel_io::UdpSocket::from_std(std_socket)?; Ok(tokio_socket) } +fn windows_check(addr: SocketAddr) -> SocketAddr { + // if feature "localhost-testing" is enabled, and, we are not on mac, then, we will bind to 127.0.0.1 + if cfg!(feature = "localhost-testing") && !cfg!(target_os = "macos") { + log::warn!(target: "citadel", "Localhost testing is enabled on non-mac OS. Will ensure bind is 127.0.0.1"); + if addr.is_ipv4() { + SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), addr.port()) + } else { + SocketAddr::new(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), addr.port()) + } + } else { + addr + } +} + fn get_tcp_listener_inner( addr: T, reuse: bool, @@ -85,6 +94,9 @@ fn get_tcp_listener_inner( .to_socket_addrs()? .next() .ok_or_else(|| anyhow::Error::msg("Bad socket addr"))?; + + let addr = windows_check(addr); + log::trace!(target: "citadel", "[Socket helper] Getting TCP listener (reuse={}) socket @ {:?} ...", reuse, &addr); let domain = if addr.is_ipv4() { @@ -94,10 +106,9 @@ fn get_tcp_listener_inner( }; let socket = get_tcp_socket_builder(domain)?; setup_bind(addr, &socket, reuse)?; - let std_tcp_socket: std::net::TcpStream = socket.into(); - std_tcp_socket.set_nonblocking(true)?; - - Ok(citadel_io::TcpSocket::from_std_stream(std_tcp_socket).listen(1024)?) + socket.listen(1024)?; + let std_tcp_socket: std::net::TcpListener = socket.into(); + Ok(citadel_io::TcpListener::from_std(std_tcp_socket)?) } async fn get_tcp_stream_inner( @@ -140,16 +151,7 @@ pub async fn get_reuse_tcp_stream( } pub fn get_udp_socket(addr: T) -> Result { - #[cfg(not(target_os = "windows"))] - { - get_udp_socket_inner(addr, false) - } - #[cfg(target_os = "windows")] - { - let std_socket = std::net::UdpSocket::bind(addr)?; - std_socket.set_nonblocking(true)?; - Ok(citadel_io::UdpSocket::from_std(std_socket)?) - } + get_udp_socket_inner(addr, false) } /// `backlog`: the max number of unprocessed TCP connections diff --git a/citadel_wire/src/udp_traversal/hole_punch_config.rs b/citadel_wire/src/udp_traversal/hole_punch_config.rs index 4481aab17..bd9edabf9 100644 --- a/citadel_wire/src/udp_traversal/hole_punch_config.rs +++ b/citadel_wire/src/udp_traversal/hole_punch_config.rs @@ -1,14 +1,12 @@ -#![cfg_attr(feature = "localhost-testing-loopback-only", allow(unreachable_code))] - use crate::nat_identification::NatType; use citadel_io::UdpSocket; -use std::collections::HashSet; -use std::net::{IpAddr, SocketAddr}; +use itertools::Itertools; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[derive(Debug)] pub struct HolePunchConfig { /// The IP addresses that must be connected to based on NAT traversal - pub bands: Vec, + pub bands: Vec>, // sockets bound to ports specially prepared for NAT traversal pub(crate) locally_bound_sockets: Option>, } @@ -20,22 +18,22 @@ pub struct AddrBand { } impl IntoIterator for HolePunchConfig { - type Item = SocketAddr; - type IntoIter = std::collections::hash_set::IntoIter; + type Item = Vec; + type IntoIter = std::vec::IntoIter>; fn into_iter(mut self) -> Self::IntoIter { // Use a HashSet to enforce uniqueness - let mut ret = HashSet::new(); + let mut ret = vec![]; - for mut band in self.bands.drain(..) { - for next in band.by_ref() { - if next.ip() == IpAddr::from([0, 0, 0, 0]) { - // we never want to send to 0.0.0.0 addrs, only loopbacks - ret.insert(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), next.port())); - } else { - ret.insert(next); + for band_set in self.bands.drain(..) { + let mut this_set = vec![]; + for mut band in band_set { + for next in band.by_ref() { + this_set.push(next); } } + + ret.push(this_set); } ret.into_iter() @@ -55,44 +53,52 @@ impl Iterator for AddrBand { impl HolePunchConfig { pub fn new( peer_nat: &NatType, - peer_internal_addr: &SocketAddr, - local_socket: UdpSocket, + peer_internal_addrs: &[SocketAddr], + local_sockets: Vec, ) -> Self { - let mut this = if let Some(bands) = peer_nat.predict(peer_internal_addr) { - Self { - bands, - locally_bound_sockets: Some(vec![local_socket]), - } - } else if cfg!(feature = "localhost-testing") { - log::info!(target: "citadel", "Will revert to localhost testing mode (not recommended for production use (peer addr: {:?}))", peer_internal_addr); - Self { - bands: get_localhost_bands(peer_internal_addr), - locally_bound_sockets: Some(vec![local_socket]), - } - } else { - // the peer nat is untraversable. However, they may still be able to connect to this node. - // As such, we will only listen: - Self { - bands: vec![AddrBand { + assert_eq!(peer_internal_addrs.len(), local_sockets.len()); + let mut this = HolePunchConfig { + bands: Vec::new(), + locally_bound_sockets: Some(local_sockets), + }; + + for peer_internal_addr in peer_internal_addrs { + let mut bands = if let Some(bands) = peer_nat.predict(peer_internal_addr) { + bands + } else if cfg!(feature = "localhost-testing") { + log::info!(target: "citadel", "Will revert to localhost testing mode (not recommended for production use (peer addr: {:?}))", peer_internal_addr); + get_localhost_bands(peer_internal_addr) + } else { + // the peer nat is untraversable. However, they may still be able to connect to this node. + // As such, we will only listen: + vec![AddrBand { necessary_ip: peer_internal_addr.ip(), anticipated_ports: vec![peer_internal_addr.port()], - }], - locally_bound_sockets: Some(vec![local_socket]), - } - }; + }] + }; - // Sometimes, even on localhost testing, both NATs are predictable, therefore the second branch above - // does not execute. This means that it entirely misses out on the localhost adjacent node. - // Therefore, we need to add it here: - this.bands.extend(get_localhost_bands(peer_internal_addr)); + // Sometimes, even on localhost testing, both NATs are predictable, therefore the second branch above + // does not execute. This means that it entirely misses out on the localhost adjacent node. + // Therefore, we need to add it here: + bands.extend(get_localhost_bands(peer_internal_addr)); + + let bands = bands.into_iter().unique().collect(); + this.bands.push(bands) + } this } } fn get_localhost_bands(peer_internal_addr: &SocketAddr) -> Vec { - vec![AddrBand { - necessary_ip: IpAddr::from([127, 0, 0, 1]), - anticipated_ports: vec![peer_internal_addr.port()], - }] + vec![ + AddrBand { + necessary_ip: IpAddr::from(Ipv4Addr::LOCALHOST), + anticipated_ports: vec![peer_internal_addr.port()], + }, + AddrBand { + necessary_ip: IpAddr::V6(Ipv6Addr::LOCALHOST), + anticipated_ports: vec![peer_internal_addr.port()], + }, + ] } diff --git a/citadel_wire/src/udp_traversal/linear/method3.rs b/citadel_wire/src/udp_traversal/linear/method3.rs index 61849f3f2..ddb0426db 100644 --- a/citadel_wire/src/udp_traversal/linear/method3.rs +++ b/citadel_wire/src/udp_traversal/linear/method3.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::io::ErrorKind; use std::net::SocketAddr; @@ -30,7 +30,8 @@ pub struct Method3 { enum NatPacket { Syn(HolePunchID, u32, RelativeNodeType, SocketAddr), // contains the local bind addr of candidate for socket identification - SynAck(HolePunchID, RelativeNodeType, SocketAddr), + SynAck(HolePunchID, RelativeNodeType, SocketAddr, HolePunchID), + Check, } impl Method3 { @@ -51,11 +52,10 @@ impl Method3 { pub(crate) async fn execute( &self, socket: &UdpSocket, - endpoints: &Vec, + endpoints: &[SocketAddr], ) -> Result { match self.this_node_type { RelativeNodeType::Initiator => self.execute_either(socket, endpoints).await, - RelativeNodeType::Receiver => self.execute_either(socket, endpoints).await, } } @@ -79,7 +79,7 @@ impl Method3 { async fn execute_either( &self, socket: &UdpSocket, - endpoints: &Vec, + endpoints: &[SocketAddr], ) -> Result { let default_ttl = socket.ttl().ok(); log::trace!(target: "citadel", "Default TTL: {:?}", default_ttl); @@ -97,7 +97,7 @@ impl Method3 { ttl_init: 20, delta_ttl: Some(60), socket: socket_wrapper, - endpoints, + endpoints: &tokio::sync::Mutex::new(endpoints.iter().copied().collect()), encryptor, millis_delta: MILLIS_DELTA, count: 2, @@ -120,12 +120,11 @@ impl Method3 { ), ); - match timeout.await { - Ok(res) => res, - Err(_) => Err(FirewallError::HolePunch( + timeout.await.unwrap_or_else(|_| { + Err(FirewallError::HolePunch( "Timeout while waiting for UDP penetration".to_string(), - )), - } + )) + }) }; let sender_task = async move { @@ -133,7 +132,6 @@ impl Method3 { Self::send_packet_barrage(packet_send_params, None) .await .map_err(|err| FirewallError::HolePunch(err.to_string()))?; - //Self::send_syn_barrage(120, None, socket_wrapper, endpoints, encryptor, MILLIS_DELTA, 3,unique_id.clone()).await.map_err(|err| FirewallError::HolePunch(err.to_string()))?; Ok(()) as Result<(), FirewallError> }; @@ -142,9 +140,9 @@ impl Method3 { log::trace!(target: "citadel", "Hole-punch join result: recv={:?} and send={:?}", res0, res1); if let Some(default_ttl) = default_ttl { - socket + let _ = socket .set_ttl(default_ttl) - .map_err(|err| FirewallError::HolePunch(err.to_string()))?; + .map_err(|err| FirewallError::HolePunch(err.to_string())); } let hole_punched_addr = res0?; @@ -158,7 +156,7 @@ impl Method3 { #[allow(clippy::too_many_arguments)] async fn send_packet_barrage( params: &SendPacketBarrageParams<'_>, - syn_received_addr: Option, + syn_received_addr: Option<(SocketAddr, HolePunchID)>, ) -> Result<(), anyhow::Error> { let SendPacketBarrageParams { ttl_init, @@ -178,33 +176,29 @@ impl Method3 { .map(|idx| ttl_init + (idx * delta_ttl)) .collect::>(); - let mut endpoints_not_reachable = Vec::new(); - // fan-out all packets from a singular source to multiple consumers using the ttls specified for ttl in ttls { let _ = sleep.tick().await; - for endpoint in endpoints.iter() { - if endpoints_not_reachable.contains(endpoint) { - continue; - } + let mut endpoints_lock = endpoints.lock().await; - let packet_ty = if let Some(syn_addr) = syn_received_addr { + for endpoint in endpoints_lock.clone() { + let packet_ty = if let Some((syn_addr, peer_id_recv)) = syn_received_addr { // put the addr the peer used to send to this node, that way the peer knows where // to send the packet, even if the receive address is translated - NatPacket::SynAck(*unique_id, *this_node_type, syn_addr) + NatPacket::SynAck(*unique_id, *this_node_type, syn_addr, peer_id_recv) } else { // put the endpoint we are sending to in the payload, that way, once we get a SynAck, we know // where our sent packet was sent that worked - NatPacket::Syn(*unique_id, ttl, *this_node_type, *endpoint) + NatPacket::Syn(*unique_id, ttl, *this_node_type, endpoint) // SynAck }; - let packet_plaintext = bincode2::serialize(&packet_ty).unwrap(); + let packet_plaintext = bincode::serialize(&packet_ty).unwrap(); let packet = encryptor.generate_packet(&packet_plaintext); log::trace!(target: "citadel", "Sending TTL={} to {} || {:?}", ttl, endpoint, &packet[..] as &[u8]); - match socket.send(&packet, *endpoint, Some(ttl)).await { + match socket.send(&packet, endpoint, Some(ttl)).await { Ok(can_continue) => { if !can_continue { log::trace!(target: "citadel", "Early-terminating SYN barrage"); @@ -212,23 +206,28 @@ impl Method3 { } } Err(err) => { - if err.kind() != ErrorKind::AddrNotAvailable { + let err_kind = err.kind(); + if err_kind != ErrorKind::AddrNotAvailable { log::warn!(target: "citadel", "Error sending packet from {:?} to {endpoint}: {:?}", socket.socket.local_addr()?, err); } - if err.kind().to_string().contains("NetworkUnreachable") { - endpoints_not_reachable.push(*endpoint); + if err.to_string().contains("NetworkUnreachable") { + endpoints_lock.remove(&endpoint); } - if endpoints_not_reachable.len() == endpoints.len() { - log::warn!(target: "citadel", "All endpoints are unreachable"); - return Err(anyhow::Error::msg( - "All UDP endpoints are unreachable for NAT traversal", - )); + if err_kind == ErrorKind::InvalidInput { + endpoints_lock.remove(&endpoint); } } } } + + if endpoints_lock.is_empty() { + log::warn!(target: "citadel", "No endpoints to send to for {unique_id:?} (local bind: {})", socket.socket.local_addr()?); + return Err(anyhow::Error::msg( + "All UDP endpoints are unreachable for NAT traversal", + )); + } } Ok(()) @@ -238,14 +237,14 @@ impl Method3 { async fn recv_until( socket: &UdpWrapper<'_>, encryptor: &HolePunchConfigContainer, - _unique_id: &HolePunchID, + unique_id: &HolePunchID, observed_addrs_on_syn: &Mutex>, _millis_delta: u64, this_node_type: RelativeNodeType, send_packet_params: &SendPacketBarrageParams<'_>, ) -> Result { let buf = &mut [0u8; 4096]; - log::trace!(target: "citadel", "[Hole-punch] Listening on {:?}", socket.socket.local_addr().unwrap()); + log::trace!(target: "citadel", "[Hole-punch] Listening on {:?}", socket.socket.local_addr()?); let mut has_received_syn = false; loop { @@ -260,21 +259,25 @@ impl Method3 { } }; - match bincode2::deserialize(&packet) + match bincode::deserialize(&packet) .map_err(|err| FirewallError::HolePunch(err.to_string())) { + Ok(NatPacket::Check) => { + continue; + } Ok(NatPacket::Syn( peer_unique_id, ttl, adjacent_node_type, their_send_addr, )) => { - if adjacent_node_type == this_node_type { - log::warn!(target: "citadel", "RECV loopback packet; will discard"); + if has_received_syn { continue; } - if has_received_syn { + if adjacent_node_type == this_node_type || &peer_unique_id == unique_id + { + log::warn!(target: "citadel", "RECV loopback packet; will discard"); continue; } @@ -292,19 +295,19 @@ impl Method3 { has_received_syn = true; - let send_addrs = send_packet_params - .endpoints - .iter() - .copied() - .chain(std::iter::once(peer_external_addr)) - .collect::>(); - - let mut send_params = send_packet_params.clone(); - send_params.endpoints = &send_addrs; - - Self::send_packet_barrage(&send_params, Some(their_send_addr)) - .await - .map_err(|err| FirewallError::HolePunch(err.to_string()))?; + let mut lock = send_packet_params.endpoints.lock().await; + let send_addrs = std::iter::once(peer_external_addr) + .chain(lock.iter().copied()) + .collect::>(); + *lock = send_addrs; + drop(lock); + + Self::send_packet_barrage( + send_packet_params, + Some((their_send_addr, peer_unique_id)), + ) + .await + .map_err(|err| FirewallError::HolePunch(err.to_string()))?; } // the reception of a SynAck proves the existence of a hole punched since there is bidirectional communication through the NAT @@ -312,6 +315,7 @@ impl Method3 { adjacent_unique_id, adjacent_node_type, address_we_sent_to, + our_id, )) => { log::trace!(target: "citadel", "RECV SYN_ACK"); if adjacent_node_type == this_node_type { @@ -319,14 +323,29 @@ impl Method3 { continue; } + if &our_id != unique_id { + log::warn!(target: "citadel", "RECV Packet from wrong hole punching process. Received {our_id:?}, But expected our id of {unique_id:?}"); + continue; + } + // NOTE: it is entirely possible that we receive a SynAck before even getting a Syn. // Since we send SYNs to the other node, and, it's possible that we don't receive a SYN by the time // the other node ACKs our sent out SYN, we should not terminate. - let expected_addr = address_we_sent_to; - if peer_external_addr != expected_addr { - log::warn!(target: "citadel", "[will allow] RECV SYN_ACK that comes from the wrong addr. RECV: {:?}, Expected: {:?}", peer_external_addr, expected_addr); - //continue; + if peer_external_addr != address_we_sent_to { + let packet = bincode::serialize(&NatPacket::Check).unwrap(); + // See if we can send a packet to the addr + if socket + .socket + .send_to(&packet, peer_external_addr) + .await + .is_ok() + { + log::warn!(target: "citadel", "[will allow] RECV SYN_ACK that comes from the wrong addr. RECV: {:?}, Expected: {:?}", peer_external_addr, address_we_sent_to); + } else { + log::warn!(target: "citadel", "[will NOT allow] RECV SYN_ACK that comes from the wrong addr. RECV: {:?}, Expected: {:?}", peer_external_addr, address_we_sent_to); + continue; + } } // this means there was a successful ping-pong. @@ -353,7 +372,13 @@ impl Method3 { target: "citadel", "Error receiving packet from {:?}: {err:?}", socket.socket.local_addr()? - ) + ); + + if err.kind() == ErrorKind::ConnectionReset { + return Err(FirewallError::HolePunch( + "Connection reset while waiting for UDP penetration".to_string(), + )); + } } } } @@ -416,12 +441,11 @@ impl UdpWrapper<'_> { } } -#[derive(Clone)] struct SendPacketBarrageParams<'a> { ttl_init: u32, delta_ttl: Option, socket: &'a UdpWrapper<'a>, - endpoints: &'a Vec, + endpoints: &'a tokio::sync::Mutex>, encryptor: &'a HolePunchConfigContainer, millis_delta: u64, count: u32, diff --git a/citadel_wire/src/udp_traversal/linear/mod.rs b/citadel_wire/src/udp_traversal/linear/mod.rs index 0982d7119..36177ba28 100644 --- a/citadel_wire/src/udp_traversal/linear/mod.rs +++ b/citadel_wire/src/udp_traversal/linear/mod.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use citadel_io::UdpSocket; use either::Either; use igd::PortMappingProtocol; +use tokio::sync::mpsc::UnboundedSender; use tokio::time::Duration; use crate::error::FirewallError; @@ -71,7 +72,7 @@ impl SingleUDPHolePuncher { &mut self, method: NatTraversalMethod, mut kill_switch: tokio::sync::broadcast::Receiver<(HolePunchID, HolePunchID)>, - post_kill_rebuild: tokio::sync::mpsc::UnboundedSender>, + mut post_kill_rebuild: tokio::sync::mpsc::UnboundedSender>, ) -> Result { match method { NatTraversalMethod::UPnP => { @@ -112,6 +113,7 @@ impl SingleUDPHolePuncher { socket: self.socket.take().ok_or_else(|| { FirewallError::HolePunch("UDP socket not loaded".to_string()) })?, + local_id: unique_id, }) } @@ -148,27 +150,36 @@ impl SingleUDPHolePuncher { res1 = kill_listener => Either::Left(res1) }; + async fn handle_rebuild_input( + this: &mut SingleUDPHolePuncher, + post_kill_rebuild: &mut UnboundedSender>, + id_opt: Option<(HolePunchID, HolePunchID)>, + ) -> Result { + match id_opt { + Some((_local_id, peer_id)) => { + post_kill_rebuild.send(Some(this.recovery_mode_generate_socket_by_remote_id(peer_id).ok_or_else(|| FirewallError::HolePunch("Kill switch called, but no matching values were found internally".to_string()))?)).map_err(|err| FirewallError::HolePunch(err.to_string()))?; + } + + None => { + log::trace!(target: "citadel", "Will end hole puncher {:?} since kill switch called", this.get_unique_id()); + post_kill_rebuild + .send(None) + .map_err(|err| FirewallError::HolePunch(err.to_string()))?; + } + } + + Err(FirewallError::Skip) + } + match res { Either::Right(addr) => Ok(HolePunchedUdpSocket { socket: self.socket.take().unwrap(), addr, + local_id: this_local_id, }), Either::Left(id_opt) => { - match id_opt { - Some((_local_id, peer_id)) => { - post_kill_rebuild.send(Some(self.recovery_mode_generate_socket_by_remote_id(peer_id).ok_or_else(|| FirewallError::HolePunch("Kill switch called, but no matching values were found internally".to_string()))?)).map_err(|err| FirewallError::HolePunch(err.to_string()))?; - } - - None => { - log::trace!(target: "citadel", "Will end hole puncher {:?} since kill switch called", self.get_unique_id()); - post_kill_rebuild - .send(None) - .map_err(|err| FirewallError::HolePunch(err.to_string()))?; - } - } - - Err(FirewallError::Skip) + handle_rebuild_input(self, &mut post_kill_rebuild, id_opt).await } } } @@ -189,6 +200,7 @@ impl SingleUDPHolePuncher { receive_address: self.peer_external_addr(), unique_id, }, + local_id: unique_id, }) } } @@ -224,17 +236,19 @@ impl SingleUDPHolePuncher { .method3 .1 .get_peer_external_addr_from_peer_hole_punch_id(remote_id)?; - let socket = self.socket.take()?; - Some(HolePunchedUdpSocket { addr, socket }) + self.recovery_mode_generate_socket_by_addr(addr) } - /// this should only be called when the adjacent node verified that the connection occurred pub fn recovery_mode_generate_socket_by_addr( &mut self, addr: TargettedSocketAddr, ) -> Option { let socket = self.socket.take()?; - Some(HolePunchedUdpSocket { addr, socket }) + Some(HolePunchedUdpSocket { + addr, + socket, + local_id: self.unique_id, + }) } } diff --git a/citadel_wire/src/udp_traversal/multi/mod.rs b/citadel_wire/src/udp_traversal/multi/mod.rs index 51ee1d59f..775935cc3 100644 --- a/citadel_wire/src/udp_traversal/multi/mod.rs +++ b/citadel_wire/src/udp_traversal/multi/mod.rs @@ -1,24 +1,24 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use futures::stream::FuturesUnordered; -use futures::{Future, StreamExt}; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::UnboundedReceiver; - use crate::error::FirewallError; use crate::udp_traversal::hole_punch_config::HolePunchConfig; use crate::udp_traversal::linear::encrypted_config_container::HolePunchConfigContainer; use crate::udp_traversal::linear::SingleUDPHolePuncher; use crate::udp_traversal::targetted_udp_socket_addr::HolePunchedUdpSocket; use crate::udp_traversal::{HolePunchID, NatTraversalMethod}; -use netbeam::reliable_conn::ReliableOrderedStreamToTarget; +use futures::future::select_ok; +use futures::stream::FuturesUnordered; +use futures::{Future, StreamExt}; +use netbeam::multiplex::MultiplexedConn; +use netbeam::sync::channel::bi_channel::{ChannelRecvHalf, ChannelSendHalf}; use netbeam::sync::network_endpoint::NetworkEndpoint; -use netbeam::sync::subscription::Subscribable; use netbeam::sync::RelativeNodeType; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::sync::mpsc::UnboundedReceiver; /// Punches a hole using IPv4/6 addrs. IPv6 is more traversal-friendly since IP-translation between external and internal is not needed (unless the NAT admins are evil) /// @@ -29,11 +29,12 @@ pub(crate) struct DualStackUdpHolePuncher { Pin> + Send + 'static>>, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy)] #[allow(variant_size_differences)] -enum DualStackCandidate { - MutexSet(HolePunchID, HolePunchID), +enum DualStackCandidateSignal { + Winner(HolePunchID, HolePunchID), WinnerCanEnd, + AllFailed, } impl DualStackUdpHolePuncher { @@ -53,11 +54,24 @@ impl DualStackUdpHolePuncher { .locally_bound_sockets .take() .ok_or_else(|| anyhow::Error::msg("sockets already taken"))?; - let addrs_to_ping: &Vec = &hole_punch_config.into_iter().collect(); + let addrs_to_ping = hole_punch_config + .into_iter() + .collect::>>(); // each individual hole puncher fans-out from 1 bound socket to n many peer addrs (determined by addrs_to_ping) - for socket in sockets { - // TODO: ensure only *some* of the addrs in addrs_to_ping get passed (MAX 2) + for (socket, mut addrs_to_ping) in sockets.into_iter().zip(addrs_to_ping) { + let socket_local_addr = socket.local_addr()?; + // We can't send from an ipv4 socket to an ipv6, addr, so remove any addrs that are ipv6 + if socket_local_addr.is_ipv4() { + addrs_to_ping.retain(|addr| addr.is_ipv4()); + } + + // We can't send from an ipv6 socket to and ipv4, addr, so remove any addrs that are ipv4 + if socket_local_addr.is_ipv6() { + addrs_to_ping.retain(|addr| addr.is_ipv6()); + } + + log::trace!(target: "citadel", "Hole punching with socket: {socket_local_addr} | addrs to ping: {addrs_to_ping:?}"); let hole_puncher = SingleUDPHolePuncher::new( relative_node_type, encrypted_config_container.clone(), @@ -107,16 +121,18 @@ async fn drive( log::trace!(target: "citadel", "Initiating subscription ..."); // initiate a dedicated channel for sending packets for coordination - let conn = &(app.initiate_subscription().await?); + let conn = app.bi_channel::().await?; + let (ref conn_tx, conn_rx) = conn.split(); + let conn_rx = &tokio::sync::Mutex::new(conn_rx); log::trace!(target: "citadel", "Initiating NetMutex ..."); // setup a mutex for handling contentions - let net_mutex = &(app.mutex::>(value).await?); + let net_mutex = &(app + .mutex::>(value) + .await?); let (final_candidate_tx, final_candidate_rx) = tokio::sync::oneshot::channel::(); - let (reader_done_tx, mut reader_done_rx) = tokio::sync::broadcast::channel::<()>(2); - let mut reader_done_rx_3 = reader_done_tx.subscribe(); let (ref kill_signal_tx, _kill_signal_rx) = tokio::sync::broadcast::channel(hole_punchers.len()); @@ -143,122 +159,184 @@ async fn drive( post_rebuild_rx: Some(post_rebuild_rx), }); - let loser_value_set = &citadel_io::Mutex::new(None); - let mut futures = FuturesUnordered::new(); for (kill_switch_rx, mut hole_puncher) in hole_punchers .into_iter() .map(|r| (kill_signal_tx.subscribe(), r)) { - futures.push(async move { + // TODO: Consider spawning to ensure if the reader/future-processor fail, + // the background still can send its results to the background rebuilder + let post_rebuild_tx = post_rebuild_tx.clone(); + let task = async move { let res = hole_puncher - .try_method( - NatTraversalMethod::Method3, - kill_switch_rx, - post_rebuild_tx.clone(), - ) + .try_method(NatTraversalMethod::Method3, kill_switch_rx, post_rebuild_tx) .await; (res, hole_puncher) - }); + }; + + let task = tokio::task::spawn(task); + + futures.push(task); } - let current_enqueued = &tokio::sync::Mutex::new(None); + let current_enqueued: &tokio::sync::Mutex> = + &tokio::sync::Mutex::new(vec![]); let finished_count = &citadel_io::Mutex::new(0); let hole_puncher_count = futures.len(); + let commanded_winner = &tokio::sync::Mutex::new(None); + + let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>(); + let done_tx = citadel_io::Mutex::new(Some(done_tx)); + + let signal_done = || -> Result<(), anyhow::Error> { + let tx = done_tx + .lock() + .take() + .ok_or_else(|| anyhow::Error::msg("signal_done has already been called"))?; + tx.send(()) + .map_err(|_| anyhow::Error::msg("signal_done oneshot sender failed to send")) + }; + + let failure_occurred = &AtomicBool::new(false); + let set_failure_occurred = || async move { + let no_failure_yet = !failure_occurred.fetch_or(true, Ordering::SeqCst); + if no_failure_yet { + log::trace!(target: "citadel", "All hole-punchers have failed locally. Will send AllFailed signal"); + send(DualStackCandidateSignal::AllFailed, conn_tx).await?; + Ok(()) + } else { + // In this case, remote already set_failure_occurred, so we know that since they + // failed, and now that we failed, we can end. + log::error!(target: "citadel", "Remote has already failed, and locally failed, therefore returning"); + Err(anyhow::Error::msg( + "All local and remote hold punchers failed", + )) + } + }; + // This is called to scan currently-running tasks to terminate, and, returning the rebuilt // hole-punched socket on completion - let assert_rebuild_ready = |local_id: HolePunchID, peer_id: HolePunchID| async move { + let loser_rebuilder_task = async move { let mut lock = rebuilder.lock().await; - // first, check local failures - if let Some(mut failure) = lock.local_failures.remove(&local_id) { - log::trace!(target: "citadel", "[Rebuild] While searching local_failures, found match"); - if let Some(rebuilt) = failure.recovery_mode_generate_socket_by_remote_id(peer_id) { - return Ok(rebuilt); - } else { - log::warn!(target: "citadel", "[Rebuild] Found in local_failures, but, failed to find rebuilt socket"); - } - } - - let _receivers = kill_signal_tx.send((local_id, peer_id))?; let mut post_rebuild_rx = lock .post_rebuild_rx .take() .ok_or_else(|| anyhow::Error::msg("post_rebuild_rx has already been taken"))?; - log::trace!(target: "citadel", "*** Will now await post_rebuild_rx ... {} have finished", finished_count.lock()); - let mut count = 0; - // Note: if properly implemented, the below should return almost instantly - loop { - if let Some(current_enqueued) = current_enqueued.lock().await.take() { - log::trace!(target: "citadel", "Grabbed the currently enqueued socket!"); - return Ok(current_enqueued); - } + drop(lock); + + let loser_poller = async move { + let mut ticker = tokio::time::interval(Duration::from_millis(100)); + loop { + ticker.tick().await; + + if let Some((local_id, peer_id)) = *commanded_winner.lock().await { + log::trace!(target: "citadel", "Local {local_id:?} has been commanded to use {peer_id:?}"); + let receivers = kill_signal_tx.send((local_id, peer_id)).unwrap_or(0); + log::trace!(target: "citadel", "Sent kill signal to {receivers} hole-punchers"); + + 'pop: while let Some(current_enqueued) = current_enqueued.lock().await.pop() { + log::trace!(target: "citadel", "Maybe grabbed the currently enqueued local socket {:?}: {:?}", current_enqueued.local_id, current_enqueued.addr); + if current_enqueued.addr.unique_id != peer_id { + log::warn!(target: "citadel", "Cannot use the enqueued socket since ID does not match"); + continue 'pop; + } - match post_rebuild_rx.recv().await { - None => return Err(anyhow::Error::msg("post_rebuild_rx failed")), + return Ok(current_enqueued); + } - Some(None) => { - count += 1; - log::trace!(target: "citadel", "*** [rebuild] So-far, {}/{} have finished", count, hole_puncher_count); - if count == hole_puncher_count { - log::error!(target: "citadel", "This should not happen") + let mut lock = rebuilder.lock().await; + if let Some(failure) = lock.local_failures.get_mut(&local_id) { + log::trace!(target: "citadel", "[Rebuild] While searching local_failures, found match"); + if let Some(rebuilt) = + failure.recovery_mode_generate_socket_by_remote_id(peer_id) + { + return Ok(rebuilt); + } else { + log::warn!(target: "citadel", "[Rebuild] Found in local_failures, but, failed to find rebuilt socket"); + } } - } - Some(Some(res)) => { - log::trace!(target: "citadel", "*** [rebuild] complete"); - return Ok(res); + if lock.local_failures.len() == hole_puncher_count { + return Err(anyhow::Error::msg("All hole-punchers have failed (t1)")); + } } } - } - }; - - let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>(); - let done_tx = citadel_io::Mutex::new(Some(done_tx)); + }; - let signal_done = || -> Result<(), anyhow::Error> { - let tx = done_tx - .lock() - .take() - .ok_or_else(|| anyhow::Error::msg("signal_done has already been called"))?; - tx.send(()) - .map_err(|_| anyhow::Error::msg("signal_done oneshot sender failed to send")) - }; + let loser_rebuilder_task = async move { + log::trace!(target: "citadel", "*** Will now await post_rebuild_rx ... {} have finished", finished_count.lock()); + // Note: if properly implemented, the below should return almost instantly + loop { + let result = post_rebuild_rx.recv().await; + log::trace!(target: "citadel", "*** [rebuild] Received signal {:?}", result); + match result { + None => return Err(anyhow::Error::msg("post_rebuild_rx failed")), + + Some(None) => { + let fail_count = rebuilder.lock().await.local_failures.len(); + log::trace!(target: "citadel", "*** [rebuild] So-far, {}/{} have finished", fail_count, hole_puncher_count); + if fail_count == hole_puncher_count { + return Err(anyhow::Error::msg("All hole-punchers have failed (t2)")); + } + } - let (winner_can_end_tx, winner_can_end_rx) = tokio::sync::oneshot::channel(); + Some(Some(res)) => { + log::trace!(target: "citadel", "*** [rebuild] complete"); + return Ok(res); + } + } + } + }; - let (futures_tx, mut futures_rx) = tokio::sync::mpsc::unbounded_channel(); + let hole_punched_socket_res = select_ok([ + Box::pin(loser_poller) + as Pin>>>, + Box::pin(loser_rebuilder_task), + ]) + .await + .map(|res| res.0); + + match hole_punched_socket_res { + Err(err) => { + // The only way an error can occur is if the total number of failures is equal to the number of hole-punchers + // In this case, while remote claimed a winner, we were unable to create/find the winner (this should be unreachable) + log::error!(target: "citadel", "Rebuilder task failed. Please contact developers on Github: {err:?}"); + set_failure_occurred().await + } - let futures_executor = async move { - while let Some(res) = futures.next().await { - futures_tx - .send(res) - .map_err(|_| anyhow::Error::msg("futures_tx send error"))?; + Ok(hole_punched_socket) => { + log::trace!(target: "citadel", "Selecting socket: {hole_punched_socket:?}"); + let _ = hole_punched_socket.cleanse(); + submit_final_candidate(hole_punched_socket)?; + signal_done() + } } - - log::trace!(target: "citadel", "Finished polling all futures"); - Ok(reader_done_rx_3.recv().await?) as Result<(), anyhow::Error> }; - // the goal of the sender is just to send results as local finishes, nothing else let futures_resolver = async move { - while let Some((res, hole_puncher)) = futures_rx.recv().await { - log::trace!(target: "citadel", "[Future resolver loop] Received {:?}", res); + while let Some(res) = futures.next().await { *finished_count.lock() += 1; + + let (res, hole_puncher) = match res { + Ok(res) => res, + Err(err) => { + log::warn!(target: "citadel", "Hole-puncher task failed: {err:?}"); + continue; + } + }; + + log::trace!(target: "citadel", "[Future resolver loop] Received {res:?}"); + match res { Ok(socket) => { let peer_unique_id = socket.addr.unique_id; let local_id = hole_puncher.get_unique_id(); + current_enqueued.lock().await.push(socket); - if let Some((pre_local, pre_remote)) = *loser_value_set.lock() { - log::trace!(target: "citadel", "*** Local did not win, and, already received a MutexSet: ({:?}, {:?})", pre_local, pre_remote); - if local_id == pre_local && peer_unique_id == pre_remote { - log::trace!(target: "citadel", "*** Local did not win, and, is currently waiting for the current value! (returning)"); - // this implies local is already waiting for this result. Submit and finish here - post_rebuild_tx.send(Some(socket))?; - } - - // continue to keep polling futures + if let Some((required_local, required_remote)) = *commanded_winner.lock().await + { + log::trace!(target: "citadel", "*** [Future resolver loop] Commanded winner (skipping NetMutex acquisition): {required_local:?}, {required_remote:?}. Will require rebuilder task to return the valid socket ..."); continue; } @@ -266,28 +344,53 @@ async fn drive( // future: if this node gets here, and waits for the mutex to drop from the other end, // the other end may say that the current result is valid, but, be unaccessible since // we are blocked waiting for the mutex. As such, we need to set the enqueued field - *current_enqueued.lock().await = Some(socket); - let mut net_lock = net_mutex.lock().await?; - if let Some(socket) = current_enqueued.lock().await.take() { - if net_lock.as_ref().is_none() { - log::trace!(target: "citadel", "*** Local won! Will command other side to use ({:?}, {:?})", peer_unique_id, local_id); - *net_lock = Some(()); + log::trace!(target: "citadel", "*** [Future resolver loop] Acquiring NetMutex ...."); + let Ok(mut net_lock) = net_mutex.lock().await else { + log::trace!(target: "citadel", "*** [Future resolver loop] Mutex failed to acquire. Likely dropped. Will continue ..."); + continue; + }; + + log::trace!(target: "citadel", "*** [Future resolver loop] Mutex acquired. Local = {local_id:?}, Remote = {peer_unique_id:?}"); + if let Some((local, remote)) = *net_lock { + log::trace!(target: "citadel", "*** The Mutex is already set! Will not claim winner status ..."); + *commanded_winner.lock().await = Some((local, remote)); + if local_id == local && peer_unique_id == remote { + log::trace!(target: "citadel", "*** [Future resolver loop] The received socket *IS* the socket remote requested. Will wait for background rebuilder to finish ..."); + } else { + log::trace!(target: "citadel", "*** [Future resolver loop] The received socket *is NOT* the socket remote requested. Will wait for background rebuilder to finish ..."); + } + } else { + // We are the winner + log::trace!(target: "citadel", "*** Local won! Will command other side to use ({:?}, {:?})", peer_unique_id, local_id); + // Tell the other side we won, that way the rebuilder background process for the other + // side can respond. If we don't send this message, then, it's possible hanging occurs + // on the loser end because the winner combo isn't obtained until this futures + // resolver received a completed future; since in variable NAT setups, the adjacent side may fail + // entirely, it could never finish, thus never trigger the code that sets the commanded_winner + // and thus prompts the background code to return the socket on the adjacent node. + send( + DualStackCandidateSignal::Winner(peer_unique_id, local_id), + conn_tx, + ) + .await?; + while let Some(socket) = current_enqueued.lock().await.pop() { + if socket.local_id != local_id { + log::warn!(target: "citadel", "*** Winner: socket ID mismatch. Expected {local_id:?}, got {:?}. Looping ...", socket.local_id); + continue; + } + + *net_lock = Some((peer_unique_id, local_id)); let _ = socket.cleanse(); - // Hold the mutex to prevent the other side from accessing the data. It will need to end via the other means - send(DualStackCandidate::MutexSet(peer_unique_id, local_id), conn) - .await?; submit_final_candidate(socket)?; log::trace!(target: "citadel", "*** [winner] Awaiting the signal ..."); + drop(net_lock); // the winner will drop once the adjacent node sends a WinnerCanEnd signal - winner_can_end_rx.await?; + //winner_can_end_rx.await?; log::trace!(target: "citadel", "*** [winner] received the signal"); - std::mem::drop(net_lock); return signal_done(); - } else { - log::error!(target: "citadel", "This should not happen"); } - } else { - log::trace!(target: "citadel", "While looping, detected that the socket was taken") + + unreachable!("Winner did not find any enqueued sockets. This is a developer bug. Please report this issue to github"); } } @@ -297,59 +400,100 @@ async fn drive( Err(err) => { log::warn!(target: "citadel", "[non-terminating] Hole-punch for local bind addr {:?} failed: {:?}", hole_puncher.get_unique_id(), err); - rebuilder - .lock() - .await - .local_failures - .insert(hole_puncher.get_unique_id(), hole_puncher); + let fail_count = { + let mut lock = rebuilder.lock().await; + let _ = lock + .local_failures + .insert(hole_puncher.get_unique_id(), hole_puncher); + lock.local_failures.len() + }; + + if fail_count == hole_puncher_count { + // All failed locally, but, remote may claim that it has a valid socket/ + // Run the function below to exit if remote already set_failure_occurred + log::warn!(target: "citadel", "All hole-punchers have failed locally"); + set_failure_occurred().await?; + } } } } - // if we get here before the reader finishes, we need to wait for the reader to finish - Ok(reader_done_rx.recv().await?) as Result<(), anyhow::Error> - //Ok(()) as Result<(), anyhow::Error> + log::trace!(target: "citadel", "Finished polling all futures"); + Ok(()) }; let reader = async move { - match receive::(conn).await? { - DualStackCandidate::MutexSet(local, remote) => { - log::trace!(target: "citadel", "*** received MutexSet. Will unconditionally end ..."); - assert!(loser_value_set.lock().replace((local, remote)).is_none()); - let hole_punched_socket = assert_rebuild_ready(local, remote).await?; - let _ = hole_punched_socket.cleanse(); - submit_final_candidate(hole_punched_socket)?; - // return here. The winner must exit last - send(DualStackCandidate::WinnerCanEnd, conn).await?; - signal_done() - } + let mut conn_rx = conn_rx.lock().await; + loop { + match receive(&mut conn_rx).await? { + DualStackCandidateSignal::Winner(local_id, peer_id) => { + log::trace!(target: "citadel", "[READER] Remote commanded local to use peer={peer_id:?} and local={local_id:?}"); + *commanded_winner.lock().await = Some((local_id, peer_id)); + } + DualStackCandidateSignal::AllFailed => { + // All failed locally, but, remote may claim that it has a valid socket/ + // Run the function below to exit if remote already set_failure_occurred + log::warn!(target: "citadel", "Remote claims all hole punchers failed"); + set_failure_occurred().await?; + // If we reach here, it implies this node is still resolving futures. Do not return + // until the other joined future resolves itself + } - DualStackCandidate::WinnerCanEnd => { - winner_can_end_tx - .send(()) - .map_err(|_| anyhow::Error::msg("Unable to send through winner_can_end_tx"))?; - Ok(()) + DualStackCandidateSignal::WinnerCanEnd => { + /*winner_can_end_tx.send(()).map_err(|_| { + anyhow::Error::msg("Unable to send through winner_can_end_tx") + })?;*/ + return Ok::<_, anyhow::Error>(()); + } } } }; log::trace!(target: "citadel", "[DualStack] Executing hole-puncher ...."); - let sender_reader_combo = futures::future::try_join(futures_resolver, reader); + let sender_reader_combo = async move { + let res = futures::future::select_ok([ + Box::pin(futures_resolver) + as Pin> + Send>>, + Box::pin(reader), + ]) + .await; + if let Some(err) = res.as_ref().err() { + log::warn!(target: "citadel", "Both reader/resolver futures failed: {err:?}") + } + + // Just wait for the background process to finish up + futures::future::pending().await + }; tokio::select! { - res0 = sender_reader_combo => { - log::trace!(target: "citadel", "[DualStack] Sender/Reader combo finished {res0:?}"); - res0.map(|_| ())? + _res0 = sender_reader_combo => { + log::trace!(target: "citadel", "[DualStack] Sender/Reader combo finished"); }, res1 = done_rx => { log::trace!(target: "citadel", "[DualStack] Done signal received {res1:?}"); res1? }, - res2 = futures_executor => { - log::trace!(target: "citadel", "[DualStack] Futures executor finished {res2:?}"); + res2 = loser_rebuilder_task => { + log::trace!(target: "citadel", "[DualStack] Loser rebuilder task finished {res2:?}"); res2? } - }; + } + + if commanded_winner.lock().await.is_none() { + // We are the "winner" + log::trace!(target: "citadel", "Winner: awaiting WinnerCanEnd signal"); + let mut conn_rx = conn_rx.lock().await; + let signal = receive(&mut conn_rx).await?; + if let DualStackCandidateSignal::WinnerCanEnd = signal { + log::trace!(target: "citadel", "Received WinnerCanEnd signal"); + } else { + log::warn!(target: "citadel", "Received unexpected signal: {:?}", signal); + } + } else { + // We are the "loser" + log::trace!(target: "citadel", "Loser: sending WinnerCanEnd signal"); + send(DualStackCandidateSignal::WinnerCanEnd, conn_tx).await?; + } log::trace!(target: "citadel", "*** ENDING DualStack ***"); @@ -359,17 +503,17 @@ async fn drive( Ok(sock) } -async fn send( - input: R, - conn: &V, +async fn send( + input: DualStackCandidateSignal, + conn: &ChannelSendHalf, ) -> Result<(), anyhow::Error> { - Ok(conn - .send_to_peer(&bincode2::serialize(&input).unwrap()) - .await?) + conn.send_item(input).await } -async fn receive( - conn: &V, -) -> Result { - Ok(bincode2::deserialize(&conn.recv().await?)?) +async fn receive( + conn: &mut ChannelRecvHalf, +) -> Result { + conn.recv() + .await + .ok_or_else(|| anyhow::Error::msg("recv from bichannel failed: stream ended"))? } diff --git a/citadel_wire/src/udp_traversal/targetted_udp_socket_addr.rs b/citadel_wire/src/udp_traversal/targetted_udp_socket_addr.rs index 71b6551c4..32be00be1 100644 --- a/citadel_wire/src/udp_traversal/targetted_udp_socket_addr.rs +++ b/citadel_wire/src/udp_traversal/targetted_udp_socket_addr.rs @@ -3,6 +3,7 @@ use citadel_io::UdpSocket; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr}; +use std::time::Duration; #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct TargettedSocketAddr { @@ -60,11 +61,49 @@ impl Display for TargettedSocketAddr { #[derive(Debug)] pub struct HolePunchedUdpSocket { - pub socket: UdpSocket, + pub local_id: HolePunchID, + pub(crate) socket: UdpSocket, pub addr: TargettedSocketAddr, } impl HolePunchedUdpSocket { + pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result { + let bind_addr = self.socket.local_addr()?; + let bind_ip = bind_addr.ip(); + let send_ip = self.addr.send_address.ip(); + let send_ip = match (bind_ip, send_ip) { + (IpAddr::V4(_bind_ip), IpAddr::V6(send_ip)) => { + // If we're sending from an IPv4 address to an IPv6 address, we need to convert the + // IPv4 address to an IPv4-mapped IPv6 address + if let Some(addr) = send_ip.to_ipv4_mapped() { + IpAddr::V4(addr) + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "IPv4-mapped IPv6 address conversion failed; Cannot send from ipv4 socket to v6", + )); + } + } + + (IpAddr::V6(_bind_ip), IpAddr::V4(send_ip)) => { + // If we're sending from an IPv6 address to an IPv4 address, we need to convert the + // IPv4 address to an IPv4-mapped IPv6 address + IpAddr::V6(send_ip.to_ipv6_mapped()) + } + + _ => send_ip, + }; + + let target_addr = SocketAddr::new(send_ip, addr.port()); + log::trace!(target: "citadel", "Sending packet from {bind_addr} to {target_addr}"); + + tokio::time::timeout( + Duration::from_secs(2), + self.socket.send_to(buf, target_addr), + ) + .await + .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err.to_string()))? + } // After hole-punching, some packets may be sent that need to be flushed // this cleanses the stream pub(crate) fn cleanse(&self) -> std::io::Result<()> { @@ -81,4 +120,16 @@ impl HolePunchedUdpSocket { } } } + + pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + self.socket.recv_from(buf).await + } + + pub fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() + } + + pub fn into_socket(self) -> UdpSocket { + self.socket + } } diff --git a/citadel_wire/src/udp_traversal/udp_hole_puncher.rs b/citadel_wire/src/udp_traversal/udp_hole_puncher.rs index 980b95af5..fcd4107b5 100644 --- a/citadel_wire/src/udp_traversal/udp_hole_puncher.rs +++ b/citadel_wire/src/udp_traversal/udp_hole_puncher.rs @@ -1,4 +1,4 @@ -use crate::nat_identification::NatType; +use crate::nat_identification::{NatType, IDENTIFY_TIMEOUT}; use crate::udp_traversal::hole_punch_config::HolePunchConfig; use crate::udp_traversal::linear::encrypted_config_container::HolePunchConfigContainer; use crate::udp_traversal::multi::DualStackUdpHolePuncher; @@ -16,7 +16,8 @@ pub struct UdpHolePuncher<'a> { driver: Pin> + Send + 'a>>, } -const DEFAULT_TIMEOUT: Duration = Duration::from_millis(6000); +const DEFAULT_TIMEOUT: Duration = + Duration::from_millis((IDENTIFY_TIMEOUT.as_millis() + 5000) as u64); impl<'a> UdpHolePuncher<'a> { pub fn new( @@ -32,9 +33,9 @@ impl<'a> UdpHolePuncher<'a> { timeout: Duration, ) -> Self { Self { - driver: Box::pin(async move { - tokio::time::timeout(timeout, driver(conn, encrypted_config_container)).await? - }), + driver: Box::pin( + async move { driver(conn, encrypted_config_container, timeout).await }, + ), } } } @@ -47,14 +48,46 @@ impl Future for UdpHolePuncher<'_> { } } +const MAX_RETRIES: usize = 3; + #[cfg_attr( feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug)) )] async fn driver( + conn: &NetworkEndpoint, + encrypted_config_container: HolePunchConfigContainer, + timeout: Duration, +) -> Result { + let mut retries = 0; + loop { + let task = tokio::time::timeout( + timeout, + driver_inner(conn, encrypted_config_container.clone()), + ); + match task.await { + Ok(Ok(res)) => return Ok(res), + Ok(Err(err)) => { + log::warn!(target: "citadel", "Hole puncher failed: {err:?}"); + } + Err(_) => { + log::warn!(target: "citadel", "Hole puncher timed-out"); + } + } + + retries += 1; + + if retries >= MAX_RETRIES { + return Err(anyhow::Error::msg("Max retries reached for UDP Traversal")); + } + } +} + +async fn driver_inner( conn: &NetworkEndpoint, mut encrypted_config_container: HolePunchConfigContainer, ) -> Result { + log::trace!(target: "citadel", "[driver] Starting hole puncher ..."); // create stream let stream = &(conn.initiate_subscription().await?); let stun_servers = encrypted_config_container.take_stun_servers(); @@ -66,21 +99,24 @@ async fn driver( let peer_nat_type = &(stream.recv_serialized::().await?); let local_initial_socket = get_optimal_bind_socket(local_nat_type, peer_nat_type)?; - let internal_bind_addr = local_initial_socket.local_addr()?; + let internal_bind_addr_optimal = local_initial_socket.local_addr()?; + let mut sockets = vec![local_initial_socket]; + let mut internal_addresses = vec![internal_bind_addr_optimal]; + if internal_bind_addr_optimal.is_ipv6() { + let additional_socket = crate::socket_helpers::get_udp_socket("0.0.0.0:0")?; + internal_addresses.push(additional_socket.local_addr()?); + sockets.push(additional_socket); + } // exchange internal bind port, also synchronizing the beginning of the hole punch process // while doing so - let peer_internal_bind_addr = conn.sync_exchange_payload(internal_bind_addr).await?; + let peer_internal_bind_addrs = conn.sync_exchange_payload(internal_addresses).await?; log::trace!(target: "citadel", "\n~~~~~~~~~~~~\n [driver] Local NAT type: {:?}\n Peer NAT type: {:?}", local_nat_type, peer_nat_type); - log::trace!(target: "citadel", "[driver] Local internal bind addr: {internal_bind_addr:?}\nPeer internal bind addr: {peer_internal_bind_addr:?}"); + log::trace!(target: "citadel", "[driver] Local internal bind addr: {internal_bind_addr_optimal:?}\nPeer internal bind addr: {peer_internal_bind_addrs:?}"); log::trace!(target: "citadel", "\n~~~~~~~~~~~~\n"); // the next functions takes everything insofar obtained into account without causing collisions with any existing // connections (e.g., no conflicts with the primary stream existing in conn) - let hole_punch_config = HolePunchConfig::new( - peer_nat_type, - &peer_internal_bind_addr, - local_initial_socket, - ); + let hole_punch_config = HolePunchConfig::new(peer_nat_type, &peer_internal_bind_addrs, sockets); let conn = conn.clone(); log::trace!(target: "citadel", "[driver] Synchronized; will now execute dualstack hole-puncher ... config: {:?}", hole_punch_config); @@ -91,6 +127,9 @@ async fn driver( conn, )? .await; + + log::info!(target: "citadel", "Hole Punch Status: {res:?}"); + res.map_err(|err| { anyhow::Error::msg(format!( "**HOLE-PUNCH-ERR**: {err:?} | local_nat_type: {local_nat_type:?} | peer_nat_type: {peer_nat_type:?}", @@ -127,7 +166,7 @@ pub fn get_optimal_bind_socket( let local_allows_ipv6 = local_nat_info.is_ipv6_compatible(); let peer_allows_ipv6 = peer_nat_info.is_ipv6_compatible(); - // only bind to ipv6 if v6 is enabled locally, and, there both nodes have an external ipv6 addr, + // only bind to ipv6 if v6 is enabled locally, and, both nodes have an external ipv6 addr, // AND, the peer allows ipv6, then go with ipv6 if local_allows_ipv6 && local_has_an_external_ipv6_addr @@ -137,14 +176,7 @@ pub fn get_optimal_bind_socket( // bind to IN_ADDR6_ANY. Allows both conns from loopback and public internet crate::socket_helpers::get_udp_socket("[::]:0") } else { - #[cfg(not(feature = "localhost-testing"))] - { - crate::socket_helpers::get_udp_socket("0.0.0.0:0") - } - #[cfg(feature = "localhost-testing")] - { - crate::socket_helpers::get_udp_socket("127.0.0.1:0") - } + crate::socket_helpers::get_udp_socket("0.0.0.0:0") } } @@ -206,22 +238,20 @@ mod tests { log::trace!(target: "citadel", "A"); _res0 - .socket .send_to(dummy_bytes as &[u8], _res0.addr.send_address) .await .unwrap(); log::trace!(target: "citadel", "B"); let buf = &mut [0u8; 4096]; - let (len, _addr) = _res1.socket.recv_from(buf).await.unwrap(); + let (len, _addr) = _res1.recv_from(buf).await.unwrap(); //assert_eq!(res1.addr.receive_address, addr); log::trace!(target: "citadel", "C"); assert_ne!(len, 0); _res1 - .socket .send_to(dummy_bytes, _res1.addr.send_address) .await .unwrap(); - let (len, _addr) = _res0.socket.recv_from(buf).await.unwrap(); + let (len, _addr) = _res0.recv_from(buf).await.unwrap(); assert_ne!(len, 0); //assert_eq!(res0.addr.receive_address, addr); log::trace!(target: "citadel", "D"); diff --git a/docker/client/Dockerfile b/docker/client/Dockerfile index 27ccd8a87..17091810c 100644 --- a/docker/client/Dockerfile +++ b/docker/client/Dockerfile @@ -3,6 +3,6 @@ WORKDIR /usr/src/client COPY . . COPY ./docker/set_nat.sh . COPY ./docker/client/exec.sh . -RUN apt-get update --fix-missing && apt-get install -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* +RUN apt-get update --fix-missing && apt-get install --fix-missing -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* RUN cargo install --example client --path ./citadel_sdk --debug RUN ["chmod", "u+x", "exec.sh"] \ No newline at end of file diff --git a/docker/peer/Dockerfile b/docker/peer/Dockerfile index de065c729..ffa1796f5 100644 --- a/docker/peer/Dockerfile +++ b/docker/peer/Dockerfile @@ -3,6 +3,6 @@ WORKDIR /usr/src/peer COPY . . COPY ./docker/set_nat.sh . COPY ./docker/peer/exec.sh . -RUN apt-get update --fix-missing && apt-get install -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* +RUN apt-get update --fix-missing && apt-get install --fix-missing -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* RUN cargo install --example peer --path ./citadel_sdk --debug RUN ["chmod", "u+x", "exec.sh"] \ No newline at end of file diff --git a/docker/server/Dockerfile b/docker/server/Dockerfile index a7ea1237a..8dc62399e 100644 --- a/docker/server/Dockerfile +++ b/docker/server/Dockerfile @@ -3,6 +3,6 @@ WORKDIR /usr/src/server COPY . . COPY ./docker/set_nat.sh . COPY ./docker/server/exec.sh . -RUN apt-get update --fix-missing && apt-get install -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* +RUN apt-get update --fix-missing && apt-get install --fix-missing -y openssl libclang-dev build-essential cmake iptables inetutils-ping net-tools iproute2 && rm -rf /var/lib/apt/lists/* RUN cargo install --example server --path ./citadel_sdk --debug RUN ["chmod", "u+x", "exec.sh"] \ No newline at end of file diff --git a/netbeam/Cargo.toml b/netbeam/Cargo.toml index b1b617757..b9102a66d 100644 --- a/netbeam/Cargo.toml +++ b/netbeam/Cargo.toml @@ -31,7 +31,7 @@ citadel_io = { workspace = true } futures = { workspace = true, features = ["std"] } bytes = { workspace = true } async-trait = { workspace = true } -bincode2 = { workspace = true } +bincode = { workspace = true } serde = { workspace = true, features = ["derive"] } anyhow = { workspace = true } tokio-util = { workspace = true, features = ["codec"] } diff --git a/netbeam/src/reliable_conn.rs b/netbeam/src/reliable_conn.rs index e061c8951..34c8275a6 100644 --- a/netbeam/src/reliable_conn.rs +++ b/netbeam/src/reliable_conn.rs @@ -30,7 +30,7 @@ impl ReliableOrderedConnectionToTar pub trait ReliableOrderedStreamToTargetExt: ReliableOrderedStreamToTarget { async fn recv_serialized(&self) -> std::io::Result { let packet = &self.recv().await?; - Ok(bincode2::deserialize(packet) + Ok(bincode::deserialize(packet) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?) } @@ -55,7 +55,7 @@ pub trait ReliableOrderedStreamToTargetExt: ReliableOrderedStreamToTarget { } async fn send_serialized(&self, t: T) -> std::io::Result<()> { - let packet = &bincode2::serialize(&t) + let packet = &bincode::serialize(&t) .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; self.send_to_peer(packet).await } diff --git a/netbeam/src/sync/network_application.rs b/netbeam/src/sync/network_application.rs index b6bc99656..22292f429 100644 --- a/netbeam/src/sync/network_application.rs +++ b/netbeam/src/sync/network_application.rs @@ -120,7 +120,7 @@ impl MultiplexedConn { } pub async fn forward_packet(&self, packet: &[u8]) -> Result<(), anyhow::Error> { - let deserialized = bincode2::deserialize::>(packet)?; + let deserialized = bincode::deserialize::>(packet)?; match deserialized { MultiplexedPacket::ApplicationLayer { id, payload } => { let lock = self.subscriptions().read(); diff --git a/netbeam/src/sync/primitives/net_mutex.rs b/netbeam/src/sync/primitives/net_mutex.rs index de0c413ab..b4aeb4f17 100644 --- a/netbeam/src/sync/primitives/net_mutex.rs +++ b/netbeam/src/sync/primitives/net_mutex.rs @@ -82,7 +82,7 @@ impl NetMutex { passive_background_handler::(channel, shared_state, stop_rx, active_to_bg_rx) .await { - log::error!(target: "citadel", "[NetMutex Passive Background Handler] Err: {:?}", err.to_string()); + log::warn!(target: "citadel", "[NetMutex Passive Background Handler] Err: {:?}", err.to_string()); } log::trace!(target: "citadel", "[NetMutex] Passive background handler ending") @@ -150,6 +150,8 @@ impl Drop for NetMutexGuard { if let Ok(rt) = tokio::runtime::Handle::try_current() { let future = NetMutexGuardDropCode::new::(app, guard); rt.spawn(future); + } else { + log::warn!(target: "citadel", "Failed to spawn drop code for NetMutexGuard since no runtime was found"); } // if the RT is down, then we are not interested in continuing the program's synchronization @@ -240,9 +242,10 @@ async fn net_mutex_drop_code( lock: LocalLockHolder, ) -> Result<(), anyhow::Error> { log::trace!(target: "citadel", "[NetMutex] Drop code initialized for {:?}...", conn.node_type()); - conn.send_serialized(UpdatePacket::Released(bincode2::serialize( - &lock.deref().0, - )?)) + conn.send_serialized(UpdatePacket::Released( + bincode::serialize(&lock.deref().0)?, + true, + )) .await?; let mut adjacent_trying_to_acquire = false; @@ -252,10 +255,13 @@ async fn net_mutex_drop_code( log::trace!(target: "citadel", "[NetMutex] [Drop Code] RECV {:?} on {:?}", &packet, conn.node_type()); match packet { UpdatePacket::ReleasedVerified => { - log::trace!(target: "citadel", "[NetMutex] [Drop Code] Release has been verified for {:?}. Adjacent node updated; will drop local lock", conn.node_type()); + log::trace!(target: "citadel", "[NetMutex] [Drop Code] Release has been verified for {:?}. Adjacent node updated; will drop local lock. Adjacent trying to acquire? {adjacent_trying_to_acquire}", conn.node_type()); if adjacent_trying_to_acquire { - return yield_lock::(&conn, lock).await.map(|_| ()); + // Since we are holding the local lock, even if the local node tries to acquire + // the lock again, it will be blocked until the adjacent node releases the lock + // and the yield_lock subroutine finishes + return yield_lock::(&conn, lock, false).await.map(|_| ()); } return Ok(()); @@ -279,7 +285,7 @@ pub struct NetMutexGuardAcquirer<'a, T: NetObject + 'static, S: Subscribable + ' #[derive(Serialize, Deserialize, Debug)] enum UpdatePacket { TryAcquire(i64), - Released(Vec), + Released(Vec, bool), LockAcquired, Halt, ReleasedVerified, @@ -329,8 +335,8 @@ async fn net_mutex_guard_acquirer( // the adjacent side will return one of two packets. In the first case, we wait until it drops the adjacent lock, in which case, // we get a Released packet. The side that gets this will automatically be allowed to acquire the mutex lock match packet { - UpdatePacket::Released(new_data) => { - let new_data = bincode2::deserialize::(&new_data)?; + UpdatePacket::Released(new_data, _) => { + let new_data = bincode::deserialize::(&new_data)?; *value = new_data; // now, send a LockAcquired packet conn.send_serialized(UpdatePacket::LockAcquired).await?; @@ -351,22 +357,28 @@ async fn net_mutex_guard_acquirer( let local_wins = if remote_request_time == local_request_time { mutex.node_type() == RelativeNodeType::Initiator } else { - remote_request_time < local_request_time + remote_request_time > local_request_time }; - if local_wins { - // remote gets the lock. We send the local value first. Then, we must continue looping - // yield the lock - owned_local_lock = yield_lock::(conn, owned_local_lock).await?; - // the next time a conflict happens, the local node will win unconditionally since its time is lesser than the next possible adjacent request time - } else { + log::trace!(target: "citadel", "Local {:?} wins?: {} (remote time ({remote_request_time}) < local time ({local_request_time}))", mutex.node_type(), local_wins); + + return if local_wins { // we requested before the remote node; tell the remote node we took the value conn.send_serialized(UpdatePacket::LockAcquired).await?; - return Ok(NetMutexGuard { + Ok(NetMutexGuard { conn: mutex.app.clone(), guard: Some(owned_local_lock), - }); - } + }) + } else { + // remote gets the lock. We send the local value first. Then, we must continue looping + // yield the lock + owned_local_lock = yield_lock::(conn, owned_local_lock, false).await?; + log::trace!(target: "citadel", "{:?} finished yielding lock to remote, will now return the mutex to local", mutex.node_type()); + Ok(NetMutexGuard { + conn: mutex.app.clone(), + guard: Some(owned_local_lock), + }) + }; } UpdatePacket::Halt => { @@ -385,18 +397,23 @@ async fn net_mutex_guard_acquirer( async fn yield_lock( channel: &Arc>, mut lock: LocalLockHolder, + send_release: bool, ) -> Result, anyhow::Error> { - channel - .send_serialized(UpdatePacket::Released( - bincode2::serialize(&lock.deref().0).unwrap(), - )) - .await?; + if send_release { + channel + .send_serialized(UpdatePacket::Released( + bincode::serialize(&lock.deref().0).unwrap(), + false, + )) + .await?; + } loop { let next_packet = channel.recv_serialized().await?; + log::trace!(target: "citadel", "[YIELD LOCK] {:?} received packet: {:?}", channel.node_type(), &next_packet); match next_packet { - UpdatePacket::Released(new_value) => { - lock.deref_mut().0 = bincode2::deserialize(&new_value)?; + UpdatePacket::Released(new_value, _) => { + lock.deref_mut().0 = bincode::deserialize(&new_value)?; channel.send_serialized(UpdatePacket::LockAcquired).await?; channel .send_serialized(UpdatePacket::ReleasedVerified) @@ -459,11 +476,19 @@ async fn passive_background_handler( // we hold the lock locally, preventing local from sending any packets outbound from the active channel since the adjacent node is actively seeking to // establish a lock // we set "true" to the local lock holder to imply that the drop code won't alert the background (b/c we already are in BG) - yield_lock::(&channel, LocalLockHolder(Some(lock), true)).await?; + yield_lock::(&channel, LocalLockHolder(Some(lock), true), true) + .await?; // return on error } - UpdatePacket::Released(..) + UpdatePacket::Released(_, true) => { + // In the case that the local node has dropped the mutex, and, + // the remote node has released the lock thereafter, this branch + // may execute (mostly when latency is very low, e,g., on localhost-testing + continue; + } + + UpdatePacket::Released(_, false) | UpdatePacket::ReleasedVerified | UpdatePacket::LockAcquired => { unreachable!("[BG] RELEASED/RELEASED_VERIFIED/LOCK_ACQUIRED should only be received in the yield_lock subroutine."); diff --git a/netbeam/src/sync/primitives/net_rwlock.rs b/netbeam/src/sync/primitives/net_rwlock.rs index aec0fc9de..fd4e955d9 100644 --- a/netbeam/src/sync/primitives/net_rwlock.rs +++ b/netbeam/src/sync/primitives/net_rwlock.rs @@ -339,7 +339,7 @@ mod drop { } LocalLockHolder::Write(_guard, ..) => { - conn.send_serialized(UpdatePacket::ReleasedWrite(bincode2::serialize( + conn.send_serialized(UpdatePacket::ReleasedWrite(bincode::serialize( &lock.deref(), )?)) .await?; @@ -529,7 +529,7 @@ async fn yield_lock( LocalLockHolder::Write(val, _) => { channel .send_serialized(UpdatePacket::ReleasedWrite( - bincode2::serialize(&val.as_ref().unwrap().0).unwrap(), + bincode::serialize(&val.as_ref().unwrap().0).unwrap(), )) .await?; } @@ -542,7 +542,7 @@ async fn yield_lock( UpdatePacket::ReleasedWrite(new_value) => match &mut lock { LocalLockHolder::Write(val, _) => { log::trace!(target: "citadel", "Yield:: Releasing Write lock"); - val.as_mut().unwrap().0 = bincode2::deserialize(&new_value)?; + val.as_mut().unwrap().0 = bincode::deserialize(&new_value)?; channel .send_serialized(UpdatePacket::ReleasedVerified(LockType::Write)) .await?; @@ -699,7 +699,7 @@ where // we get a Released packet. The side that gets this will automatically be allowed to acquire the mutex lock match packet { UpdatePacket::ReleasedWrite(new_data) => { - let new_data = bincode2::deserialize::(&new_data)?; + let new_data = bincode::deserialize::(&new_data)?; match &mut owned_local_lock { LocalLockHolder::Write(lock, ..) => { lock.as_mut().unwrap().0 = new_data; @@ -768,13 +768,21 @@ where let local_wins = if remote_request_time == local_request_time { rwlock.node_type() == RelativeNodeType::Initiator } else { - remote_request_time < local_request_time + remote_request_time > local_request_time }; if local_wins { - // remote gets the lock. We send the local value first. Then, we must continue looping - // yield the lock + // we requested before the remote node; tell the remote node we took the value + conn.send_serialized(UpdatePacket::LockAcquired(lock_type)) + .await?; + if owned_local_lock.lock_type() == LockType::Read { + *rwlock.local_active_read_lock.write() = + Some(owned_local_lock.assert_read().clone()) + } + + return Ok((fx)(owned_local_lock)); + } else { // transform only if local wins if owned_local_lock.lock_type() != lock_type { log::trace!(target: "citadel", "Remote is trying to acquire lock type not equal to local type. Must transform. Local {:?}, Remote {:?}", owned_local_lock.lock_type(), lock_type); @@ -816,17 +824,6 @@ where log::trace!(target: "citadel", "Asserted local is write and downgraded"); } } - } else { - // we requested before the remote node; tell the remote node we took the value - conn.send_serialized(UpdatePacket::LockAcquired(lock_type)) - .await?; - - if owned_local_lock.lock_type() == LockType::Read { - *rwlock.local_active_read_lock.write() = - Some(owned_local_lock.assert_read().clone()) - } - - return Ok((fx)(owned_local_lock)); } } diff --git a/netbeam/src/sync/subscription.rs b/netbeam/src/sync/subscription.rs index 112ddb298..085e2f330 100644 --- a/netbeam/src/sync/subscription.rs +++ b/netbeam/src/sync/subscription.rs @@ -76,7 +76,7 @@ impl ReliableOrderedStreamToTarget for R { payload: input.to_vec(), }; self.conn() - .send_to_peer(&bincode2::serialize(&packet).unwrap()) + .send_to_peer(&bincode::serialize(&packet).unwrap()) .await }