Skip to content

Commit

Permalink
replace dashmap
Browse files Browse the repository at this point in the history
  • Loading branch information
fikersd committed Nov 17, 2023
1 parent 54cd466 commit 828764a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 47 deletions.
13 changes: 6 additions & 7 deletions akasa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ edition = "2021"

[dependencies]
bytes = "1.2.1"
dashmap = "5.4.0"
flume = { version = "0.10", features = ["async"] }
flume = { version = "0.11", features = ["async"] }
futures-io = "0.3.25"
futures-lite = "1.12.0"
hashbrown = "0.13.1"
futures-lite = "2.0.1"
hashbrown = "0.14.2"
hex = { version = "0.4", features = ["serde"] }
log = "0.4.17"
mqtt-proto = { git = "https://github.com/akasamq/mqtt-proto.git", branch = "master" }
Expand All @@ -32,15 +31,15 @@ futures-sink = "0.3.26"
futures-util = "0.3.26"
async-trait = "0.1.64"
base64 = "0.21.0"
ring = "0.16"
ring = "0.17"
crc32c = "0.6.3"
openssl = "0.10.51"
async-tungstenite = "0.21.0"
async-tungstenite = "0.23.0"

[target.'cfg(target_os = "linux")'.dependencies]
glommio = { version = "0.8.0" }

[dev-dependencies]
futures-sink = "0.3.26"
tokio-util = "0.7.7"
env_logger = "0.9.3"
env_logger = "0.10.1"
21 changes: 11 additions & 10 deletions akasa-core/src/protocols/mqtt/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::io::{self, BufRead, BufReader, BufWriter, Read, Write};
use std::num::NonZeroU32;

use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine};
use dashmap::DashMap;
use hashbrown::HashMap;
use parking_lot::RwLock;
use ring::{
digest::{Context, SHA256, SHA256_OUTPUT_LEN, SHA512, SHA512_OUTPUT_LEN},
pbkdf2::{self, PBKDF2_HMAC_SHA256, PBKDF2_HMAC_SHA512},
Expand All @@ -12,9 +13,9 @@ use crate::state::{AuthPassword, HashAlgorithm};

pub const MIN_SALT_LEN: usize = 12;

pub fn load_passwords<R: Read>(input: R) -> io::Result<DashMap<String, AuthPassword>> {
pub fn load_passwords<R: Read>(input: R) -> io::Result<RwLock<HashMap<String, AuthPassword>>> {
let reader = BufReader::with_capacity(2048, input);
let passwords = DashMap::new();
let passwords = RwLock::new(HashMap::new());
for (line_num, line_result) in reader.lines().enumerate() {
let line = line_result?;
let text = line.trim();
Expand Down Expand Up @@ -121,7 +122,7 @@ pub fn load_passwords<R: Read>(input: R) -> io::Result<DashMap<String, AuthPassw
hashed_password,
salt,
};
passwords.insert(username.to_owned(), item);
passwords.write().insert(username.to_owned(), item);
} else {
log::error!("invalid password line(#{}): {}", line_num, line);
return Err(io::ErrorKind::InvalidData.into());
Expand All @@ -132,16 +133,16 @@ pub fn load_passwords<R: Read>(input: R) -> io::Result<DashMap<String, AuthPassw

pub fn dump_passwords<W: Write>(
output: W,
passwords: &DashMap<String, AuthPassword>,
passwords: &RwLock<HashMap<String, AuthPassword>>,
) -> io::Result<()> {
let mut writer = BufWriter::new(output);
for item in passwords.iter() {
let username = item.key();
for item in passwords.read().iter() {
let username = item.0;
let AuthPassword {
hash_algorithm,
hashed_password,
salt,
} = item.value();
} = item.1;
let base64_hashed_password = STANDARD_NO_PAD.encode(hashed_password);
let base64_salt = STANDARD_NO_PAD.encode(salt);
let line = match hash_algorithm {
Expand All @@ -164,11 +165,11 @@ pub fn dump_passwords<W: Write>(
}

pub fn check_password(
passwords: &DashMap<String, AuthPassword>,
passwords: &RwLock<HashMap<String, AuthPassword>>,
username: &str,
password: &[u8],
) -> bool {
if let Some(item) = passwords.get(username) {
if let Some(item) = passwords.read().get(username) {
match item.hash_algorithm {
HashAlgorithm::Sha256 => {
let mut ctx = Context::new(&SHA256);
Expand Down
51 changes: 29 additions & 22 deletions akasa-core/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::sync::Arc;
use std::time::{Duration, Instant};

use bytes::Bytes;
use dashmap::DashMap;
use flume::{bounded, Receiver, Sender};
use hashbrown::HashMap;
use mqtt_proto::{v5::PublishProperties, Protocol, QoS, TopicFilter, TopicName};
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};

use crate::config::Config;
use crate::protocols::mqtt::{self, RetainTable, RouteTable};
Expand All @@ -23,14 +23,14 @@ pub struct GlobalState {
// online clients count
online_clients: AtomicU64,
// client internal id => (MQTT client identifier, online)
client_id_map: DashMap<ClientId, (String, bool)>,
client_id_map: RwLock<HashMap<ClientId, (String, bool)>>,
// MQTT client identifier => client internal id
client_identifier_map: DashMap<String, ClientId>,
client_identifier_map: RwLock<HashMap<String, ClientId>>,
// All clients (online/offline clients)
clients: DashMap<ClientId, ClientSender>,
clients: RwLock<HashMap<ClientId, ClientSender>>,

pub config: Config,
pub auth_passwords: DashMap<String, AuthPassword>,
pub auth_passwords: RwLock<HashMap<String, AuthPassword>>,

/// MQTT route table
pub route_table: RouteTable,
Expand Down Expand Up @@ -72,12 +72,12 @@ impl GlobalState {
// FIXME: load from db (rosksdb or sqlite3)
next_client_id: Mutex::new(ClientId(0)),
online_clients: AtomicU64::new(0),
client_id_map: DashMap::new(),
client_identifier_map: DashMap::new(),
clients: DashMap::new(),
client_id_map: RwLock::new(HashMap::new()),
client_identifier_map: RwLock::new(HashMap::new()),
clients: RwLock::new(HashMap::new()),

config,
auth_passwords: DashMap::new(),
auth_passwords: RwLock::new(HashMap::new()),
route_table: RouteTable::default(),
retain_table: RetainTable::default(),
}
Expand All @@ -90,7 +90,7 @@ impl GlobalState {
// self.clients.len() - *self.online_clients.lock()
// }
pub fn clients_count(&self) -> usize {
self.clients.len()
self.clients.read().len()
}

// When clean_session=1 and client disconnected
Expand All @@ -101,13 +101,15 @@ impl GlobalState {
) {
// keep client operation atomic
let _guard = self.next_client_id.lock();
if let Some((_, (client_identifier, online))) = self.client_id_map.remove(&client_id) {
self.client_identifier_map.remove(&client_identifier);
if let Some((client_identifier, online)) = self.client_id_map.write().remove(&client_id) {
self.client_identifier_map
.write()
.remove(&client_identifier);
if online {
assert_ne!(self.online_clients.fetch_sub(1, Ordering::AcqRel), 0);
}
}
self.clients.remove(&client_id);
self.clients.write().remove(&client_id);
for filter in subscribes {
self.route_table.unsubscribe(filter, client_id);
}
Expand All @@ -117,8 +119,8 @@ impl GlobalState {
pub fn offline_client(&self, client_id: ClientId) {
let _guard = self.next_client_id.lock();
assert_ne!(self.online_clients.fetch_sub(1, Ordering::AcqRel), 0);
if let Some(mut pair) = self.client_id_map.get_mut(&client_id) {
pair.value_mut().1 = false;
if let Some(pair) = self.client_id_map.write().get_mut(&client_id) {
pair.1 = false;
}
}

Expand All @@ -127,16 +129,18 @@ impl GlobalState {
client_id: &ClientId,
) -> Option<Sender<(ClientId, NormalMessage)>> {
self.clients
.read()
.get(client_id)
.map(|pair| pair.value().normal.clone())
.map(|pair| pair.normal.clone())
}
pub fn get_client_control_sender(
&self,
client_id: &ClientId,
) -> Option<Sender<ControlMessage>> {
self.clients
.read()
.get(client_id)
.map(|pair| pair.value().control.clone())
.map(|pair| pair.control.clone())
}

// Client connected
Expand All @@ -151,18 +155,21 @@ impl GlobalState {
self.online_clients.fetch_add(1, Ordering::AcqRel);
let client_id_opt: Option<ClientId> = self
.client_identifier_map
.read()
.get(client_identifier)
.map(|pair| *pair.value());
.copied();
if let Some(old_id) = client_id_opt {
if let Some(mut pair) = self.client_id_map.get_mut(&old_id) {
pair.value_mut().1 = true;
if let Some(pair) = self.client_id_map.write().get_mut(&old_id) {
pair.1 = true;
}
self.get_client_control_sender(&old_id).unwrap()
} else {
let client_id = *next_client_id;
self.client_id_map
.write()
.insert(client_id, (client_identifier.to_string(), true));
self.client_identifier_map
.write()
.insert(client_identifier.to_string(), client_id);
// FIXME: if some one subscribe topic "#" and never receive the message it will block all sender clients.
// Suggestion: Add QoS0 message to pending queue
Expand All @@ -172,7 +179,7 @@ impl GlobalState {
normal: normal_sender,
control: control_sender,
};
self.clients.insert(client_id, sender);
self.clients.write().insert(client_id, sender);
next_client_id.0 += 1;
return Ok(AddClientReceipt::New {
client_id,
Expand Down
2 changes: 1 addition & 1 deletion akasa-core/src/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl GlobalState {
let mut salt = vec![0u8; MIN_SALT_LEN];
OsRng.fill_bytes(&mut salt);
let hashed_password = hash_password(algo, &salt, password.as_bytes());
self.auth_passwords.insert(
self.auth_passwords.write().insert(
username.to_owned(),
AuthPassword {
hash_algorithm: algo,
Expand Down
5 changes: 3 additions & 2 deletions akasa/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ edition = "2021"
akasa-core = { version = "0.1.0", path = "../akasa-core" }
anyhow = "1.0.66"
clap = { version = "4.0.26", features = ["derive"] }
hashbrown = "0.14.2"
serde_yaml = "0.9.14"
async-trait = "0.1.64"
env_logger = "0.9.3"
env_logger = "0.10.1"
log = "0.4.17"
dashmap = "5.4.0"
rpassword = "7.2.0"
rand = { version = "0.8.5", features = ["getrandom"] }
parking_lot = "0.12.1"
11 changes: 6 additions & 5 deletions akasa/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use akasa_core::{
};
use anyhow::{anyhow, bail};
use clap::{Parser, Subcommand, ValueEnum};
use dashmap::DashMap;
use hashbrown::HashMap;
use parking_lot::RwLock;
use rand::{rngs::OsRng, RngCore};

use default_hook::DefaultHook;
Expand Down Expand Up @@ -124,7 +125,7 @@ fn main() -> anyhow::Result<()> {
fs::File::open(path).map_err(|err| anyhow!("load passwords: {}", err))?;
load_passwords(file)?
} else {
DashMap::new()
RwLock::new(HashMap::new())
};
let mut global_state = GlobalState::new(config);
global_state.auth_passwords = auth_passwords;
Expand Down Expand Up @@ -174,12 +175,12 @@ fn main() -> anyhow::Result<()> {
let hashed_password = hash_password(hash_algorithm, &salt, password.as_bytes());

let auth_passwords = if create {
DashMap::new()
RwLock::new(HashMap::new())
} else {
load_passwords(&fs::File::open(&path)?)?
};

auth_passwords.insert(
auth_passwords.write().insert(
username.clone(),
AuthPassword {
hash_algorithm,
Expand All @@ -193,7 +194,7 @@ fn main() -> anyhow::Result<()> {
}
Commands::RemovePassword { path, username } => {
let auth_passwords = load_passwords(&fs::File::open(&path)?)?;
if auth_passwords.remove(&username).is_some() {
if auth_passwords.write().remove(&username).is_some() {
let file = fs::File::create(&path)?;
dump_passwords(&file, &auth_passwords)?;
println!("removed user={username} from {path:?}");
Expand Down

0 comments on commit 828764a

Please sign in to comment.