diff --git a/citadel_crypt/Cargo.toml b/citadel_crypt/Cargo.toml index d30d7c117..957651d15 100644 --- a/citadel_crypt/Cargo.toml +++ b/citadel_crypt/Cargo.toml @@ -62,6 +62,7 @@ num_cpus = { workspace = true } [dev-dependencies] citadel_logging = { workspace = true } rstest = { workspace = true } +lazy_static = { workspace = true } [lib] doctest = false diff --git a/citadel_crypt/src/endpoint_crypto_container.rs b/citadel_crypt/src/endpoint_crypto_container.rs index 5abc6d55a..415ea9cdf 100644 --- a/citadel_crypt/src/endpoint_crypto_container.rs +++ b/citadel_crypt/src/endpoint_crypto_container.rs @@ -260,12 +260,17 @@ pub trait EndpointRatchetConstructor: Send + Sync + 'static { new_drill_vers: u32, opts: Vec, transfer: AliceToBobTransferType, + psks: &[Vec], ) -> Option where Self: Sized; fn stage0_alice(&self) -> Option; fn stage0_bob(&self) -> Option; - fn stage1_alice(&mut self, transfer: BobToAliceTransferType) -> Result<(), CryptError>; + fn stage1_alice( + &mut self, + transfer: BobToAliceTransferType, + psks: &[Vec], + ) -> Result<(), CryptError>; fn update_version(&mut self, version: u32) -> Option<()>; fn finish_with_custom_cid(self, cid: u64) -> Option; diff --git a/citadel_crypt/src/fcm/fcm_ratchet.rs b/citadel_crypt/src/fcm/fcm_ratchet.rs index 49c7fc1af..21938d609 100644 --- a/citadel_crypt/src/fcm/fcm_ratchet.rs +++ b/citadel_crypt/src/fcm/fcm_ratchet.rs @@ -149,10 +149,11 @@ impl EndpointRatchetConstructor for ThinRatchetConstructor { _new_drill_vers: u32, mut opts: Vec, transfer: AliceToBobTransferType, + psks: &[Vec], ) -> Option { match transfer { AliceToBobTransferType::Fcm(transfer) => { - ThinRatchetConstructor::new_bob(opts.remove(0), transfer) + ThinRatchetConstructor::new_bob(opts.remove(0), transfer, psks) } _ => { @@ -170,9 +171,13 @@ impl EndpointRatchetConstructor for ThinRatchetConstructor { Some(BobToAliceTransferType::Fcm(self.stage0_bob()?)) } - fn stage1_alice(&mut self, transfer: BobToAliceTransferType) -> Result<(), CryptError> { + fn stage1_alice( + &mut self, + transfer: BobToAliceTransferType, + psks: &[Vec], + ) -> Result<(), CryptError> { match transfer { - BobToAliceTransferType::Fcm(transfer) => self.stage1_alice(transfer), + BobToAliceTransferType::Fcm(transfer) => self.stage1_alice(transfer, psks), _ => Err(CryptError::DrillUpdateError( "Incompatible Ratchet Type passed! [X-44]".to_string(), @@ -228,9 +233,13 @@ impl ThinRatchetConstructor { } /// - pub fn new_bob(opts: ConstructorOpts, transfer: FcmAliceToBobTransfer) -> Option { + pub fn new_bob( + opts: ConstructorOpts, + transfer: FcmAliceToBobTransfer, + psks: &[Vec], + ) -> Option { let params = transfer.params; - let pqc = PostQuantumContainer::new_bob(opts, transfer.transfer_params).ok()?; + let pqc = PostQuantumContainer::new_bob(opts, transfer.transfer_params, psks).ok()?; let drill = EntropyBank::new(transfer.cid, transfer.version, params.encryption_algorithm).ok()?; @@ -268,9 +277,13 @@ impl ThinRatchetConstructor { } /// - pub fn stage1_alice(&mut self, transfer: FcmBobToAliceTransfer) -> Result<(), CryptError> { + pub fn stage1_alice( + &mut self, + transfer: FcmBobToAliceTransfer, + psks: &[Vec], + ) -> Result<(), CryptError> { self.pqc - .alice_on_receive_ciphertext(transfer.params_tx) + .alice_on_receive_ciphertext(transfer.params_tx, psks) .map_err(|err| CryptError::DrillUpdateError(err.to_string()))?; let bytes = self .pqc diff --git a/citadel_crypt/src/stacked_ratchet.rs b/citadel_crypt/src/stacked_ratchet.rs index 06062fe70..0c855582a 100644 --- a/citadel_crypt/src/stacked_ratchet.rs +++ b/citadel_crypt/src/stacked_ratchet.rs @@ -434,11 +434,14 @@ pub mod constructor { } impl ConstructorType { - pub fn stage1_alice(&mut self, transfer: BobToAliceTransferType) -> Result<(), CryptError> { + pub fn stage1_alice( + &mut self, + transfer: BobToAliceTransferType, + psks: &[Vec], + ) -> Result<(), CryptError> { match self { - ConstructorType::Default(con) => con.stage1_alice(transfer), - - ConstructorType::Fcm(con) => con.stage1_alice(transfer), + ConstructorType::Default(con) => con.stage1_alice(transfer, psks), + ConstructorType::Fcm(con) => con.stage1_alice(transfer, psks), } } @@ -510,10 +513,11 @@ pub mod constructor { new_drill_vers: u32, opts: Vec, transfer: AliceToBobTransferType, + psks: &[Vec], ) -> Option { match transfer { AliceToBobTransferType::Default(transfer) => { - StackedRatchetConstructor::new_bob(cid, new_drill_vers, opts, transfer) + StackedRatchetConstructor::new_bob(cid, new_drill_vers, opts, transfer, psks) } _ => { @@ -531,8 +535,12 @@ pub mod constructor { Some(BobToAliceTransferType::Default(self.stage0_bob()?)) } - fn stage1_alice(&mut self, transfer: BobToAliceTransferType) -> Result<(), CryptError> { - self.stage1_alice(transfer) + fn stage1_alice( + &mut self, + transfer: BobToAliceTransferType, + psks: &[Vec], + ) -> Result<(), CryptError> { + self.stage1_alice(transfer, psks) } fn update_version(&mut self, version: u32) -> Option<()> { @@ -665,6 +673,7 @@ pub mod constructor { new_drill_vers: u32, opts: Vec, transfer: AliceToBobTransfer, + psks: &[Vec], ) -> Option { log::trace!(target: "citadel", "[BOB] creating container with {:?} security level", transfer.security_level); let count = transfer.security_level.value() as usize + 1; @@ -679,7 +688,7 @@ pub mod constructor { EntropyBank::new(cid, new_drill_vers, params.encryption_algorithm) .ok()?, ), - pqc: PostQuantumContainer::new_bob(opts, params_tx).ok()?, + pqc: PostQuantumContainer::new_bob(opts, params_tx, psks).ok()?, }) }) .collect(); @@ -699,6 +708,7 @@ pub mod constructor { pqc: PostQuantumContainer::new_bob( ConstructorOpts::new_init(Some(params)), transfer.scramble_alice_params, + psks, ) .ok()?, }, @@ -798,7 +808,11 @@ pub mod constructor { } /// Returns Some(()) if process succeeded - pub fn stage1_alice(&mut self, transfer: BobToAliceTransferType) -> Result<(), CryptError> { + pub fn stage1_alice( + &mut self, + transfer: BobToAliceTransferType, + psks: &[Vec], + ) -> Result<(), CryptError> { if let BobToAliceTransferType::Default(transfer) = transfer { let nonce_msg = &self.nonce_message; @@ -810,7 +824,7 @@ pub mod constructor { { container .pqc - .alice_on_receive_ciphertext(bob_param_tx) + .alice_on_receive_ciphertext(bob_param_tx, psks) .map_err(|err| CryptError::DrillUpdateError(err.to_string()))?; } @@ -834,7 +848,7 @@ pub mod constructor { let nonce_scramble = &self.nonce_scramble; self.scramble .pqc - .alice_on_receive_ciphertext(transfer.scramble_bob_params_tx) + .alice_on_receive_ciphertext(transfer.scramble_bob_params_tx, psks) .map_err(|err| CryptError::DrillUpdateError(err.to_string()))?; // do the same as above let decrypted_scramble_drill = self diff --git a/citadel_crypt/tests/primary.rs b/citadel_crypt/tests/primary.rs index 9e9054e57..49d021c6a 100644 --- a/citadel_crypt/tests/primary.rs +++ b/citadel_crypt/tests/primary.rs @@ -21,6 +21,11 @@ mod tests { #[cfg(not(target_family = "wasm"))] use std::path::PathBuf; + lazy_static::lazy_static! { + pub static ref PRE_SHARED_KEYS: Vec> = vec!["Hello".into(), "World".into()]; + pub static ref PRE_SHARED_KEYS2: Vec> = vec!["World".into(), "Hello".into()]; + } + #[cfg(not(target_family = "wasm"))] #[tokio::test] async fn argon_autotuner() { @@ -217,11 +222,15 @@ mod tests { KemAlgorithm::from_u8(x).unwrap() + EncryptionAlgorithm::AES_GCM_256, Some(sec.into()), false, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ); let _ = hyper_ratchet::( KemAlgorithm::from_u8(x).unwrap() + EncryptionAlgorithm::ChaCha20Poly_1305, Some(sec.into()), false, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ); } } @@ -236,11 +245,15 @@ mod tests { KemAlgorithm::from_u8(x).unwrap() + EncryptionAlgorithm::AES_GCM_256, Some(sec.into()), true, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ); let _ = hyper_ratchet::( KemAlgorithm::from_u8(x).unwrap() + EncryptionAlgorithm::ChaCha20Poly_1305, Some(sec.into()), true, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ); } } @@ -254,6 +267,8 @@ mod tests { KemAlgorithm::Kyber + EncryptionAlgorithm::AES_GCM_256, Some(sec.into()), false, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ); for x in 0..sec { assert!(ratchet.verify_level(Some(x.into())).is_ok()) @@ -269,6 +284,8 @@ mod tests { algorithm: Z, security_level: Option, is_fcm: bool, + bob_psks: &[Vec], + alice_psks: &[Vec], ) -> R { let algorithm = algorithm.into(); log::trace!(target: "citadel", "Using {:?} with {:?} @ {:?} security level | is FCM: {}", algorithm.kem_algorithm, algorithm.encryption_algorithm, security_level, is_fcm); @@ -288,11 +305,14 @@ mod tests { 0, ConstructorOpts::new_vec_init(algorithm, count), transfer, + bob_psks, ) .unwrap(); let transfer = bob_hyper_ratchet.stage0_bob().unwrap(); - alice_hyper_ratchet.stage1_alice(transfer).unwrap(); + alice_hyper_ratchet + .stage1_alice(transfer, alice_psks) + .unwrap(); let alice_hyper_ratchet = alice_hyper_ratchet.finish().unwrap(); let bob_hyper_ratchet = bob_hyper_ratchet.finish().unwrap(); @@ -362,13 +382,30 @@ mod tests { const COUNT: u32 = 100; let security_level = SecurityLevel::Standard; - let (alice, _bob) = gen::(0, 0, security_level, enx + kem + sig); + let (alice, _bob) = gen::( + 0, + 0, + security_level, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let mut toolset = Toolset::new(0, alice); for x in 1..COUNT { let res = toolset - .update_from(gen::(0, x, security_level, enx + kem + sig).0) + .update_from( + gen::( + 0, + x, + security_level, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .0, + ) .unwrap(); match res { UpdateStatus::Committed { .. } => { @@ -399,7 +436,17 @@ mod tests { } let _res = toolset - .update_from(gen::(0, COUNT, security_level, enx + kem + sig).0) + .update_from( + gen::( + 0, + COUNT, + security_level, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .0, + ) .unwrap(); assert_eq!(toolset.len(), MAX_HYPER_RATCHETS_IN_MEMORY + 1); assert_eq!( @@ -421,6 +468,8 @@ mod tests { version: u32, sec: SecurityLevel, algorithm: CryptoParameters, + bob_psks: &[Vec], + alice_psks: &[Vec], ) -> (R, R) { let count = sec.value() as usize + 1; let mut alice = R::Constructor::new_alice( @@ -435,10 +484,11 @@ mod tests { version, ConstructorOpts::new_vec_init(Some(algorithm), count), alice.stage0_alice().unwrap(), + bob_psks, ) .unwrap(); let stage0_bob = bob.stage0_bob().unwrap(); - alice.stage1_alice(stage0_bob).unwrap(); + alice.stage1_alice(stage0_bob, alice_psks).unwrap(); (alice.finish().unwrap(), bob.finish().unwrap()) } @@ -481,7 +531,14 @@ mod tests { citadel_logging::setup_log(); let vers = u32::MAX - 1; let cid = 10; - let hr = gen::(cid, vers, SecurityLevel::Standard, enx + kem + sig); + let hr = gen::( + cid, + vers, + SecurityLevel::Standard, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let mut toolset = Toolset::new_debug(cid, hr.0, vers, vers); let r = toolset.get_hyper_ratchet(vers).unwrap(); assert_eq!(r.version(), vers); @@ -495,7 +552,17 @@ mod tests { } toolset - .update_from(gen::(cid, cur_vers, SecurityLevel::Standard, enx + kem + sig).0) + .update_from( + gen::( + cid, + cur_vers, + SecurityLevel::Standard, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .0, + ) .unwrap(); let ratchet = toolset.get_hyper_ratchet(cur_vers).unwrap(); assert_eq!(ratchet.version(), cur_vers); @@ -630,9 +697,22 @@ mod tests { const HEADER_SIZE_BYTES: usize = 44; let mut data = BytesMut::with_capacity(1500); - let (ratchet_alice, ratchet_bob) = gen::(10, 0, SECURITY_LEVEL, enx + kem + sig); - let (pseudo_static_aux_ratchet_alice, pseudo_static_aux_ratchet_bob) = - gen::(10, 0, SECURITY_LEVEL, enx + kem + sig); + let (ratchet_alice, ratchet_bob) = gen::( + 10, + 0, + SECURITY_LEVEL, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); + let (pseudo_static_aux_ratchet_alice, pseudo_static_aux_ratchet_bob) = gen::( + 10, + 0, + SECURITY_LEVEL, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); for x in 0..1500_usize { if x != 0 { @@ -792,9 +872,22 @@ mod tests { use citadel_crypt::streaming_crypt_scrambler::scramble_encrypt_source; - let (alice, bob) = gen::(0, 0, security_level, enx + kem + sig); - let (pseudo_static_aux_ratchet_alice, pseudo_static_aux_ratchet_bob) = - gen::(0, 0, security_level, enx + kem + sig); + let (alice, bob) = gen::( + 0, + 0, + security_level, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); + let (pseudo_static_aux_ratchet_alice, pseudo_static_aux_ratchet_bob) = gen::( + 0, + 0, + security_level, + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let cmp = include_bytes!("../../resources/TheBridge.pdf"); let source = PathBuf::from("../resources/TheBridge.pdf"); @@ -901,6 +994,31 @@ mod tests { }); } + #[should_panic(expected = "EncryptionFailure")] + #[rstest] + #[case( + EncryptionAlgorithm::AES_GCM_256, + KemAlgorithm::Kyber, + SigAlgorithm::None + )] + fn test_drill_encrypt_decrypt_basic_bad_psks( + #[case] enx: EncryptionAlgorithm, + #[case] kem: KemAlgorithm, + #[case] sig: SigAlgorithm, + ) { + citadel_logging::setup_log_no_panic_hook(); + test_harness_with_psks( + enx + kem + sig, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS2, + |alice, bob, _, data| { + let encrypted = alice.encrypt(data).unwrap(); + let decrypted = bob.decrypt(encrypted).unwrap(); + assert_eq!(decrypted, data); + }, + ); + } + #[rstest] #[case( EncryptionAlgorithm::AES_GCM_256, @@ -973,14 +1091,24 @@ mod tests { fn test_harness( params: CryptoParameters, fx: impl Fn(&StackedRatchet, &StackedRatchet, SecurityLevel, &[u8]), + ) { + test_harness_with_psks(params, &PRE_SHARED_KEYS, &PRE_SHARED_KEYS, fx); + } + + fn test_harness_with_psks( + params: CryptoParameters, + bob_psks: &[Vec], + alice_psks: &[Vec], + fx: impl Fn(&StackedRatchet, &StackedRatchet, SecurityLevel, &[u8]), ) { let data = Vec::from("Hello, world!"); for sec in 0..5 { let security_level = SecurityLevel::from(sec); - let (hr_alice, hr_bob) = gen::(0, 0, security_level, params); + let (hr_alice, hr_bob) = + gen::(0, 0, security_level, params, bob_psks, alice_psks); for idx in 0..data.len() { - (fx)(&hr_alice, &hr_bob, security_level, &data[..idx]); + fx(&hr_alice, &hr_bob, security_level, &data[..idx]); } } } diff --git a/citadel_pqcrypto/Cargo.toml b/citadel_pqcrypto/Cargo.toml index 8561718e8..8c0b11eeb 100644 --- a/citadel_pqcrypto/Cargo.toml +++ b/citadel_pqcrypto/Cargo.toml @@ -59,6 +59,7 @@ pqcrypto-traits-wasi = { workspace = true } [dev-dependencies] citadel_logging = { workspace = true } +lazy_static = { workspace = true } [lib] doctest = false diff --git a/citadel_pqcrypto/src/lib.rs b/citadel_pqcrypto/src/lib.rs index 68a4fece0..ea2b3575d 100644 --- a/citadel_pqcrypto/src/lib.rs +++ b/citadel_pqcrypto/src/lib.rs @@ -17,7 +17,6 @@ use serde::{Deserialize, Serialize}; use sha3::Digest; use std::fmt::Debug; use std::fmt::Formatter; -use std::ops::Deref; use std::sync::Arc; use zeroize::Zeroizing; @@ -261,6 +260,7 @@ impl PostQuantumContainer { pub fn new_bob( opts: ConstructorOpts, tx_params: AliceToBobTransferParameters, + psks: &[Vec], ) -> Result { let pq_node = PQNode::Bob; let params = opts.cryptography.unwrap_or_default(); @@ -275,7 +275,7 @@ impl PostQuantumContainer { let sig = data.sig().cloned(); let (chain, keys) = - Self::generate_recursive_keystore(pq_node, params, sig, ss, chain.as_ref(), kex) + Self::generate_recursive_keystore(pq_node, params, sig, ss, chain.as_ref(), kex, psks) .map_err(|err| { Error::Other(format!("Error while calculating recursive keystore: {err}",)) })?; @@ -293,6 +293,7 @@ impl PostQuantumContainer { }) } + /// `psks`: Pre-shared keys fn generate_recursive_keystore( pq_node: PQNode, params: CryptoParameters, @@ -300,10 +301,11 @@ impl PostQuantumContainer { ss: Arc>>, previous_chain: Option<&RecursiveChain>, kex: PostQuantumMetaKex, + psks: &[Vec], ) -> Result<(RecursiveChain, KeyStore), Error> { let (chain, alice_key, bob_key) = if let Some(prev) = previous_chain { // prev = C_n - // If a previous key, S_n, existed, we calculate S_(n+1)' = KDF(C_n || S_n)) + // If a previous key, S_n, existed, we calculate S_(n+1)' = KDF(C_n || S_n || psks)) let mut hasher_temp = sha3::Sha3_512::new(); let mut hasher_alice = sha3::Sha3_256::new(); let mut hasher_bob = sha3::Sha3_256::new(); @@ -312,7 +314,8 @@ impl PostQuantumContainer { .chain .iter() .chain(ss.iter()) - .cloned() + .chain(psks.iter().flatten()) + .copied() .collect::>()[..], ); @@ -372,7 +375,12 @@ impl PostQuantumContainer { } else { // The first key, S_0', = KDF(S_0) let mut hasher_temp = sha3::Sha3_512::new(); - hasher_temp.update(ss.deref()); + hasher_temp.update( + ss.iter() + .chain(psks.iter().flatten()) + .copied() + .collect::>(), + ); let temp_key = hasher_temp.finalize(); let (alice_key, bob_key) = temp_key.as_slice().split_at(32); @@ -446,7 +454,7 @@ impl PostQuantumContainer { } /// This should always be called after deserialization - fn load_symmetric_keys(&mut self) -> Result<(), Error> { + fn load_symmetric_keys(&mut self, psks: &[Vec]) -> Result<(), Error> { let pq_node = self.node; let params = self.params; let sig = self.data.sig().cloned(); @@ -454,8 +462,15 @@ impl PostQuantumContainer { let kex = self.data.kex().clone(); let prev_symmetric_key = self.chain.as_ref(); - let (chain, key) = - Self::generate_recursive_keystore(pq_node, params, sig, ss, prev_symmetric_key, kex)?; + let (chain, key) = Self::generate_recursive_keystore( + pq_node, + params, + sig, + ss, + prev_symmetric_key, + kex, + psks, + )?; self.key_store = Some(key); self.chain = Some(chain); @@ -467,10 +482,11 @@ impl PostQuantumContainer { pub fn alice_on_receive_ciphertext( &mut self, params: BobToAliceTransferParameters, + psks: &[Vec], ) -> Result<(), Error> { self.data.alice_on_receive_ciphertext(params)?; let _ss = self.data.get_shared_secret()?; // call once to load internally - self.load_symmetric_keys() + self.load_symmetric_keys(psks) } /// Returns true if either Tx/Rx Anti-replay attack counters have been engaged (useful for determining diff --git a/citadel_pqcrypto/tests/primary.rs b/citadel_pqcrypto/tests/primary.rs index 38e72c423..60b1aae9d 100644 --- a/citadel_pqcrypto/tests/primary.rs +++ b/citadel_pqcrypto/tests/primary.rs @@ -17,10 +17,17 @@ mod tests { use std::fmt::Debug; use std::iter::FromIterator; + lazy_static::lazy_static! { + pub static ref PRE_SHARED_KEYS: Vec> = vec!["Hello".into(), "World".into()]; + pub static ref PRE_SHARED_KEYS2: Vec> = vec!["World".into(), "Hello".into()]; + } + fn gen( kem_algorithm: KemAlgorithm, encryption_algorithm: EncryptionAlgorithm, sig_alg: SigAlgorithm, + bob_psks: &[Vec], + alice_psks: &[Vec], ) -> (PostQuantumContainer, PostQuantumContainer) { log::trace!(target: "citadel", "Test algorithm {:?} w/ {:?}", kem_algorithm, encryption_algorithm); let mut alice_container = PostQuantumContainer::new_alice(ConstructorOpts::new_init(Some( @@ -32,32 +39,51 @@ mod tests { let bob_container = PostQuantumContainer::new_bob( ConstructorOpts::new_init(Some(kem_algorithm + encryption_algorithm + sig_alg)), tx_params, + bob_psks, ) .unwrap(); let tx_params = bob_container.generate_bob_to_alice_transfer().unwrap(); alice_container - .alice_on_receive_ciphertext(tx_params) + .alice_on_receive_ciphertext(tx_params, alice_psks) .unwrap(); (alice_container, bob_container) } #[test] fn runit() { - run(0, EncryptionAlgorithm::AES_GCM_256, SigAlgorithm::None).unwrap(); + run( + 0, + EncryptionAlgorithm::AES_GCM_256, + SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .unwrap(); run( 0, EncryptionAlgorithm::ChaCha20Poly_1305, SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .unwrap(); + run( + 0, + EncryptionAlgorithm::Ascon80pq, + SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ) .unwrap(); - run(0, EncryptionAlgorithm::Ascon80pq, SigAlgorithm::None).unwrap(); } fn run( algorithm: u8, encryption_algorithm: EncryptionAlgorithm, signature_algorithm: SigAlgorithm, + bob_psk: &[Vec], + alice_psk: &[Vec], ) -> Result<(), Box> { let kem_algorithm = KemAlgorithm::from_u8(algorithm).unwrap(); log::trace!(target: "citadel", "Test: {:?} w/ {:?} w/ {:?}", kem_algorithm, encryption_algorithm, signature_algorithm); @@ -75,12 +101,14 @@ mod tests { kem_algorithm + encryption_algorithm + signature_algorithm, )), tx_params.clone(), + bob_psk, )?; let eve_container = PostQuantumContainer::new_bob( ConstructorOpts::new_init(Some( kem_algorithm + encryption_algorithm + signature_algorithm, )), tx_params, + bob_psk, )?; // Internally, this computes the CipherText. The next step is to send this CipherText back over to alice let bob_ciphertext = bob_container.get_ciphertext().unwrap(); @@ -90,7 +118,7 @@ mod tests { // Next, alice received Bob's ciphertext. She must now run an update on her internal data in order to get the shared secret let tx_params = bob_container.generate_bob_to_alice_transfer().unwrap(); alice_container - .alice_on_receive_ciphertext(tx_params) + .alice_on_receive_ciphertext(tx_params, alice_psk) .unwrap(); let alice_ss = alice_container.get_shared_secret().unwrap(); @@ -163,8 +191,13 @@ mod tests { let encryption_algorithm = EncryptionAlgorithm::AES_GCM_256; let signature_algorithm = SigAlgorithm::None; - let (alice_container, bob_container) = - gen(kem_algorithm, encryption_algorithm, signature_algorithm); + let (alice_container, bob_container) = gen( + kem_algorithm, + encryption_algorithm, + signature_algorithm, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); for x in 0..256 { run_protection::>(&alice_container, &bob_container, HEADER_LEN, x); @@ -180,8 +213,13 @@ mod tests { let encryption_algorithm = EncryptionAlgorithm::Kyber; let signature_algorithm = SigAlgorithm::Falcon1024; - let (alice_container, bob_container) = - gen(kem_algorithm, encryption_algorithm, signature_algorithm); + let (alice_container, bob_container) = gen( + kem_algorithm, + encryption_algorithm, + signature_algorithm, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); for x in 0..256 { run_protection::>(&alice_container, &bob_container, HEADER_LEN, x); @@ -234,8 +272,13 @@ mod tests { let signature_algorithm = SigAlgorithm::None; let nonce_len = encryption_algorithm.nonce_len() as u8; - let (alice_container, bob_container) = - gen(kem_algorithm, encryption_algorithm, signature_algorithm); + let (alice_container, bob_container) = gen( + kem_algorithm, + encryption_algorithm, + signature_algorithm, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let mut zeroth = Vec::::default(); let mut zeroth_nonce = Vec::::from_iter(0..nonce_len); for y in 0..(HISTORY_LEN + 10) { @@ -308,8 +351,13 @@ mod tests { let encryption_algorithm = EncryptionAlgorithm::AES_GCM_256; let signature_algorithm = SigAlgorithm::None; let nonce_len = encryption_algorithm.nonce_len(); - let (alice_container, bob_container) = - gen(kem_algorithm, encryption_algorithm, signature_algorithm); + let (alice_container, bob_container) = gen( + kem_algorithm, + encryption_algorithm, + signature_algorithm, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let mut packet0 = (0..TOTAL_LEN as u8).collect::>(); let nonce = Vec::from_iter(0..nonce_len as u8); @@ -344,18 +392,24 @@ mod tests { algorithm.as_u8(), EncryptionAlgorithm::AES_GCM_256, SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ) .unwrap(); run( algorithm.as_u8(), EncryptionAlgorithm::ChaCha20Poly_1305, SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ) .unwrap(); run( algorithm.as_u8(), EncryptionAlgorithm::Ascon80pq, SigAlgorithm::None, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ) .unwrap(); if algorithm == KemAlgorithm::Kyber { @@ -363,6 +417,8 @@ mod tests { algorithm.as_u8(), EncryptionAlgorithm::Kyber, SigAlgorithm::Falcon1024, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, ) .unwrap(); } @@ -376,6 +432,22 @@ mod tests { KemAlgorithm::Kyber.as_u8(), EncryptionAlgorithm::Kyber, SigAlgorithm::Falcon1024, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ) + .unwrap() + } + + #[should_panic] + #[test] + fn test_kyber_bad_psks() { + citadel_logging::setup_log_no_panic_hook(); + run( + KemAlgorithm::Kyber.as_u8(), + EncryptionAlgorithm::AES_GCM_256, + SigAlgorithm::Falcon1024, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS2, ) .unwrap() } @@ -401,8 +473,13 @@ mod tests { let kem_algorithm = KemAlgorithm::Kyber; let encryption_algorithm = EncryptionAlgorithm::AES_GCM_256; let signature_algorithm = SigAlgorithm::None; - let (alice_container, bob_container) = - gen(kem_algorithm, encryption_algorithm, signature_algorithm); + let (alice_container, bob_container) = gen( + kem_algorithm, + encryption_algorithm, + signature_algorithm, + &PRE_SHARED_KEYS, + &PRE_SHARED_KEYS, + ); let nonce = &mut [0u8; 12]; ThreadRng::default().fill_bytes(nonce); diff --git a/citadel_proto/Cargo.toml b/citadel_proto/Cargo.toml index 5353008c4..bd3a8f2dd 100644 --- a/citadel_proto/Cargo.toml +++ b/citadel_proto/Cargo.toml @@ -53,7 +53,7 @@ futures = { workspace = true } log = { workspace = true } async-trait = { workspace = true } tokio-util = { workspace = true, features = ["net", "codec", "time", "io"] } -tokio = { workspace = true } +tokio = { workspace = true, features = ["parking_lot"] } auto_impl = { workspace = true } tokio-stream = { workspace = true } zerocopy = { workspace = true, features = ["byteorder", "derive"] } diff --git a/citadel_proto/src/kernel/kernel_executor.rs b/citadel_proto/src/kernel/kernel_executor.rs index 8fa74f42b..c8382a78a 100644 --- a/citadel_proto/src/kernel/kernel_executor.rs +++ b/citadel_proto/src/kernel/kernel_executor.rs @@ -47,6 +47,7 @@ impl KernelExecutor { client_config, kernel_executor_settings, stun_servers, + server_only_session_password, } = args; let (server_to_kernel_tx, server_to_kernel_rx) = unbounded(); let (server_shutdown_alerter_tx, server_shutdown_alerter_rx) = @@ -60,6 +61,7 @@ impl KernelExecutor { underlying_proto, client_config, stun_servers, + server_only_session_password, ) .await .map_err(|err| NetworkError::Generic(err.to_string()))?; diff --git a/citadel_proto/src/kernel/mod.rs b/citadel_proto/src/kernel/mod.rs index e5a1b260c..97074ede5 100644 --- a/citadel_proto/src/kernel/mod.rs +++ b/citadel_proto/src/kernel/mod.rs @@ -7,7 +7,7 @@ use tokio::runtime::Handle; use crate::error::NetworkError; use crate::macros::ContextRequirements; -use crate::prelude::ServerUnderlyingProtocol; +use crate::prelude::{PreSharedKey, ServerUnderlyingProtocol}; /// for handling easy asynchronous callbacks pub mod kernel_communicator; @@ -44,4 +44,5 @@ pub struct KernelExecutorArguments { pub client_config: Option>, pub kernel_executor_settings: KernelExecutorSettings, pub stun_servers: Option>, + pub server_only_session_password: Option, } diff --git a/citadel_proto/src/proto/misc/dual_cell.rs b/citadel_proto/src/proto/misc/dual_cell.rs index d84c58c92..7c621c0bb 100644 --- a/citadel_proto/src/proto/misc/dual_cell.rs +++ b/citadel_proto/src/proto/misc/dual_cell.rs @@ -15,7 +15,7 @@ impl DualCell { { #[cfg(not(feature = "multi-threaded"))] { - let _ = self.inner.set(new); + self.inner.set(new); } #[cfg(feature = "multi-threaded")] { diff --git a/citadel_proto/src/proto/node.rs b/citadel_proto/src/proto/node.rs index b6e786c22..7e40c4a7d 100644 --- a/citadel_proto/src/proto/node.rs +++ b/citadel_proto/src/proto/node.rs @@ -19,7 +19,7 @@ use crate::error::NetworkError; use crate::functional::PairMap; use crate::kernel::kernel_communicator::KernelAsyncCallbackHandler; use crate::kernel::RuntimeFuture; -use crate::prelude::{DeleteObject, PullObject}; +use crate::prelude::{DeleteObject, PreSharedKey, PullObject}; use crate::proto::misc::net::{ DualListener, FirstPacket, GenericNetworkListener, GenericNetworkStream, TlsListener, }; @@ -58,10 +58,14 @@ pub struct NodeInner { nat_type: NatType, // for TLS params client_config: Arc, + // All connecting/registering clients must present this pre-shared password in order to register and connect + // to the server. This is an additional security measure to prevent unauthorized connections. + server_only_c2s_session_password: PreSharedKey, } impl Node { /// Creates a new [`Node`] + #[allow(clippy::too_many_arguments)] pub(crate) async fn init( local_node_type: NodeType, to_kernel: UnboundedSender, @@ -70,6 +74,7 @@ impl Node { underlying_proto: ServerUnderlyingProtocol, client_config: Option>, stun_servers: Option>, + server_only_c2s_session_password: Option, ) -> io::Result<( NodeRemote, Pin>, @@ -124,6 +129,7 @@ impl Node { session_manager, nat_type, client_config, + server_only_c2s_session_password: server_only_c2s_session_password.unwrap_or_default(), }; let this = Self::from(inner); @@ -173,6 +179,7 @@ impl Node { .session_manager .load_server_remote_get_tt(remote.clone()); let session_manager = read.session_manager.clone(); + let server_only_session_password = read.server_only_c2s_session_password.clone(); drop(read); @@ -195,6 +202,7 @@ impl Node { this.clone(), tt, kernel_tx.clone(), + server_only_session_password, session_spawner_tx.clone(), )) } else { @@ -224,6 +232,7 @@ impl Node { this.clone(), tt, kernel_tx.clone(), + server_only_session_password, session_spawner_tx.clone(), )) } else { @@ -593,6 +602,7 @@ impl Node { server: Node, _tt: TimeTracker, to_kernel: UnboundedSender, + server_only_session_password: PreSharedKey, session_spawner: UnboundedSender>>, ) -> Result<(), NetworkError> { let primary_port_future = { @@ -600,12 +610,13 @@ impl Node { let listener = this.primary_socket.take().unwrap(); let session_manager = this.session_manager.clone(); let local_nat_type = this.nat_type.clone(); - std::mem::drop(this); + drop(this); Self::primary_session_creator_loop( to_kernel, local_nat_type, session_manager, listener, + server_only_session_password, session_spawner, ) }; @@ -618,6 +629,7 @@ impl Node { local_nat_type: NatType, session_manager: HdpSessionManager, mut socket: DualListener, + server_session_password: PreSharedKey, session_spawner: UnboundedSender>>, ) -> Result<(), NetworkError> { loop { @@ -633,6 +645,7 @@ impl Node { local_nat_type.clone(), peer_addr, stream, + server_session_password.clone(), ) { Ok(session) => { session_spawner @@ -739,6 +752,7 @@ impl Node { remote_addr: peer_addr, proposed_credentials: credentials, static_security_settings: security_settings, + session_password, }) => { match session_manager .initiate_connection( @@ -752,6 +766,7 @@ impl Node { None, security_settings, &default_client_config, + session_password, ) .await { @@ -773,6 +788,7 @@ impl Node { udp_mode, keep_alive_timeout, session_security_settings: security_settings, + session_password, }) => { match session_manager .initiate_connection( @@ -786,6 +802,7 @@ impl Node { keep_alive_timeout.map(|val| (val as i64) * 1_000_000_000), security_settings, &default_client_config, + session_password, ) .await { diff --git a/citadel_proto/src/proto/node_request.rs b/citadel_proto/src/proto/node_request.rs index 31c759aad..3f3f977c2 100644 --- a/citadel_proto/src/proto/node_request.rs +++ b/citadel_proto/src/proto/node_request.rs @@ -1,11 +1,13 @@ use crate::auth::AuthenticationRequest; use crate::prelude::{GroupBroadcast, PeerSignal, VirtualTargetType}; use crate::proto::state_container::VirtualConnectionType; +use crate::re_imports::openssl::sha::sha256; use citadel_crypt::streaming_crypt_scrambler::ObjectSource; use citadel_types::crypto::SecurityLevel; use citadel_types::proto::TransferType; use citadel_types::proto::{ConnectMode, SessionSecuritySettings, UdpMode}; use citadel_user::auth::proposed_credentials::ProposedCredentials; +use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Formatter}; use std::net::SocketAddr; use std::path::PathBuf; @@ -14,6 +16,8 @@ pub struct RegisterToHypernode { pub remote_addr: SocketAddr, pub proposed_credentials: ProposedCredentials, pub static_security_settings: SessionSecuritySettings, + // Some servers require a password in order to register and connect. By default, it is empty. + pub session_password: PreSharedKey, } pub struct PeerCommand { @@ -31,6 +35,7 @@ pub struct ConnectToHypernode { pub connect_mode: ConnectMode, pub udp_mode: UdpMode, pub keep_alive_timeout: Option, + pub session_password: PreSharedKey, pub session_security_settings: SessionSecuritySettings, } @@ -130,3 +135,31 @@ impl Debug for NodeRequest { write!(f, "NodeRequest") } } + +#[derive(Default, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct PreSharedKey { + passwords: Vec>, +} + +impl PreSharedKey { + /// Adds a password to the session password list. Both connecting nodes + /// must have matching passwords in order to establish a connection. + /// + /// Note: The password is hashed using SHA-256 before being added to the list to increase security. + pub fn add_password>(mut self, password: T) -> Self { + self.passwords.push(sha256(password.as_ref()).to_vec()); + self + } +} + +impl AsRef<[Vec]> for PreSharedKey { + fn as_ref(&self) -> &[Vec] { + &self.passwords + } +} + +impl> From for PreSharedKey { + fn from(password: T) -> Self { + PreSharedKey::default().add_password(password) + } +} diff --git a/citadel_proto/src/proto/packet_processor/peer/mod.rs b/citadel_proto/src/proto/packet_processor/peer/mod.rs index 5cedf2614..bb0d8e290 100644 --- a/citadel_proto/src/proto/packet_processor/peer/mod.rs +++ b/citadel_proto/src/proto/packet_processor/peer/mod.rs @@ -1,3 +1,7 @@ +use crate::error::NetworkError; +use crate::prelude::{ConnectFail, NodeResult, Ticket}; +use crate::proto::session::HdpSession; + /// pub mod group_broadcast; /// @@ -6,3 +10,20 @@ pub mod peer_cmd_packet; pub mod server; /// pub mod signal_handler_interface; + +pub(crate) fn send_dc_signal_peer>( + session: &HdpSession, + ticket: Ticket, + err: T, +) -> Result<(), NetworkError> { + let implicated_cid = session.implicated_cid.get().expect("Should exist"); + session + .send_to_kernel(NodeResult::ConnectFail(ConnectFail { + ticket, + cid_opt: Some(implicated_cid), + error_message: err.into(), + })) + .map_err(|err| NetworkError::Generic(err.to_string()))?; + + Ok(()) +} diff --git a/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs b/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs index c0ee8af5f..8f4d12798 100644 --- a/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs +++ b/citadel_proto/src/proto/packet_processor/peer/peer_cmd_packet.rs @@ -18,7 +18,7 @@ use crate::error::NetworkError; use crate::proto::node_result::{PeerChannelCreated, PeerEvent}; use crate::proto::outbound_sender::OutboundPrimaryStreamSender; use crate::proto::packet_processor::includes::*; -use crate::proto::packet_processor::peer::group_broadcast; +use crate::proto::packet_processor::peer::{group_broadcast, send_dc_signal_peer}; use crate::proto::packet_processor::preconnect_packet::{ calculate_sync_time, generate_hole_punch_crypt_container, }; @@ -29,7 +29,7 @@ use crate::proto::peer::hole_punch_compat_sink_stream::ReliableOrderedCompatStre use crate::proto::peer::p2p_conn_handler::attempt_simultaneous_hole_punch; use crate::proto::peer::peer_crypt::{KeyExchangeProcess, PeerNatInfo}; use crate::proto::peer::peer_layer::{ - HyperNodePeerLayerInner, NodeConnectionType, PeerConnectionType, PeerResponse, PeerSignal, + NodeConnectionType, PeerConnectionType, PeerResponse, PeerSignal, }; use crate::proto::remote::Ticket; use crate::proto::session_manager::HdpSessionManager; @@ -74,7 +74,7 @@ pub async fn process_peer_cmd( let task = async move { let session = &session; // we can unwrap below safely since the header layout has already been verified - let header = Ref::new(&*header).unwrap() as Ref<&[u8], HdpHeader>; + let header = Ref::new(&*header).unwrap(); match aux_cmd { packet_flags::cmd::aux::peer_cmd::GROUP_BROADCAST => { @@ -240,6 +240,7 @@ pub async fn process_peer_cmd( event: PeerSignal::SignalError { ticket, error: err.into_string(), + peer_connection_type: vconn.reverse(), }, ticket, implicated_cid, @@ -256,9 +257,10 @@ pub async fn process_peer_cmd( invitee_response: Some(resp), session_security_settings: endpoint_security_settings, udp_mode: udp_enabled, + session_password: _, } => { + log::trace!(target: "citadel", "Handling peer connect"); let accepted = matches!(resp, PeerResponse::Accept(_)); - // TODO: handle non-accept case // the connection was mutually accepted. Now, we must begin the KEM subroutine if accepted { return match conn { @@ -272,11 +274,6 @@ pub async fn process_peer_cmd( // unique to the session. //let mut state_container = inner_mut!(session.state_container); //let peer_cid = conn.get_original_implicated_cid(); - let mut peer_kem_state_container = - PeerKemStateContainer::new( - *endpoint_security_settings, - *udp_enabled == UdpMode::Enabled, - ); let alice_constructor = return_if_none!(StackedRatchetConstructor::new_alice( @@ -299,14 +296,36 @@ pub async fn process_peer_cmd( //log::trace!(target: "citadel", "0. Len: {}, {:?}", alice_pub_key.len(), &alice_pub_key[..10]); let msg_bytes = return_if_none!(transfer.serialize_to_vec()); + + let mut state_container = + inner_mut_state!(session.state_container); + + let session_password = state_container + .get_session_password( + conn.get_original_implicated_cid(), + ) + .cloned(); + if session_password.is_none() { + log::error!(target: "citadel", "The session password locally is set to None. This is a development issue, please report"); + } + + let session_password = session_password.unwrap_or_default(); + let mut peer_kem_state_container = + PeerKemStateContainer::new( + *endpoint_security_settings, + *udp_enabled == UdpMode::Enabled, + session_password.clone(), + ); + peer_kem_state_container.constructor = Some(alice_constructor); - inner_mut_state!(session.state_container) - .peer_kem_states - .insert( - *original_implicated_cid, - peer_kem_state_container, - ); + + state_container.peer_kem_states.insert( + *original_implicated_cid, + peer_kem_state_container, + ); + + drop(state_container); // finally, prepare the signal and send outbound // signal: PeerSignal, pqc: &Rc, drill: &EntropyBank, ticket: Ticket, timestamp: i64 let signal = PeerSignal::Kex { @@ -339,6 +358,17 @@ pub async fn process_peer_cmd( Ok(PrimaryProcessorResult::Void) } }; + } else { + // Send error to kernel for peer connect fail. Reason: did not accept + session.send_to_kernel(NodeResult::PeerEvent(PeerEvent { + event: PeerSignal::SignalError { + ticket, + error: "Peer did not accept connection".to_string(), + peer_connection_type: conn.reverse(), + }, + ticket, + implicated_cid, + }))?; } } @@ -360,6 +390,18 @@ pub async fn process_peer_cmd( let transfer_deser = return_if_none!( AliceToBobTransfer::deserialize_from(transfer) ); + + let mut state_container = + inner_mut_state!(session.state_container); + + let session_password = + state_container.get_session_password(peer_cid).cloned(); + if session_password.is_none() { + log::error!(target: "citadel", "The session password locally is set to None. This is a development issue, please report"); + } + + let session_password = session_password.unwrap_or_default(); + let bob_constructor = return_if_none!(StackedRatchetConstructor::new_bob( conn.get_original_target_cid(), @@ -370,10 +412,10 @@ pub async fn process_peer_cmd( + 1) as usize ), - transfer_deser + transfer_deser, + session_password.as_ref(), )); let transfer = return_if_none!(bob_constructor.stage0_bob()); - let bob_transfer = return_if_none!(transfer.serialize_to_vector().ok()); @@ -394,12 +436,14 @@ pub async fn process_peer_cmd( let mut state_container_kem = PeerKemStateContainer::new( *session_security_settings, *udp_enabled == UdpMode::Enabled, + session_password, ); state_container_kem.constructor = Some(bob_constructor); - inner_mut_state!(session.state_container) + state_container .peer_kem_states .insert(peer_cid, state_container_kem); + drop(state_container); let stage1_kem = packet_crafter::peer_cmd::craft_peer_signal( &sess_hyper_ratchet, signal, @@ -449,11 +493,35 @@ pub async fn process_peer_cmd( BobToAliceTransfer::deserialize_from(transfer), "bad deser" ); - alice_constructor - .stage1_alice(BobToAliceTransferType::Default(deser)) - .map_err(|err| { - NetworkError::Generic(err.to_string()) - })?; + + if let Err(err) = alice_constructor.stage1_alice( + BobToAliceTransferType::Default(deser), + kem_state.session_password.as_ref(), + ) { + log::warn!(target: "citadel", "Failed to complete key exchange for {implicated_cid} | Wrong session passwords? Err: {err:?}"); + send_dc_signal_peer( + session, + ticket, + format!("{err:?}"), + )?; + // Send the error signal to the peer + let error_signal = PeerSignal::SignalError { + ticket, + error: err.into_string(), + peer_connection_type: conn.reverse(), + }; + let error_packet = + packet_crafter::peer_cmd::craft_peer_signal( + &sess_hyper_ratchet, + error_signal, + ticket, + timestamp, + security_level, + ); + return Ok(PrimaryProcessorResult::ReplyToSender( + error_packet, + )); + } let hyper_ratchet = return_if_none!( alice_constructor.finish_with_custom_cid(this_cid) ); @@ -524,7 +592,7 @@ pub async fn process_peer_cmd( let ticket_for_chan = state_container .outgoing_peer_connect_attempts .remove(&peer_cid); - std::mem::drop(state_container); + drop(state_container); let stun_servers = session.stun_servers.clone(); let encrypted_config_container = generate_hole_punch_crypt_container( @@ -767,7 +835,6 @@ pub async fn process_peer_cmd( _ => {} } - log::trace!(target: "citadel", "Forwarding PEER signal to kernel ..."); session .kernel_tx .unbounded_send(NodeResult::PeerEvent(PeerEvent { @@ -857,21 +924,27 @@ async fn process_signal_command_as_server( return Ok(PrimaryProcessorResult::Void); } - let res = sess_mgr.send_signal_to_peer_direct( - conn.get_original_target_cid(), - move |peer_hyper_ratchet| { - packet_crafter::peer_cmd::craft_peer_signal( - peer_hyper_ratchet, - signal_to, - ticket, - timestamp, - security_level, - ) - }, - ); + let peer_cid = conn.get_original_target_cid(); + + let res = sess_mgr.send_signal_to_peer_direct(peer_cid, move |peer_hyper_ratchet| { + packet_crafter::peer_cmd::craft_peer_signal( + peer_hyper_ratchet, + signal_to, + ticket, + timestamp, + security_level, + ) + }); if let Err(err) = res { - reply_to_sender_err(err, &sess_hyper_ratchet, ticket, timestamp, security_level) + reply_to_sender_err( + err, + &sess_hyper_ratchet, + ticket, + timestamp, + security_level, + peer_cid, + ) } else { Ok(PrimaryProcessorResult::Void) } @@ -896,7 +969,6 @@ async fn process_signal_command_as_server( if let Some(peer_response) = peer_response { // the signal is going to be routed from HyperLAN Client B to HyperLAN client A (response phase) super::server::post_register::handle_response_phase_post_register( - &mut *session.hypernode_peer_layer.inner.write().await, peer_conn_type, username, peer_response, @@ -936,10 +1008,11 @@ async fn process_signal_command_as_server( peer_layer.check_simultaneous_register(implicated_cid, target_cid) { log::trace!(target: "citadel", "Simultaneous register detected! Simulating implicated_cid={} sent an accept_register to target={}", implicated_cid, target_cid); + peer_layer.insert_mapped_ticket(implicated_cid, ticket_new, ticket); // route signal to peer + drop(peer_layer); let _ = super::server::post_register::handle_response_phase_post_register( - &mut peer_layer, peer_conn_type, username.clone(), PeerResponse::Accept(Some(username)), @@ -984,8 +1057,8 @@ async fn process_signal_command_as_server( let to_primary_stream = return_if_none!(session.to_primary_stream.clone()); let sess_mgr = session.session_manager.clone(); + drop(peer_layer); route_signal_and_register_ticket_forwards( - &peer_layer, PeerSignal::PostRegister { peer_conn_type, inviter_username: username, @@ -1093,6 +1166,10 @@ async fn process_signal_command_as_server( let error_signal = PeerSignal::SignalError { ticket, error: err.into_string(), + peer_connection_type: PeerConnectionType::LocalGroupPeer { + implicated_cid, + peer_cid: target_cid, + }, }; let error_packet = packet_crafter::peer_cmd::craft_peer_signal( &sess_hyper_ratchet, @@ -1123,6 +1200,7 @@ async fn process_signal_command_as_server( invitee_response: peer_response, session_security_settings: endpoint_security_level, udp_mode: udp_enabled, + session_password, } => { match peer_conn_type { PeerConnectionType::LocalGroupPeer { @@ -1131,10 +1209,8 @@ async fn process_signal_command_as_server( } => { // TODO: Change timeouts. Create a better timeout system, in general const TIMEOUT: Duration = Duration::from_secs(60 * 60); - let mut peer_layer = session.hypernode_peer_layer.inner.write().await; if let Some(peer_response) = peer_response { super::server::post_connect::handle_response_phase_post_connect( - &mut peer_layer, peer_conn_type, ticket, peer_response, @@ -1152,16 +1228,18 @@ async fn process_signal_command_as_server( // the signal is going to be routed from HyperLAN client A to HyperLAN client B (initiation phase) let to_primary_stream = return_if_none!(session.to_primary_stream.clone()); let sess_mgr = session.session_manager.clone(); + let mut peer_layer = session.hypernode_peer_layer.inner.write().await; if let Some(ticket_new) = peer_layer.check_simultaneous_connect(implicated_cid, target_cid) { log::trace!(target: "citadel", "Simultaneous connect detected! Simulating implicated_cid={} sent an accept_connect to target={}", implicated_cid, target_cid); log::trace!(target: "citadel", "Simultaneous connect: first_ticket: {} | sender expected ticket: {}", ticket_new, ticket); + peer_layer.insert_mapped_ticket(implicated_cid, ticket_new, ticket); // NOTE: Packet will rebound to sender, then, sender will locally send // packet to the peer who first attempted a connect request + drop(peer_layer); let _ = super::server::post_connect::handle_response_phase_post_connect( - &mut peer_layer, peer_conn_type, ticket_new, PeerResponse::Accept(None), @@ -1177,14 +1255,15 @@ async fn process_signal_command_as_server( .await?; Ok(PrimaryProcessorResult::Void) } else { + drop(peer_layer); route_signal_and_register_ticket_forwards( - &peer_layer, PeerSignal::PostConnect { peer_conn_type, ticket_opt: Some(ticket), invitee_response: None, session_security_settings: endpoint_security_level, udp_mode: udp_enabled, + session_password, }, TIMEOUT, implicated_cid, @@ -1421,16 +1500,58 @@ async fn process_signal_command_as_server( ticket: _ticket, } => Ok(PrimaryProcessorResult::Void), - PeerSignal::SignalError { ticket, error: err } => { + PeerSignal::SignalError { + ticket, + error, + peer_connection_type, + } => { // in this case, we delegate the error to the higher-level kernel to determine what to do + let signal = PeerSignal::SignalError { + ticket, + error, + peer_connection_type, + }; session .kernel_tx .unbounded_send(NodeResult::PeerEvent(PeerEvent { - event: PeerSignal::SignalError { ticket, error: err }, + event: signal.clone(), ticket, implicated_cid: sess_hyper_ratchet.get_cid(), }))?; - Ok(PrimaryProcessorResult::Void) + + let peer_cid = peer_connection_type.get_original_target_cid(); + // If this was a simultaneous connect, we need to remap the ticket + let mut peer_layer = session.hypernode_peer_layer.inner.write().await; + let ticket = peer_layer + .take_mapped_ticket(peer_cid, ticket) + .unwrap_or(ticket); + drop(peer_layer); + + let res = inner!(session.session_manager).send_signal_to_peer_direct( + peer_cid, + move |peer_hyper_ratchet| { + packet_crafter::peer_cmd::craft_peer_signal( + peer_hyper_ratchet, + signal, + ticket, + timestamp, + security_level, + ) + }, + ); + + if let Err(err) = res { + reply_to_sender_err( + err, + &sess_hyper_ratchet, + ticket, + timestamp, + security_level, + peer_cid, + ) + } else { + Ok(PrimaryProcessorResult::Void) + } } PeerSignal::SignalReceived { ticket } => { @@ -1480,9 +1601,17 @@ fn reply_to_sender_err( ticket: Ticket, timestamp: i64, security_level: SecurityLevel, + peer_cid: u64, ) -> Result { Ok(PrimaryProcessorResult::ReplyToSender( - construct_error_signal(err, hyper_ratchet, ticket, timestamp, security_level), + construct_error_signal( + err, + hyper_ratchet, + ticket, + timestamp, + security_level, + peer_cid, + ), )) } @@ -1492,10 +1621,15 @@ fn construct_error_signal( ticket: Ticket, timestamp: i64, security_level: SecurityLevel, + peer_cid: u64, ) -> BytesMut { let err_signal = PeerSignal::SignalError { ticket, error: err.to_string(), + peer_connection_type: PeerConnectionType::LocalGroupPeer { + implicated_cid: hyper_ratchet.get_cid(), + peer_cid, + }, }; packet_crafter::peer_cmd::craft_peer_signal( hyper_ratchet, @@ -1508,7 +1642,6 @@ fn construct_error_signal( #[allow(clippy::too_many_arguments)] pub(crate) async fn route_signal_and_register_ticket_forwards( - peer_layer: &HyperNodePeerLayerInner, signal: PeerSignal, timeout: Duration, implicated_cid: u64, @@ -1524,7 +1657,7 @@ pub(crate) async fn route_signal_and_register_ticket_forwards( let to_primary_stream = to_primary_stream.clone(); // Give the target_cid 10 seconds to respond - let res = sess_mgr.route_signal_primary(peer_layer, implicated_cid, target_cid, ticket, signal.clone(), move |peer_hyper_ratchet| { + let res = sess_mgr.route_signal_primary(implicated_cid, target_cid, ticket, signal.clone(), move |peer_hyper_ratchet| { packet_crafter::peer_cmd::craft_peer_signal(peer_hyper_ratchet, signal.clone(), ticket, timestamp, security_level) }, timeout, move |stale_signal| { // on timeout, run this @@ -1536,7 +1669,14 @@ pub(crate) async fn route_signal_and_register_ticket_forwards( // Then, we tell the implicated_cid's node that we have handled the message. However, the peer has yet to respond if let Err(err) = res { - reply_to_sender_err(err, sess_hyper_ratchet, ticket, timestamp, security_level) + reply_to_sender_err( + err, + sess_hyper_ratchet, + ticket, + timestamp, + security_level, + target_cid, + ) } else { let received_signal = PeerSignal::SignalReceived { ticket }; reply_to_sender( @@ -1557,13 +1697,12 @@ pub(crate) async fn route_signal_response( target_cid: u64, timestamp: i64, ticket: Ticket, - peer_layer: &mut HyperNodePeerLayerInner, session: HdpSession, sess_hyper_ratchet: &StackedRatchet, on_route_finished: impl FnOnce(&HdpSession, &HdpSession, PeerSignal), security_level: SecurityLevel, ) -> Result { - log::trace!(target: "citadel", "Routing signal {:?} | impl: {} | target: {}", signal, implicated_cid, target_cid); + trace!(target: "citadel", "Routing signal {:?} | impl: {} | target: {}", signal, implicated_cid, target_cid); let sess_ref = &session; let res = session @@ -1572,7 +1711,7 @@ pub(crate) async fn route_signal_response( implicated_cid, target_cid, ticket, - peer_layer, + sess_ref, move |peer_hyper_ratchet| { packet_crafter::peer_cmd::craft_peer_signal( peer_hyper_ratchet, @@ -1605,7 +1744,14 @@ pub(crate) async fn route_signal_response( Err(err) => { log::warn!(target: "citadel", "Unable to route signal! {:?}", err); - reply_to_sender_err(err, sess_hyper_ratchet, ticket, timestamp, security_level) + reply_to_sender_err( + err, + sess_hyper_ratchet, + ticket, + timestamp, + security_level, + target_cid, + ) } } } diff --git a/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs b/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs index e93657f69..c2f5e85d3 100644 --- a/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs +++ b/citadel_proto/src/proto/packet_processor/peer/server/post_connect.rs @@ -3,7 +3,6 @@ use crate::prelude::{PeerConnectionType, PeerResponse, PeerSignal}; use crate::proto::packet_processor::includes::VirtualConnectionType; use crate::proto::packet_processor::peer::peer_cmd_packet::route_signal_response; use crate::proto::packet_processor::PrimaryProcessorResult; -use crate::proto::peer::peer_layer::HyperNodePeerLayerInner; use crate::proto::remote::Ticket; use crate::proto::session::HdpSession; use citadel_crypt::stacked_ratchet::StackedRatchet; @@ -13,7 +12,6 @@ use citadel_types::proto::{SessionSecuritySettings, UdpMode}; #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, implicated_cid = implicated_cid, target_cid = target_cid)))] #[allow(clippy::too_many_arguments)] pub(crate) async fn handle_response_phase_post_connect( - peer_layer: &mut HyperNodePeerLayerInner, peer_conn_type: PeerConnectionType, ticket: Ticket, peer_response: PeerResponse, @@ -32,8 +30,9 @@ pub(crate) async fn handle_response_phase_post_connect( ticket_opt: Some(ticket), invitee_response: Some(peer_response), session_security_settings: endpoint_security_level, - udp_mode: udp_enabled - }, implicated_cid, target_cid, timestamp, ticket, peer_layer, session.clone(), sess_hyper_ratchet, + udp_mode: udp_enabled, + session_password: None, + }, implicated_cid, target_cid, timestamp, ticket,session.clone(), sess_hyper_ratchet, |this_sess, peer_sess, _original_tracked_posting| { // when the route finishes, we need to update both sessions to allow high-level message-passing // In other words, forge a virtual connection diff --git a/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs b/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs index e51a0424f..925057cec 100644 --- a/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs +++ b/citadel_proto/src/proto/packet_processor/peer/server/post_register.rs @@ -2,7 +2,7 @@ use crate::error::NetworkError; use crate::prelude::{PeerConnectionType, PeerResponse, PeerSignal}; use crate::proto::packet_processor::peer::peer_cmd_packet::route_signal_response; use crate::proto::packet_processor::PrimaryProcessorResult; -use crate::proto::peer::peer_layer::{HyperNodePeerLayerInner, Username}; +use crate::proto::peer::peer_layer::Username; use crate::proto::remote::Ticket; use crate::proto::session::HdpSession; use citadel_crypt::stacked_ratchet::StackedRatchet; @@ -11,7 +11,6 @@ use citadel_types::crypto::SecurityLevel; #[cfg_attr(feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err, fields(is_server = session.is_server, implicated_cid = implicated_cid, target_cid = target_cid)))] #[allow(clippy::too_many_arguments)] pub async fn handle_response_phase_post_register( - peer_layer: &mut HyperNodePeerLayerInner, peer_conn_type: PeerConnectionType, username: Username, peer_response: PeerResponse, @@ -30,8 +29,8 @@ pub async fn handle_response_phase_post_register( inviter_username: username, invitee_username: None, ticket_opt: Some(ticket), - invitee_response: Some(peer_response) - }, implicated_cid, target_cid, timestamp, ticket, peer_layer, session.clone(), sess_hyper_ratchet, + invitee_response: Some(peer_response), + }, implicated_cid, target_cid, timestamp, ticket, session.clone(), sess_hyper_ratchet, |this_sess, _peer_sess, _original_tracked_posting| { if !decline { let account_manager = this_sess.account_manager.clone(); diff --git a/citadel_proto/src/proto/packet_processor/preconnect_packet.rs b/citadel_proto/src/proto/packet_processor/preconnect_packet.rs index 6e7866d06..d9ac72d17 100644 --- a/citadel_proto/src/proto/packet_processor/preconnect_packet.rs +++ b/citadel_proto/src/proto/packet_processor/preconnect_packet.rs @@ -81,6 +81,7 @@ pub async fn process_preconnect( &cnac, packet, &session.session_manager, + &session.session_password, ) { Ok(( static_aux_ratchet, @@ -163,6 +164,7 @@ pub async fn process_preconnect( let implicated_cid = header.session_cid.get(); if let Some((new_hyper_ratchet, nat_type)) = validation::pre_connect::validate_syn_ack( + &session.session_password, cnac, alice_constructor, packet, diff --git a/citadel_proto/src/proto/packet_processor/primary_group_packet.rs b/citadel_proto/src/proto/packet_processor/primary_group_packet.rs index b139a3d22..95162afae 100644 --- a/citadel_proto/src/proto/packet_processor/primary_group_packet.rs +++ b/citadel_proto/src/proto/packet_processor/primary_group_packet.rs @@ -3,7 +3,7 @@ use crate::constants::GROUP_EXPIRE_TIME_MS; use crate::error::NetworkError; use crate::functional::IfTrueConditional; use crate::inner_arg::ExpectedInnerTarget; -use crate::prelude::InternalServerError; +use crate::prelude::{InternalServerError, PreSharedKey}; use crate::proto::node_result::OutboundRequestRejected; use crate::proto::packet_crafter::peer_cmd::C2S_ENCRYPTION_ONLY; use crate::proto::session_queue_handler::QueueWorkerResult; @@ -144,6 +144,7 @@ pub fn process_primary_packet( // now, update the keys (if applicable) let transfer = return_if_none!( attempt_kem_as_bob( + session, resp_target_cid, &header, transfer.map(AliceToBobTransferType::Default), @@ -330,6 +331,7 @@ pub fn process_primary_packet( // TODO: make the below function return a result, not bools if state_container.on_group_header_ack_received( + session, secrecy_mode, peer_cid, target_cid, @@ -675,6 +677,7 @@ impl ToolsetUpdate<'_> { /// /// Returns: Ok(latest_hyper_ratchet) pub(crate) fn attempt_kem_as_alice_finish( + session: &HdpSession, base_session_secrecy_mode: SecrecyMode, peer_cid: u64, target_cid: u64, @@ -682,7 +685,10 @@ pub(crate) fn attempt_kem_as_alice_finish( state_container: &mut StateContainerInner, constructor: Option>, ) -> Result>, ()> { - let (mut toolset_update_method, secrecy_mode) = if target_cid != C2S_ENCRYPTION_ONLY { + let (mut toolset_update_method, secrecy_mode, session_password) = if target_cid + != C2S_ENCRYPTION_ONLY + { + let session_password = state_container.get_session_password(peer_cid).cloned(); let endpoint_container = state_container .active_virtual_connections .get_mut(&peer_cid) @@ -691,12 +697,17 @@ pub(crate) fn attempt_kem_as_alice_finish( .as_mut() .ok_or(())?; let crypt = &mut endpoint_container.endpoint_crypto; + if session_password.is_none() { + log::error!(target: "citadel", "Session password not found for peer_cid {}", peer_cid); + return Err(()); + } ( ToolsetUpdate::E2E { crypt, local_cid: target_cid, }, endpoint_container.default_security_settings.secrecy_mode, + session_password.unwrap(), ) } else { let crypt = &mut state_container @@ -710,6 +721,7 @@ pub(crate) fn attempt_kem_as_alice_finish( local_cid: peer_cid, }, base_session_secrecy_mode, + session.session_password.clone(), ) }; @@ -719,7 +731,7 @@ pub(crate) fn attempt_kem_as_alice_finish( match transfer { KemTransferStatus::Some(transfer, ..) => { if let Some(mut constructor) = constructor { - if let Err(err) = constructor.stage1_alice(transfer) { + if let Err(err) = constructor.stage1_alice(transfer, session_password.as_ref()) { log::error!(target: "citadel", "Unable to construct hyper ratchet {:?}", err); return Err(()); // return true, otherwise, the session ends } @@ -776,6 +788,7 @@ pub(crate) fn attempt_kem_as_alice_finish( /// NOTE! Assumes the `hr` passed is the latest version IF the transfer is some pub(crate) fn attempt_kem_as_bob( + session: &HdpSession, resp_target_cid: u64, header: &Ref<&[u8], HdpHeader>, transfer: Option, @@ -783,30 +796,44 @@ pub(crate) fn attempt_kem_as_bob( hr: &StackedRatchet, ) -> Option { if let Some(transfer) = transfer { - let update = if resp_target_cid != C2S_ENCRYPTION_ONLY { + let (update, session_password) = if resp_target_cid != C2S_ENCRYPTION_ONLY { + let session_password = state_container + .get_session_password(resp_target_cid) + .cloned(); + if session_password.is_none() { + log::error!(target: "citadel", "Session password not found for peer_cid {}", resp_target_cid); + return None; + } + let crypt = &mut state_container .active_virtual_connections .get_mut(&resp_target_cid)? .endpoint_container .as_mut()? .endpoint_crypto; - ToolsetUpdate::E2E { - crypt, - local_cid: header.target_cid.get(), - } + ( + ToolsetUpdate::E2E { + crypt, + local_cid: header.target_cid.get(), + }, + session_password.unwrap(), + ) } else { let crypt = &mut state_container .c2s_channel_container .as_mut() .unwrap() .peer_session_crypto; - ToolsetUpdate::E2E { - crypt, - local_cid: header.session_cid.get(), - } + ( + ToolsetUpdate::E2E { + crypt, + local_cid: header.session_cid.get(), + }, + session.session_password.clone(), + ) }; - update_toolset_as_bob(update, transfer, hr) + update_toolset_as_bob(update, transfer, hr, session_password) } else { Some(KemTransferStatus::Empty) } @@ -816,6 +843,7 @@ pub(crate) fn update_toolset_as_bob( mut update_method: ToolsetUpdate<'_>, transfer: AliceToBobTransferType, hr: &StackedRatchet, + session_password: PreSharedKey, ) -> Option { let cid = update_method.get_local_cid(); let new_version = transfer.get_declared_new_version(); @@ -823,8 +851,13 @@ pub(crate) fn update_toolset_as_bob( //let opts = ConstructorOpts::new_vec_init(Some(crypto_params), (session_security_level.value() + 1) as usize); let opts = hr.get_next_constructor_opts(); if matches!(transfer, AliceToBobTransferType::Fcm(..)) { - let constructor = - EndpointRatchetConstructor::::new_bob(cid, new_version, opts, transfer)?; + let constructor = EndpointRatchetConstructor::::new_bob( + cid, + new_version, + opts, + transfer, + session_password.as_ref(), + )?; Some( update_method .update(ConstructorType::Fcm(constructor), false) @@ -836,6 +869,7 @@ pub(crate) fn update_toolset_as_bob( new_version, opts, transfer, + session_password.as_ref(), )?; Some( update_method diff --git a/citadel_proto/src/proto/packet_processor/register_packet.rs b/citadel_proto/src/proto/packet_processor/register_packet.rs index 48ecc741c..cb4e25969 100644 --- a/citadel_proto/src/proto/packet_processor/register_packet.rs +++ b/citadel_proto/src/proto/packet_processor/register_packet.rs @@ -61,6 +61,7 @@ pub async fn process_register( } std::mem::drop(state_container); + let session_password = session.session_password.clone(); async move { let cid = header.session_cid.get(); @@ -72,6 +73,7 @@ pub async fn process_register( (transfer.security_level.value() + 1) as usize, ), transfer, + session_password.as_ref(), ) .ok_or(NetworkError::InvalidRequest("Bad bob transfer"))?; let transfer = return_if_none!( @@ -143,7 +145,10 @@ pub async fn process_register( ); let security_level = transfer.security_level; alice_constructor - .stage1_alice(BobToAliceTransferType::Default(transfer)) + .stage1_alice( + BobToAliceTransferType::Default(transfer), + session.session_password.as_ref(), + ) .map_err(|err| NetworkError::Generic(err.to_string()))?; let new_hyper_ratchet = return_if_none!( alice_constructor.finish(), diff --git a/citadel_proto/src/proto/packet_processor/rekey_packet.rs b/citadel_proto/src/proto/packet_processor/rekey_packet.rs index 1f9766023..c5cd20405 100644 --- a/citadel_proto/src/proto/packet_processor/rekey_packet.rs +++ b/citadel_proto/src/proto/packet_processor/rekey_packet.rs @@ -56,6 +56,7 @@ pub fn process_rekey( let resp_target_cid = get_resp_target_cid_from_header(&header); let status = return_if_none!( attempt_kem_as_bob( + session, resp_target_cid, &header, Some(AliceToBobTransferType::Default(transfer)), @@ -112,6 +113,7 @@ pub fn process_rekey( let latest_hr = return_if_none!(return_if_none!( attempt_kem_as_alice_finish( + session, secrecy_mode, peer_cid, target_cid, diff --git a/citadel_proto/src/proto/peer/p2p_conn_handler.rs b/citadel_proto/src/proto/peer/p2p_conn_handler.rs index 5d4c73291..01d0beb5c 100644 --- a/citadel_proto/src/proto/peer/p2p_conn_handler.rs +++ b/citadel_proto/src/proto/peer/p2p_conn_handler.rs @@ -180,6 +180,7 @@ fn handle_p2p_stream( let p2p_primary_stream_tx = OutboundPrimaryStreamSender::from(p2p_primary_stream_tx); let p2p_primary_stream_rx = OutboundPrimaryStreamReceiver::from(p2p_primary_stream_rx); //let (header_obfuscator, packet_opt) = HeaderObfuscator::new(from_listener); + let peer_cid = v_conn.get_target_cid(); let (stopper_tx, stopper_rx) = channel(); let p2p_handle = P2PInboundHandle::new( @@ -188,6 +189,7 @@ fn handle_p2p_stream( implicated_cid, kernel_tx, p2p_primary_stream_tx.clone(), + peer_cid, ); let writer_future = HdpSession::outbound_stream(p2p_primary_stream_rx, sink); let reader_future = @@ -228,6 +230,12 @@ fn handle_p2p_stream( } } + let mut state_container = inner_mut_state!(sess.state_container); + state_container.active_virtual_connections.remove(&peer_cid); + state_container + .outgoing_peer_connect_attempts + .remove(&peer_cid); + log::trace!(target: "citadel", "[P2P-stream] Dropping tri-joined future"); res }; @@ -244,6 +252,7 @@ pub struct P2PInboundHandle { pub implicated_cid: DualRwLock>, pub kernel_tx: UnboundedSender, pub to_primary_stream: OutboundPrimaryStreamSender, + pub peer_cid: u64, } impl P2PInboundHandle { @@ -253,6 +262,7 @@ impl P2PInboundHandle { implicated_cid: DualRwLock>, kernel_tx: UnboundedSender, to_primary_stream: OutboundPrimaryStreamSender, + peer_cid: u64, ) -> Self { Self { remote_peer, @@ -260,6 +270,7 @@ impl P2PInboundHandle { implicated_cid, kernel_tx, to_primary_stream, + peer_cid, } } } @@ -335,7 +346,7 @@ pub(crate) async fn attempt_simultaneous_hole_punch( ) } else { log::trace!(target: "citadel", "Non-initiator will begin listening immediately"); - std::mem::drop(hole_punched_socket); // drop to prevent conflicts caused by SO_REUSE_ADDR + drop(hole_punched_socket); // drop to prevent conflicts caused by SO_REUSE_ADDR setup_listener_non_initiator(local_addr, remote_connect_addr, session.clone(), v_conn, addr, ticket) .await .map_err(|err|generic_error(format!("Non-initiator was unable to secure connection despite hole-punching success: {err:?}"))) diff --git a/citadel_proto/src/proto/peer/peer_layer.rs b/citadel_proto/src/proto/peer/peer_layer.rs index 964a79e19..4df59cda3 100644 --- a/citadel_proto/src/proto/peer/peer_layer.rs +++ b/citadel_proto/src/proto/peer/peer_layer.rs @@ -1,5 +1,6 @@ use crate::error::NetworkError; use crate::macros::SyncContextRequirements; +use crate::prelude::PreSharedKey; use crate::proto::packet_processor::peer::group_broadcast::GroupBroadcast; use crate::proto::peer::message_group::{MessageGroup, MessageGroupPeer}; use crate::proto::peer::peer_crypt::KeyExchangeProcess; @@ -34,6 +35,7 @@ pub struct HyperNodePeerLayerInner { // When a signal is routed to the target destination, the server needs to keep track of the state while awaiting pub(crate) persistence_handler: PersistenceHandler, pub(crate) message_groups: HashMap>, + pub(crate) simultaneous_ticket_mappings: HashMap>, waker: Arc, inner: Arc>, } @@ -83,15 +85,17 @@ impl TrackedPosting { } impl HyperNodePeerLayer { + #[allow(clippy::arc_with_non_send_sync)] pub fn new(persistence_handler: PersistenceHandler) -> HyperNodePeerLayer { - let waker = std::sync::Arc::new(AtomicWaker::new()); + let waker = Arc::new(AtomicWaker::new()); let inner = HyperNodePeerLayerInner { waker: waker.clone(), inner: Arc::new(citadel_io::RwLock::new(Default::default())), + simultaneous_ticket_mappings: Default::default(), persistence_handler, message_groups: HashMap::new(), }; - let inner = std::sync::Arc::new(tokio::sync::RwLock::new(inner)); + let inner = Arc::new(tokio::sync::RwLock::new(inner)); Self { inner, waker } } @@ -386,6 +390,18 @@ impl HyperNodePeerLayerExecutor { } impl HyperNodePeerLayerInner { + pub fn insert_mapped_ticket(&mut self, cid: u64, ticket: Ticket, mapped_ticket: Ticket) { + self.simultaneous_ticket_mappings + .entry(cid) + .or_default() + .insert(ticket, mapped_ticket); + } + + pub fn take_mapped_ticket(&mut self, cid: u64, ticket: Ticket) -> Option { + self.simultaneous_ticket_mappings + .get_mut(&cid)? + .remove(&ticket) + } /// Determines if `peer_cid` is already attempting to register to `implicated_cid` /// Returns the target's ticket for their corresponding request pub fn check_simultaneous_register( @@ -395,7 +411,7 @@ impl HyperNodePeerLayerInner { ) -> Option { log::trace!(target: "citadel", "Checking simultaneous register between {} and {}", implicated_cid, peer_cid); - self.check_simultaneous_event(peer_cid, |posting| if let PeerSignal::PostRegister { peer_conn_type: conn, inviter_username: _, invitee_username: _, ticket_opt: _, invitee_response: None } = &posting.signal { + self.check_simultaneous_event(peer_cid, |posting| if let PeerSignal::PostRegister { peer_conn_type: conn, inviter_username: _, invitee_username: _, ticket_opt: _, invitee_response: None, .. } = &posting.signal { log::trace!(target: "citadel", "Checking if posting from conn={:?} ~ {:?}", conn, implicated_cid); if let PeerConnectionType::LocalGroupPeer { implicated_cid: _, peer_cid: b } = conn { *b == implicated_cid @@ -416,7 +432,7 @@ impl HyperNodePeerLayerInner { ) -> Option { log::trace!(target: "citadel", "Checking simultaneous register between {} and {}", implicated_cid, peer_cid); - self.check_simultaneous_event(peer_cid, |posting| if let PeerSignal::PostConnect { peer_conn_type: conn, ticket_opt: _, invitee_response: _, session_security_settings: _, udp_mode: _ } = &posting.signal { + self.check_simultaneous_event(peer_cid, |posting| if let PeerSignal::PostConnect { peer_conn_type: conn, ticket_opt: _, invitee_response: _, session_security_settings: _, udp_mode: _, .. } = &posting.signal { log::trace!(target: "citadel", "Checking if posting from conn={:?} ~ {:?}", conn, implicated_cid); if let PeerConnectionType::LocalGroupPeer { implicated_cid: _, peer_cid: b } = conn { *b == implicated_cid @@ -556,6 +572,9 @@ pub enum PeerSignal { invitee_response: Option, session_security_settings: SessionSecuritySettings, udp_mode: UdpMode, + // On the wire, this should be set to None. Should always be set to some value when submitting. + #[serde(skip)] + session_password: Option, }, Disconnect { peer_conn_type: PeerConnectionType, @@ -590,6 +609,7 @@ pub enum PeerSignal { SignalError { ticket: Ticket, error: String, + peer_connection_type: PeerConnectionType, }, DeregistrationSuccess { peer_conn_type: PeerConnectionType, diff --git a/citadel_proto/src/proto/session.rs b/citadel_proto/src/proto/session.rs index 94c786e29..be9442699 100644 --- a/citadel_proto/src/proto/session.rs +++ b/citadel_proto/src/proto/session.rs @@ -43,7 +43,7 @@ use citadel_types::proto::SessionSecuritySettings; //use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender, channel, TrySendError}; use crate::auth::AuthenticationRequest; use crate::kernel::RuntimeFuture; -use crate::prelude::{GroupBroadcast, PeerEvent, PeerResponse, SecureProtocolPacket}; +use crate::prelude::{GroupBroadcast, PeerEvent, PeerResponse, PreSharedKey, SecureProtocolPacket}; use crate::proto::endpoint_crypto_accessor::EndpointCryptoAccessor; use crate::proto::misc::dual_cell::DualCell; use crate::proto::misc::dual_late_init::DualLateInit; @@ -209,6 +209,7 @@ pub struct HdpSessionInner { pub(super) stun_servers: Option>, pub(super) init_time: Instant, pub(super) file_transfer_compatible: DualLateInit, + pub(super) session_password: PreSharedKey, on_drop: UnboundedSender<()>, } @@ -258,6 +259,7 @@ pub(crate) struct SessionInitParams { pub client_only_settings: Option, pub stun_servers: Option>, pub init_time: Instant, + pub session_password: PreSharedKey, } pub(crate) struct ClientOnlySessionInitSettings { @@ -352,6 +354,7 @@ impl HdpSession { .unwrap_or(KEEP_ALIVE_TIMEOUT_NS); let stun_servers = session_init_params.stun_servers; let init_time = session_init_params.init_time; + let session_password = session_init_params.session_password; let mut inner = HdpSessionInner { hypernode_peer_layer, @@ -394,6 +397,7 @@ impl HdpSession { stun_servers, init_time, file_transfer_compatible: DualLateInit::default(), + session_password, }; if let Some(proposed_credentials) = session_init_params @@ -919,7 +923,7 @@ impl HdpSession { ref implicated_cid, ref kernel_tx, ref primary_stream, - p2p, + peer_cid, is_server, ) = if let Some(p2p) = p2p_handle { ( @@ -928,7 +932,7 @@ impl HdpSession { p2p.implicated_cid, p2p.kernel_tx, p2p.to_primary_stream, - true, + Some(p2p.peer_cid), false, ) } else { @@ -945,7 +949,7 @@ impl HdpSession { implicated_cid, kernel_tx, primary_stream, - false, + None, is_server, ) }; @@ -995,9 +999,10 @@ impl HdpSession { } fn handle_session_terminating_error( + session: &HdpSession, err: std::io::Error, is_server: bool, - p2p: bool, + peer_cid: Option, ) -> SessionShutdownReason { const _WINDOWS_FORCE_SHUTDOWN: i32 = 10054; const _RST: i32 = 104; @@ -1006,7 +1011,7 @@ impl HdpSession { let error = err.raw_os_error().unwrap_or(-1); // error != WINDOWS_FORCE_SHUTDOWN && error != RST && error != ECONN_RST && if error != -1 { - log::error!(target: "citadel", "primary port reader error {}: {}. is server: {}. P2P: {}", error, err.to_string(), is_server, p2p); + log::error!(target: "citadel", "primary port reader error {}: {}. is server: {}. P2P: {}", error, err.to_string(), is_server, peer_cid.is_some()); } let err_string = err.to_string(); @@ -1014,6 +1019,35 @@ impl HdpSession { if err_string.contains(SUCCESS_DISCONNECT) { SessionShutdownReason::ProperShutdown } else { + let implicated_cid = session.implicated_cid.get().unwrap_or_default(); + let v_conn_type = if let Some(peer_cid) = peer_cid { + VirtualConnectionType::LocalGroupPeer { + implicated_cid, + peer_cid, + } + } else { + VirtualConnectionType::LocalGroupServer { implicated_cid } + }; + + if let Err(err) = session.send_to_kernel(NodeResult::Disconnect(Disconnect { + ticket: session.kernel_ticket.get(), + cid_opt: session.implicated_cid.get(), + success: false, + v_conn_type: Some(v_conn_type), + message: err_string.clone(), + })) { + log::error!(target: "citadel", "Error sending disconnect signal to kernel: {err:?}"); + } + + if peer_cid.is_none() { + // If this is a c2s connection, close the session + session.send_session_dc_signal( + Some(session.kernel_ticket.get()), + false, + err_string.as_str(), + ); + } + SessionShutdownReason::Error(NetworkError::Generic(err_string)) } } @@ -1037,7 +1071,7 @@ impl HdpSession { .await; evaluate_result(result, primary_stream, kernel_tx, this_main, implicated_cid) }) - .map_err(|err| handle_session_terminating_error(err, is_server, p2p)) + .map_err(|err| handle_session_terminating_error(this_main, err, is_server, peer_cid)) .await; match res { @@ -1969,10 +2003,12 @@ impl HdpSession { PeerSignal::PostConnect { peer_conn_type: a, ticket_opt: b, - invitee_response: None, + invitee_response, session_security_settings: d, udp_mode: e, + session_password, } => { + let session_password = session_password.unwrap_or_default(); if state_container .outgoing_peer_connect_attempts .contains_key(&a.get_original_target_cid()) @@ -1980,6 +2016,9 @@ impl HdpSession { log::warn!(target: "citadel", "{} is already attempting to connect to {}", a.get_original_implicated_cid(), a.get_original_target_cid()) } + state_container + .store_session_password(a.get_original_target_cid(), session_password); + // in case the ticket gets mapped during simultaneous_connect, store locally let _ = state_container .outgoing_peer_connect_attempts @@ -1987,9 +2026,10 @@ impl HdpSession { PeerSignal::PostConnect { peer_conn_type: a, ticket_opt: b, - invitee_response: None, + invitee_response, session_security_settings: d, udp_mode: e, + session_password: None, } } diff --git a/citadel_proto/src/proto/session_manager.rs b/citadel_proto/src/proto/session_manager.rs index a3044845b..5a0046195 100644 --- a/citadel_proto/src/proto/session_manager.rs +++ b/citadel_proto/src/proto/session_manager.rs @@ -19,7 +19,7 @@ use crate::constants::{DO_CONNECT_EXPIRE_TIME_MS, KEEP_ALIVE_TIMEOUT_NS, UDP_MOD use crate::error::NetworkError; use crate::kernel::RuntimeFuture; use crate::macros::SyncContextRequirements; -use crate::prelude::Disconnect; +use crate::prelude::{Disconnect, PreSharedKey}; use crate::proto::endpoint_crypto_accessor::EndpointCryptoAccessor; use crate::proto::misc::net::GenericNetworkStream; use crate::proto::misc::underlying_proto::ServerUnderlyingProtocol; @@ -31,8 +31,7 @@ use crate::proto::packet_processor::includes::{Duration, Instant}; use crate::proto::packet_processor::peer::group_broadcast::GroupBroadcast; use crate::proto::packet_processor::PrimaryProcessorResult; use crate::proto::peer::peer_layer::{ - HyperNodePeerLayer, HyperNodePeerLayerInner, MailboxTransfer, PeerConnectionType, PeerResponse, - PeerSignal, + HyperNodePeerLayer, MailboxTransfer, PeerConnectionType, PeerResponse, PeerSignal, }; use crate::proto::remote::{NodeRemote, Ticket}; use crate::proto::session::{ @@ -144,6 +143,7 @@ impl HdpSessionManager { keep_alive_timeout_ns: Option, security_settings: SessionSecuritySettings, default_client_config: &Arc, + session_password: PreSharedKey, ) -> Result>, NetworkError> { let (session_manager, new_session, peer_addr, primary_stream) = { let session_manager_clone = self.clone(); @@ -309,6 +309,7 @@ impl HdpSessionManager { client_only_settings: Some(client_only_settings), stun_servers, init_time, + session_password, }; let (stopper, new_session) = HdpSession::new(session_init_params)?; @@ -479,6 +480,7 @@ impl HdpSessionManager { local_nat_type: NatType, peer_addr: SocketAddr, primary_stream: GenericNetworkStream, + session_password: PreSharedKey, ) -> Result>, NetworkError> { let this_dc = self.clone(); let mut this = inner_mut!(self); @@ -513,6 +515,7 @@ impl HdpSessionManager { client_only_settings: None, stun_servers, init_time, + session_password, }; let (stopper, new_session) = HdpSession::new(session_init_params)?; @@ -1080,7 +1083,6 @@ impl HdpSessionManager { #[allow(clippy::too_many_arguments)] pub async fn route_signal_primary( &self, - peer_layer: &HyperNodePeerLayerInner, implicated_cid: u64, target_cid: u64, ticket: Ticket, @@ -1109,7 +1111,11 @@ impl HdpSessionManager { // get the target cid's session if let Some(ref sess_ref) = sess { - peer_layer + sess_ref + .hypernode_peer_layer + .inner + .write() + .await .insert_tracked_posting(implicated_cid, timeout, ticket, signal, on_timeout) .await; let peer_sender = sess_ref.to_primary_stream.as_ref().unwrap(); @@ -1123,15 +1129,21 @@ impl HdpSessionManager { } else { // session is not active, but user is registered (thus offline). Setup return ticket tracker on implicated_cid // and deliver to the mailbox of target_cid, that way target_cid receives mail on connect. TODO: external svc route, if available - peer_layer - .insert_tracked_posting( - implicated_cid, - timeout, - ticket, - signal.clone(), - on_timeout, - ) - .await; + { + let peer_layer = { inner!(self).hypernode_peer_layer.clone() }; + peer_layer + .inner + .write() + .await + .insert_tracked_posting( + implicated_cid, + timeout, + ticket, + signal.clone(), + on_timeout, + ) + .await; + } HyperNodePeerLayer::try_add_mailbox(&pers, target_cid, signal) .await .map_err(|err| err.into_string()) @@ -1247,12 +1259,20 @@ impl HdpSessionManager { implicated_cid: u64, target_cid: u64, ticket: Ticket, - peer_layer: &mut HyperNodePeerLayerInner, + session: &HdpSession, packet: impl FnOnce(&StackedRatchet) -> BytesMut, post_send: impl FnOnce(&HdpSession, PeerSignal) -> Result, ) -> Result, String> { // Instead of checking for registration, check the `implicated_cid`'s timed queue for a ticket corresponding to Ticket. - if let Some(tracked_posting) = peer_layer.remove_tracked_posting_inner(target_cid, ticket) { + let tracked_posting = { + session + .hypernode_peer_layer + .inner + .write() + .await + .remove_tracked_posting_inner(target_cid, ticket) + }; + if let Some(tracked_posting) = tracked_posting { // since the posting was valid, we just need to forward the signal to `implicated_cid` let this = inner!(self); if let Some(target_sess) = this.sessions.get(&target_cid) { diff --git a/citadel_proto/src/proto/state_container.rs b/citadel_proto/src/proto/state_container.rs index 683d463ea..a1bcc388f 100644 --- a/citadel_proto/src/proto/state_container.rs +++ b/citadel_proto/src/proto/state_container.rs @@ -24,7 +24,7 @@ use crate::constants::{ }; use crate::error::NetworkError; use crate::functional::IfEqConditional; -use crate::prelude::{InternalServerError, ReKeyResult, ReKeyReturnType}; +use crate::prelude::{InternalServerError, PreSharedKey, ReKeyResult, ReKeyReturnType}; use crate::proto::misc::dual_late_init::DualLateInit; use crate::proto::misc::dual_rwlock::DualRwLock; use crate::proto::misc::ordered_channel::OrderedChannel; @@ -127,6 +127,7 @@ pub struct StateContainerInner { pub(super) group_channels: HashMap>, pub(super) transfer_stats: TransferStats, pub(super) udp_mode: UdpMode, + pub(super) session_passwords: HashMap, is_server: bool, } @@ -586,10 +587,26 @@ impl StateContainerInner { peer_kem_states: HashMap::new(), inbound_files: HashMap::new(), outbound_files: HashMap::new(), + session_passwords: HashMap::new(), }; inner.into() } + // Note: c2s connections should not be stored here. They are stored in the session.rs file + pub fn store_session_password(&mut self, peer_cid: u64, session_password: PreSharedKey) { + self.session_passwords.insert(peer_cid, session_password); + } + + pub fn get_session_password(&self, peer_cid: u64) -> Option<&PreSharedKey> { + self.session_passwords.get(&peer_cid) + } + + // TODO: use this + #[allow(dead_code)] + pub fn remove_session_password(&mut self, peer_cid: u64) { + self.session_passwords.remove(&peer_cid); + } + /// Attempts to find the direct p2p stream. If not found, will use the default /// to_server stream. Note: the underlying crypto is still the same pub fn get_preferred_stream(&self, peer_cid: u64) -> &OutboundPrimaryStreamSender { @@ -1340,6 +1357,7 @@ impl StateContainerInner { #[allow(clippy::too_many_arguments)] pub fn on_group_header_ack_received( &mut self, + session: &HdpSession, base_session_secrecy_mode: SecrecyMode, peer_cid: u64, target_cid: u64, @@ -1363,6 +1381,7 @@ impl StateContainerInner { }; if attempt_kem_as_alice_finish( + session, base_session_secrecy_mode, peer_cid, target_cid, diff --git a/citadel_proto/src/proto/state_subcontainers/peer_kem_state_container.rs b/citadel_proto/src/proto/state_subcontainers/peer_kem_state_container.rs index 9ebec6ce6..7a6082676 100644 --- a/citadel_proto/src/proto/state_subcontainers/peer_kem_state_container.rs +++ b/citadel_proto/src/proto/state_subcontainers/peer_kem_state_container.rs @@ -1,3 +1,4 @@ +use crate::prelude::PreSharedKey; use crate::proto::state_subcontainers::preconnect_state_container::UdpChannelSender; use citadel_crypt::stacked_ratchet::constructor::StackedRatchetConstructor; use citadel_types::proto::SessionSecuritySettings; @@ -7,12 +8,18 @@ pub struct PeerKemStateContainer { pub(crate) local_is_initiator: bool, pub(crate) session_security_settings: SessionSecuritySettings, pub(crate) udp_channel_sender: UdpChannelSender, + pub(crate) session_password: PreSharedKey, } impl PeerKemStateContainer { - pub fn new(session_security_settings: SessionSecuritySettings, udp_enabled: bool) -> Self { + pub fn new( + session_security_settings: SessionSecuritySettings, + udp_enabled: bool, + session_password: PreSharedKey, + ) -> Self { Self { constructor: None, + session_password, local_is_initiator: false, session_security_settings, udp_channel_sender: if udp_enabled { diff --git a/citadel_proto/src/proto/validation.rs b/citadel_proto/src/proto/validation.rs index be4363fee..83b4999fc 100644 --- a/citadel_proto/src/proto/validation.rs +++ b/citadel_proto/src/proto/validation.rs @@ -214,6 +214,7 @@ pub(crate) mod pre_connect { use citadel_wire::hypernode_type::NodeType; use crate::error::NetworkError; + use crate::prelude::PreSharedKey; use crate::proto::packet::HdpPacket; use crate::proto::packet_crafter::pre_connect::{PreConnectStage0, SynPacket}; use crate::proto::packet_processor::includes::packet_crafter::pre_connect::SynAckPacket; @@ -244,6 +245,7 @@ pub(crate) mod pre_connect { cnac: &ClientNetworkAccount, packet: HdpPacket, session_manager: &HdpSessionManager, + session_password: &PreSharedKey, ) -> Result { // TODO: NOTE: This can interrupt any active session's. This should be moved up after checking the connect mode let static_auxiliary_ratchet = cnac.refresh_static_hyper_ratchet(); @@ -289,6 +291,7 @@ pub(crate) mod pre_connect { 0, opts, transfer.transfer, + session_password.as_ref(), ) .ok_or(NetworkError::InternalError( "Unable to create bob container", @@ -321,6 +324,7 @@ pub(crate) mod pre_connect { /// This returns an error if the packet is maliciously invalid (e.g., due to a false packet) /// This returns Ok(true) if the system was already synchronized, or Ok(false) if the system needed to synchronize toolsets pub fn validate_syn_ack( + session_password: &PreSharedKey, cnac: &ClientNetworkAccount, mut alice_constructor: StackedRatchetConstructor, packet: HdpPacket, @@ -333,9 +337,10 @@ pub(crate) mod pre_connect { let lvl = packet.transfer.security_level; log::trace!(target: "citadel", "Session security level based-on returned transfer: {:?}", lvl); - if let Err(err) = - alice_constructor.stage1_alice(BobToAliceTransferType::Default(packet.transfer)) - { + if let Err(err) = alice_constructor.stage1_alice( + BobToAliceTransferType::Default(packet.transfer), + session_password.as_ref(), + ) { log::error!(target: "citadel", "Error on stage1_alice: {:?}", err); return None; } diff --git a/citadel_sdk/examples/client.rs b/citadel_sdk/examples/client.rs index ae900308c..bfa98d5c5 100644 --- a/citadel_sdk/examples/client.rs +++ b/citadel_sdk/examples/client.rs @@ -11,7 +11,7 @@ async fn main() { let finished = &AtomicBool::new(false); let client = citadel_sdk::prefabs::client::single_connection - ::SingleClientServerConnectionKernel::new_register("Dummy user", "dummyusername", "notsecurepassword", addr, UdpMode::Enabled, Default::default(), |mut connection, remote| async move { + ::SingleClientServerConnectionKernel::new_register("Dummy user", "dummyusername", "notsecurepassword", addr, UdpMode::Enabled, Default::default(), None, |mut connection, remote| async move { let chan = connection.udp_channel_rx.take(); tokio::task::spawn(citadel_sdk::test_common::udp_mode_assertions(UdpMode::Enabled, chan)) .await.map_err(|err| NetworkError::Generic(err.to_string()))?; diff --git a/citadel_sdk/examples/peer.rs b/citadel_sdk/examples/peer.rs index 16d8a9e7c..a9e1512a1 100644 --- a/citadel_sdk/examples/peer.rs +++ b/citadel_sdk/examples/peer.rs @@ -27,6 +27,7 @@ async fn main() { addr, UdpMode::Enabled, Default::default(), + None, |mut connection, remote| async move { let mut connection = connection.recv().await.unwrap()?; let chan = connection.udp_channel_rx.take(); diff --git a/citadel_sdk/src/builder/node_builder.rs b/citadel_sdk/src/builder/node_builder.rs index 975e45b99..4a9b45ff1 100644 --- a/citadel_sdk/src/builder/node_builder.rs +++ b/citadel_sdk/src/builder/node_builder.rs @@ -24,6 +24,7 @@ pub struct NodeBuilder { client_tls_config: Option, kernel_executor_settings: Option, stun_servers: Option>, + server_session_password: Option, } /// An awaitable future whose return value propagates any internal protocol or kernel-level errors @@ -104,6 +105,8 @@ impl NodeBuilder { .map_err(|err| anyhow::Error::msg(err.into_string()))? }; + let server_only_session_password = self.server_session_password.take(); + Ok(NodeFuture { _pd: Default::default(), inner: Box::pin(async move { @@ -128,6 +131,7 @@ impl NodeBuilder { client_config, kernel_executor_settings, stun_servers, + server_only_session_password, }; log::trace!(target: "citadel", "[NodeBuilder] Creating KernelExecutor ..."); @@ -271,6 +275,15 @@ impl NodeBuilder { self } + /// Sets the pre-shared key for the server. Only a server should set this value + /// If no value is set, any client can connect to the server. If a pre-shared key + /// is specified, the client must have the matching pre-shared key in order to + /// register and connect with the server. + pub fn with_server_password>(&mut self, password: T) -> &mut Self { + self.server_session_password = Some(password.into()); + self + } + fn check(&self) -> anyhow::Result<()> { #[cfg(feature = "google-services")] if let Some(svc) = self.services.as_ref() { diff --git a/citadel_sdk/src/fs.rs b/citadel_sdk/src/fs.rs index 22fe04c38..6797fb4e0 100644 --- a/citadel_sdk/src/fs.rs +++ b/citadel_sdk/src/fs.rs @@ -127,11 +127,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Disabled, session_security_settings, + None, |_channel, remote| async move { log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); @@ -197,11 +198,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Disabled, session_security_settings, + None, |_channel, remote| async move { log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); @@ -269,11 +271,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Disabled, session_security_settings, + None, |_channel, remote| async move { log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); let virtual_path = PathBuf::from("/home/john.doe/TheBridge.pdf"); @@ -345,12 +348,13 @@ mod tests { // TODO: SinglePeerConnectionKernel // to not hold up all conns - let client_kernel0 = PeerConnectionKernel::new_passwordless( + let client_kernel0 = PeerConnectionKernel::new_authless( uuid0, server_addr, vec![uuid1.into()], UdpMode::Disabled, session_security, + None, move |mut connection, remote_outer| async move { wait_for_peers().await; let mut connection = connection.recv().await.unwrap()?; @@ -388,12 +392,13 @@ mod tests { ) .unwrap(); - let client_kernel1 = PeerConnectionKernel::new_passwordless( + let client_kernel1 = PeerConnectionKernel::new_authless( uuid1, server_addr, vec![uuid0.into()], UdpMode::Disabled, session_security, + None, move |mut connection, remote_outer| async move { wait_for_peers().await; let connection = connection.recv().await.unwrap()?; diff --git a/citadel_sdk/src/prefabs/client/broadcast.rs b/citadel_sdk/src/prefabs/client/broadcast.rs index ebad10fb8..2609fc8ed 100644 --- a/citadel_sdk/src/prefabs/client/broadcast.rs +++ b/citadel_sdk/src/prefabs/client/broadcast.rs @@ -416,7 +416,7 @@ mod tests { } }; - let client_kernel = BroadcastKernel::new_passwordless_defaults( + let client_kernel = BroadcastKernel::new_authless_defaults( uuid, server_addr, request, @@ -494,7 +494,7 @@ mod tests { .map(UserIdentifier::from) .collect::>(); - let client_kernel = PeerConnectionKernel::new_passwordless_defaults( + let client_kernel = PeerConnectionKernel::new_authless_defaults( uuid, server_addr, peers, diff --git a/citadel_sdk/src/prefabs/client/mod.rs b/citadel_sdk/src/prefabs/client/mod.rs index 7c026ebd0..25c5cec08 100644 --- a/citadel_sdk/src/prefabs/client/mod.rs +++ b/citadel_sdk/src/prefabs/client/mod.rs @@ -36,6 +36,7 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { arg: Arg, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: Self::UserLevelInputFunction, ) -> Self { let (tx, rx) = tokio::sync::oneshot::channel(); @@ -44,6 +45,7 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { password, udp_mode, session_security_settings, + server_password, |connect_success, remote| { on_channel_received_fn::<_, Self>( connect_success, @@ -73,6 +75,7 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { arg, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } @@ -87,6 +90,7 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { server_addr: V, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: Self::UserLevelInputFunction, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); @@ -97,6 +101,7 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { server_addr, udp_mode, session_security_settings, + server_password, |connect_success, remote| { on_channel_received_fn::<_, Self>( connect_success, @@ -135,25 +140,28 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { server_addr, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } /// Creates a new authless connection with custom arguments - fn new_passwordless( + fn new_authless( uuid: Uuid, server_addr: V, arg: Arg, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: Self::UserLevelInputFunction, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - let server_conn_kernel = SingleClientServerConnectionKernel::new_passwordless( + let server_conn_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, udp_mode, session_security_settings, + server_password, |connect_success, remote| { on_channel_received_fn::<_, Self>( connect_success, @@ -171,18 +179,19 @@ pub trait PrefabFunctions<'a, Arg: Send + 'a>: Sized + 'a { } /// Creates a new authless connection with default arguments - fn new_passwordless_defaults( + fn new_authless_defaults( uuid: Uuid, server_addr: V, arg: Arg, on_channel_received: Self::UserLevelInputFunction, ) -> Result { - Self::new_passwordless( + Self::new_authless( uuid, server_addr, arg, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } diff --git a/citadel_sdk/src/prefabs/client/peer_connection.rs b/citadel_sdk/src/prefabs/client/peer_connection.rs index 33802ffc5..83b4a5119 100644 --- a/citadel_sdk/src/prefabs/client/peer_connection.rs +++ b/citadel_sdk/src/prefabs/client/peer_connection.rs @@ -163,6 +163,7 @@ struct PeerConnectionSettings { session_security_settings: SessionSecuritySettings, udp_mode: UdpMode, ensure_registered: bool, + peer_session_password: Option, } pub struct AddedPeer { @@ -171,6 +172,7 @@ pub struct AddedPeer { session_security_settings: Option, ensure_registered: bool, udp_mode: Option, + peer_session_password: Option, } impl AddedPeer { @@ -181,6 +183,7 @@ impl AddedPeer { session_security_settings: self.session_security_settings.unwrap_or_default(), udp_mode: self.udp_mode.unwrap_or_default(), ensure_registered: self.ensure_registered, + peer_session_password: self.peer_session_password, }; self.list.inner.push(new); @@ -207,6 +210,13 @@ impl AddedPeer { self.ensure_registered = true; self } + + /// Adds a pre-shared key to the peer session password list. Both connecting nodes + /// must have matching passwords in order to establish a connection. Default is None. + pub fn with_session_password>(mut self, password: T) -> Self { + self.peer_session_password = Some(password.into()); + self + } } impl PeerConnectionSetupAggregator { @@ -244,6 +254,7 @@ impl PeerConnectionSetupAggregator { ensure_registered: false, session_security_settings: None, udp_mode: None, + peer_session_password: None, } } } @@ -329,6 +340,7 @@ where session_security_settings, udp_mode, ensure_registered, + peer_session_password, } = peer_to_connect; let task = async move { @@ -357,7 +369,11 @@ where }; handle - .connect_to_peer_custom(session_security_settings, udp_mode) + .connect_to_peer_custom( + session_security_settings, + udp_mode, + peer_session_password, + ) .await .map(|mut success| { let peer_conn = success.channel.get_peer_conn_type().unwrap(); @@ -414,6 +430,7 @@ mod tests { use rstest::rstest; use std::collections::HashMap; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::time::Duration; use uuid::Uuid; lazy_static::lazy_static! { @@ -563,7 +580,7 @@ mod tests { .map(UserIdentifier::from) .collect::>(); - let client_kernel = PeerConnectionKernel::new_passwordless_defaults( + let client_kernel = PeerConnectionKernel::new_authless_defaults( uuid, server_addr, peers, @@ -656,7 +673,7 @@ mod tests { .map(UserIdentifier::from) .collect::>(); - let client_kernel = PeerConnectionKernel::new_passwordless_defaults( + let client_kernel = PeerConnectionKernel::new_authless_defaults( uuid, server_addr, peers, @@ -784,7 +801,7 @@ mod tests { .map(UserIdentifier::from) .collect::>(); - let client_kernel = PeerConnectionKernel::new_passwordless_defaults( + let client_kernel = PeerConnectionKernel::new_authless_defaults( uuid, server_addr, peers, @@ -861,7 +878,7 @@ mod tests { .map(UserIdentifier::from) .collect::>(); - let client_kernel = PeerConnectionKernel::new_passwordless_defaults( + let client_kernel = PeerConnectionKernel::new_authless_defaults( uuid, server_addr, peers, @@ -903,4 +920,108 @@ mod tests { assert_eq!(client_success.load(Ordering::Relaxed), peer_count); Ok(()) } + + #[rstest] + #[case(SecrecyMode::BestEffort, Some("test-p2p-password"))] + #[timeout(std::time::Duration::from_secs(240))] + #[tokio::test(flavor = "multi_thread")] + async fn test_p2p_wrong_session_password( + #[case] secrecy_mode: SecrecyMode, + #[case] p2p_password: Option<&'static str>, + #[values(KemAlgorithm::Kyber)] kem: KemAlgorithm, + #[values(EncryptionAlgorithm::AES_GCM_256)] enx: EncryptionAlgorithm, + ) { + citadel_logging::setup_log_no_panic_hook(); + crate::test_common::TestBarrier::setup(2); + let (server, server_addr) = server_info(); + let peer_0_error_received = &AtomicBool::new(false); + let peer_1_error_received = &AtomicBool::new(false); + + let uuid0 = Uuid::new_v4(); + let uuid1 = Uuid::new_v4(); + let session_security = SessionSecuritySettingsBuilder::default() + .with_secrecy_mode(secrecy_mode) + .with_crypto_params(kem + enx) + .build() + .unwrap(); + + let mut peer0_agg = PeerConnectionSetupAggregator::default() + .with_peer_custom(uuid1) + .with_session_security_settings(session_security); + + if let Some(password) = p2p_password { + peer0_agg = peer0_agg.with_session_password(password); + } + + let peer0_connection = peer0_agg.add(); + + let mut peer1_agg = PeerConnectionSetupAggregator::default() + .with_peer_custom(uuid0) + .with_session_security_settings(session_security); + + if let Some(_password) = p2p_password { + peer1_agg = peer1_agg.with_session_password("wrong password"); + } + + let peer1_connection = peer1_agg.add(); + + let client_kernel0 = PeerConnectionKernel::new_authless( + uuid0, + server_addr, + peer0_connection, + UdpMode::Enabled, + session_security, + None, + move |mut connection, remote| async move { + wait_for_peers().await; + let conn = connection.recv().await.unwrap(); + log::trace!(target: "citadel", "Peer 0 {} received: {:?}", remote.conn_type.get_implicated_cid(), conn); + if conn.is_ok() { + peer_0_error_received.store(true, Ordering::SeqCst); + } + wait_for_peers().await; + remote.shutdown_kernel().await + }, + ) + .unwrap(); + + let client_kernel1 = PeerConnectionKernel::new_authless( + uuid1, + server_addr, + peer1_connection, + UdpMode::Enabled, + session_security, + None, + move |mut connection, remote| async move { + wait_for_peers().await; + let conn = connection.recv().await.unwrap(); + log::trace!(target: "citadel", "Peer 1 {} received: {:?}", remote.conn_type.get_implicated_cid(), conn); + if conn.is_ok() { + peer_1_error_received.store(true, Ordering::SeqCst); + } + wait_for_peers().await; + remote.shutdown_kernel().await + }, + ) + .unwrap(); + + let client0 = NodeBuilder::default().build(client_kernel0).unwrap(); + let client1 = NodeBuilder::default().build(client_kernel1).unwrap(); + let clients = futures::future::try_join(client0, client1); + + let task = async move { + tokio::select! { + server_res = server => Err(NetworkError::msg(format!("Server ended prematurely: {:?}", server_res.map(|_| ())))), + client_res = clients => client_res.map(|_| ()) + } + }; + + tokio::time::timeout(Duration::from_secs(120), task) + .await + .unwrap() + .unwrap(); + + assert!(!peer_0_error_received.load(Ordering::SeqCst)); + assert!(!peer_1_error_received.load(Ordering::SeqCst)); + } } diff --git a/citadel_sdk/src/prefabs/client/single_connection.rs b/citadel_sdk/src/prefabs/client/single_connection.rs index 3ce416714..16bbaa58f 100644 --- a/citadel_sdk/src/prefabs/client/single_connection.rs +++ b/citadel_sdk/src/prefabs/client/single_connection.rs @@ -20,6 +20,7 @@ pub struct SingleClientServerConnectionKernel { session_security_settings: SessionSecuritySettings, unprocessed_signal_filter_tx: Mutex>>, remote: Option, + server_password: Option, // by using fn() -> Fut, the future does not need to be Sync _pd: PhantomData Fut>, } @@ -53,6 +54,7 @@ where password: P, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: F, ) -> Self { Self { @@ -65,6 +67,7 @@ where session_security_settings, unprocessed_signal_filter_tx: Default::default(), remote: None, + server_password, _pd: Default::default(), } } @@ -80,11 +83,13 @@ where password, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } /// First registers with a central server with the proposed credentials, and thereafter, establishes a connection with custom parameters + #[allow(clippy::too_many_arguments)] pub fn new_register, R: Into, P: Into, V: ToSocketAddrs>( full_name: T, username: R, @@ -92,6 +97,7 @@ where server_addr: V, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: F, ) -> Result { let server_addr = get_socket_addr(server_addr)?; @@ -107,6 +113,7 @@ where session_security_settings, unprocessed_signal_filter_tx: Default::default(), remote: None, + server_password, _pd: Default::default(), }) } @@ -131,16 +138,18 @@ where server_addr, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } /// Creates a new authless connection with custom arguments - pub fn new_passwordless( + pub fn new_authless( uuid: Uuid, server_addr: V, udp_mode: UdpMode, session_security_settings: SessionSecuritySettings, + server_password: Option, on_channel_received: F, ) -> Result { let server_addr = get_socket_addr(server_addr)?; @@ -150,6 +159,7 @@ where auth_info: Mutex::new(Some(ConnectionType::Passwordless { uuid, server_addr })), session_security_settings, unprocessed_signal_filter_tx: Default::default(), + server_password, remote: None, _pd: Default::default(), }) @@ -161,11 +171,12 @@ where server_addr: V, on_channel_received: F, ) -> Result { - Self::new_passwordless( + Self::new_authless( uuid, server_addr, Default::default(), Default::default(), + Default::default(), on_channel_received, ) } @@ -216,6 +227,7 @@ where username.as_str(), password.clone(), self.session_security_settings, + self.server_password.clone(), ) .await?; } @@ -239,6 +251,7 @@ where self.udp_mode, None, self.session_security_settings, + self.server_password.clone(), ) .await?; let conn_type = VirtualTargetType::LocalGroupServer { @@ -357,6 +370,7 @@ mod tests { server_addr, udp_mode, Default::default(), + None, |channel, remote| async move { log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***"); wait_for_peers().await; @@ -379,12 +393,14 @@ mod tests { } #[rstest] - #[case(false, UdpMode::Enabled)] + #[case(false, UdpMode::Enabled, None)] + #[case(false, UdpMode::Enabled, Some("test-password"))] #[timeout(std::time::Duration::from_secs(90))] #[tokio::test(flavor = "multi_thread")] async fn test_single_connection_passwordless( #[case] debug_force_nat_timeout: bool, #[case] udp_mode: UdpMode, + #[case] server_password: Option<&'static str>, ) { citadel_logging::setup_log(); TestBarrier::setup(2); @@ -401,16 +417,21 @@ mod tests { |conn, remote| async move { default_server_harness(udp_mode, conn, remote, server_success).await }, - |_| (), + |opts| { + if let Some(password) = server_password { + let _ = opts.with_server_password(password); + } + }, ); let uuid = Uuid::new_v4(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, udp_mode, Default::default(), + server_password.map(|x| x.into()), |channel, remote| async move { log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***"); wait_for_peers().await; @@ -433,6 +454,59 @@ mod tests { assert!(server_success.load(Ordering::Relaxed)); } + #[cfg(feature = "multi-threaded")] + #[rstest] + #[case(false, UdpMode::Enabled, Some("test-password"))] + #[timeout(std::time::Duration::from_secs(90))] + #[tokio::test(flavor = "multi_thread")] + async fn test_single_connection_passwordless_wrong_password( + #[case] debug_force_nat_timeout: bool, + #[case] udp_mode: UdpMode, + #[case] server_password: Option<&'static str>, + ) { + citadel_logging::setup_log(); + TestBarrier::setup(2); + + if debug_force_nat_timeout { + std::env::set_var("debug_cause_timeout", "ON"); + } else { + std::env::remove_var("debug_cause_timeout"); + } + + let (server, server_addr) = server_info_reactive( + |_conn, _remote| async move { panic!("Server should not have connected") }, + |opts| { + if let Some(password) = server_password { + let _ = opts.with_server_password(password); + } + }, + ); + + let uuid = Uuid::new_v4(); + + let client_kernel = SingleClientServerConnectionKernel::new_authless( + uuid, + server_addr, + udp_mode, + Default::default(), + Some("wrong-password".into()), + |_channel, _remote| async move { panic!("Client should not have connected") }, + ) + .unwrap(); + + // Spawn the server, since the server won't quit when a bad connection is made; + let _server = tokio::spawn(server); + + let client = NodeBuilder::default().build(client_kernel).unwrap(); + + let result = client.await; + if let Err(error) = result { + assert!(error.into_string().contains("EncryptionFailure")); + } else { + panic!("Client should not have connected") + } + } + #[rstest] #[case(UdpMode::Disabled)] #[timeout(std::time::Duration::from_secs(90))] @@ -453,11 +527,12 @@ mod tests { let uuid = Uuid::new_v4(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, udp_mode, Default::default(), + None, |channel, remote| async move { log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***"); wait_for_peers().await; @@ -500,11 +575,12 @@ mod tests { let uuid = Uuid::new_v4(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, udp_mode, Default::default(), + None, |channel, remote| async move { log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***"); wait_for_peers().await; @@ -576,11 +652,12 @@ mod tests { let uuid = Uuid::new_v4(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, udp_mode, Default::default(), + None, |_channel, remote| async move { log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***"); wait_for_peers().await; diff --git a/citadel_sdk/src/remote_ext.rs b/citadel_sdk/src/remote_ext.rs index 54d60d51d..87b990424 100644 --- a/citadel_sdk/src/remote_ext.rs +++ b/citadel_sdk/src/remote_ext.rs @@ -139,6 +139,7 @@ pub trait ProtocolRemoteExt: Remote { username: V, proposed_password: K, default_security_settings: SessionSecuritySettings, + server_password: Option, ) -> Result { let creds = ProposedCredentials::new_register(full_name, username, proposed_password.into()) @@ -150,6 +151,7 @@ pub trait ProtocolRemoteExt: Remote { .ok_or(NetworkError::InternalError("Invalid socket addr"))?, proposed_credentials: creds, static_security_settings: default_security_settings, + session_password: server_password.unwrap_or_default(), }); let mut subscription = self.send_callback_subscription(register_request).await?; @@ -192,6 +194,7 @@ pub trait ProtocolRemoteExt: Remote { username, proposed_password, Default::default(), + Default::default(), ) .await } @@ -205,6 +208,7 @@ pub trait ProtocolRemoteExt: Remote { udp_mode: UdpMode, keep_alive_timeout: Option, session_security_settings: SessionSecuritySettings, + server_password: Option, ) -> Result { let connect_request = NodeRequest::ConnectToHypernode(ConnectToHypernode { auth_request: auth, @@ -212,6 +216,7 @@ pub trait ProtocolRemoteExt: Remote { udp_mode, keep_alive_timeout: keep_alive_timeout.map(|r| r.as_secs()), session_security_settings, + session_password: server_password.unwrap_or_default(), }); let mut subscription = self.send_callback_subscription(connect_request).await?; @@ -265,6 +270,7 @@ pub trait ProtocolRemoteExt: Remote { Default::default(), None, Default::default(), + Default::default(), ) .await } @@ -451,6 +457,15 @@ pub trait ProtocolRemoteExt: Remote { pub fn map_errors(result: NodeResult) -> Result { match result { + NodeResult::ConnectFail(ConnectFail { + ticket: _, + cid_opt: _, + error_message: err, + }) => Err(NetworkError::Generic(err)), + NodeResult::RegisterFailure(RegisterFailure { + ticket: _, + error_message: err, + }) => Err(NetworkError::Generic(err)), NodeResult::InternalServerError(InternalServerError { ticket_opt: _, cid_opt: _, @@ -461,6 +476,7 @@ pub fn map_errors(result: NodeResult) -> Result { PeerSignal::SignalError { ticket: _, error: err, + peer_connection_type: _, }, ticket: _, .. @@ -697,6 +713,7 @@ pub trait ProtocolRemoteTargetExt: TargetLockedRemote { &self, session_security_settings: SessionSecuritySettings, udp_mode: UdpMode, + peer_session_password: Option, ) -> Result { let implicated_cid = self.user().get_implicated_cid(); let peer_target = self.try_as_peer_connection().await?; @@ -711,6 +728,7 @@ pub trait ProtocolRemoteTargetExt: TargetLockedRemote { invitee_response: None, session_security_settings, udp_mode, + session_password: peer_session_password, }, })) .await?; @@ -765,7 +783,7 @@ pub trait ProtocolRemoteTargetExt: TargetLockedRemote { /// Connects to the target peer with default settings async fn connect_to_peer(&self) -> Result { - self.connect_to_peer_custom(Default::default(), Default::default()) + self.connect_to_peer_custom(Default::default(), Default::default(), Default::default()) .await } @@ -1321,11 +1339,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Disabled, session_security_settings, + None, |_channel, remote| async move { log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS :: File transfer next ***"); remote diff --git a/citadel_sdk/src/responses.rs b/citadel_sdk/src/responses.rs index 9d2ad10eb..778e7ffa6 100644 --- a/citadel_sdk/src/responses.rs +++ b/citadel_sdk/src/responses.rs @@ -68,6 +68,7 @@ pub async fn peer_connect( input_signal: PeerSignal, accept: bool, remote: &impl Remote, + peer_session_password: Option, ) -> Result { if let PeerSignal::PostConnect { peer_conn_type: v_conn, @@ -75,6 +76,7 @@ pub async fn peer_connect( invitee_response: None, session_security_settings: sess_sec, udp_mode, + session_password: None, } = input_signal { let this_cid = v_conn.get_original_target_cid(); @@ -94,6 +96,7 @@ pub async fn peer_connect( invitee_response: Some(resp), session_security_settings: sess_sec, udp_mode, + session_password: peer_session_password, }, }); remote diff --git a/citadel_sdk/tests/stress_tests.rs b/citadel_sdk/tests/stress_tests.rs index 14f1a2b7b..a7ec08e8d 100644 --- a/citadel_sdk/tests/stress_tests.rs +++ b/citadel_sdk/tests/stress_tests.rs @@ -234,11 +234,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Enabled, session_security, + None, move |connection, remote| async move { log::trace!(target: "citadel", "*** CLIENT RECV CHANNEL ***"); handle_send_receive_e2e(get_barrier(), connection.channel, message_count).await?; @@ -262,13 +263,14 @@ mod tests { } #[rstest] - #[case(100, SecrecyMode::Perfect)] - #[case(100, SecrecyMode::BestEffort)] + #[case(100, SecrecyMode::Perfect, None)] + #[case(100, SecrecyMode::BestEffort, Some("test-password"))] #[timeout(std::time::Duration::from_secs(240))] #[tokio::test(flavor = "multi_thread")] async fn stress_test_c2s_messaging_kyber( #[case] message_count: usize, #[case] secrecy_mode: SecrecyMode, + #[case] server_password: Option<&'static str>, #[values(KemAlgorithm::Kyber)] kem: KemAlgorithm, #[values(EncryptionAlgorithm::Kyber)] enx: EncryptionAlgorithm, ) { @@ -289,7 +291,11 @@ mod tests { SERVER_SUCCESS.store(true, Ordering::Relaxed); remote.shutdown_kernel().await }, - |_| {}, + |node| { + if let Some(password) = server_password { + node.with_server_password(password); + } + }, ); let uuid = Uuid::new_v4(); @@ -299,11 +305,12 @@ mod tests { .build() .unwrap(); - let client_kernel = SingleClientServerConnectionKernel::new_passwordless( + let client_kernel = SingleClientServerConnectionKernel::new_authless( uuid, server_addr, UdpMode::Enabled, session_security, + server_password.map(|p| p.into()), move |connection, remote| async move { log::trace!(target: "citadel", "*** CLIENT RECV CHANNEL ***"); handle_send_receive_e2e(get_barrier(), connection.channel, message_count).await?; @@ -327,13 +334,14 @@ mod tests { } #[rstest] - #[case(500, SecrecyMode::Perfect)] - #[case(500, SecrecyMode::BestEffort)] + #[case(500, SecrecyMode::Perfect, None)] + #[case(500, SecrecyMode::BestEffort, Some("test-p2p-password"))] #[timeout(std::time::Duration::from_secs(240))] #[tokio::test(flavor = "multi_thread")] async fn stress_test_p2p_messaging( #[case] message_count: usize, #[case] secrecy_mode: SecrecyMode, + #[case] p2p_password: Option<&'static str>, #[values(KemAlgorithm::Kyber)] kem: KemAlgorithm, #[values( EncryptionAlgorithm::AES_GCM_256, @@ -357,14 +365,33 @@ mod tests { .build() .unwrap(); - // TODO: SinglePeerConnectionKernel - // to not hold up all conns - let client_kernel0 = PeerConnectionKernel::new_passwordless( + let mut peer0_agg = PeerConnectionSetupAggregator::default() + .with_peer_custom(uuid1) + .with_session_security_settings(session_security); + + if let Some(password) = p2p_password { + peer0_agg = peer0_agg.with_session_password(password); + } + + let peer0_connection = peer0_agg.add(); + + let mut peer1_agg = PeerConnectionSetupAggregator::default() + .with_peer_custom(uuid0) + .with_session_security_settings(session_security); + + if let Some(password) = p2p_password { + peer1_agg = peer1_agg.with_session_password(password); + } + + let peer1_connection = peer1_agg.add(); + + let client_kernel0 = PeerConnectionKernel::new_authless( uuid0, server_addr, - vec![uuid1.into()], + peer0_connection, UdpMode::Enabled, session_security, + None, move |mut connection, remote| async move { handle_send_receive_e2e( get_barrier(), @@ -379,12 +406,13 @@ mod tests { ) .unwrap(); - let client_kernel1 = PeerConnectionKernel::new_passwordless( + let client_kernel1 = PeerConnectionKernel::new_authless( uuid1, server_addr, - vec![uuid0.into()], + peer1_connection, UdpMode::Enabled, session_security, + None, move |mut connection, remote| async move { handle_send_receive_e2e( get_barrier(), @@ -457,7 +485,7 @@ mod tests { } }; - let client_kernel = BroadcastKernel::new_passwordless_defaults( + let client_kernel = BroadcastKernel::new_authless_defaults( uuid, server_addr, request, diff --git a/citadel_user/tests/primary.rs b/citadel_user/tests/primary.rs index fa5c46adf..fdf4a1c97 100644 --- a/citadel_user/tests/primary.rs +++ b/citadel_user/tests/primary.rs @@ -1450,11 +1450,19 @@ mod tests { ); let mut alice = StackedRatchetConstructor::new_alice(opts.clone(), cid, version, None).unwrap(); - let bob = - StackedRatchetConstructor::new_bob(cid, version, opts, alice.stage0_alice().unwrap()) - .unwrap(); + let bob = StackedRatchetConstructor::new_bob( + cid, + version, + opts, + alice.stage0_alice().unwrap(), + &[], + ) + .unwrap(); alice - .stage1_alice(BobToAliceTransferType::Default(bob.stage0_bob().unwrap())) + .stage1_alice( + BobToAliceTransferType::Default(bob.stage0_bob().unwrap()), + &[], + ) .unwrap(); let bob = if let Some(cid) = endpoint_bob_cid { bob.finish_with_custom_cid(cid).unwrap()