Skip to content

Commit

Permalink
Implement tracked radios cache
Browse files Browse the repository at this point in the history
  • Loading branch information
kurotych committed Jan 8, 2025
1 parent cd3f79a commit d91ef05
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 101 deletions.
46 changes: 2 additions & 44 deletions mobile_config/src/gateway_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,10 @@ pub(crate) mod db {
use super::{DeviceType, GatewayInfo, GatewayMetadata};
use crate::gateway_info::DeploymentInfo;
use chrono::{DateTime, Utc};
use futures::{
stream::{Stream, StreamExt},
TryStreamExt,
};
use futures::stream::{Stream, StreamExt};
use helium_crypto::PublicKeyBinary;
use sqlx::{types::Json, PgExecutor, Row};
use std::{collections::HashMap, str::FromStr};
use std::str::FromStr;

const GET_METADATA_SQL: &str = r#"
select kta.entity_key, infos.location::bigint, infos.device_type,
Expand All @@ -380,50 +377,11 @@ pub(crate) mod db {
const BATCH_SQL_WHERE_SNIPPET: &str = " where kta.entity_key = any($1::bytea[]) ";
const DEVICE_TYPES_WHERE_SNIPPET: &str = " where device_type::text = any($1) ";

const GET_UPDATED_RADIOS: &str =
"SELECT entity_key, last_changed_at FROM mobile_radio_tracker WHERE last_changed_at >= $1";

const GET_UPDATED_AT: &str =
"SELECT last_changed_at FROM mobile_radio_tracker WHERE entity_key = $1";

lazy_static::lazy_static! {
static ref BATCH_METADATA_SQL: String = format!("{GET_METADATA_SQL} {BATCH_SQL_WHERE_SNIPPET}");
static ref DEVICE_TYPES_METADATA_SQL: String = format!("{GET_METADATA_SQL} {DEVICE_TYPES_WHERE_SNIPPET}");
}

pub async fn get_updated_radios(
db: impl PgExecutor<'_>,
min_updated_at: DateTime<Utc>,
) -> anyhow::Result<HashMap<PublicKeyBinary, DateTime<Utc>>> {
sqlx::query(GET_UPDATED_RADIOS)
.bind(min_updated_at)
.fetch(db)
.map_err(anyhow::Error::from)
.try_fold(
HashMap::new(),
|mut map: HashMap<PublicKeyBinary, DateTime<Utc>>, row| async move {
let entity_key_b = row.get::<&[u8], &str>("entity_key");
let entity_key = bs58::encode(entity_key_b).into_string();
let updated_at = row.get::<DateTime<Utc>, &str>("last_changed_at");
map.insert(PublicKeyBinary::from_str(&entity_key)?, updated_at);
Ok(map)
},
)
.await
}

pub async fn get_updated_at(
db: impl PgExecutor<'_>,
address: &PublicKeyBinary,
) -> anyhow::Result<Option<DateTime<Utc>>> {
let entity_key = bs58::decode(address.to_string()).into_vec()?;
sqlx::query_scalar(GET_UPDATED_AT)
.bind(entity_key)
.fetch_optional(db)
.await
.map_err(anyhow::Error::from)
}

pub async fn get_info(
db: impl PgExecutor<'_>,
address: &PublicKeyBinary,
Expand Down
59 changes: 26 additions & 33 deletions mobile_config/src/gateway_service.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::{
gateway_info::{self, db::get_updated_radios, DeviceType, GatewayInfo},
gateway_info::{self, DeviceType, GatewayInfo},
key_cache::KeyCache,
mobile_radio_tracker::TrackedRadiosMap,
telemetry, verify_public_key, GrpcResult, GrpcStreamResult,
};
use chrono::{DateTime, TimeZone, Utc};
use file_store::traits::{MsgVerify, TimestampEncode};
use futures::{
future,
stream::{Stream, StreamExt, TryStreamExt},
TryFutureExt,
};
Expand All @@ -20,28 +20,29 @@ use helium_proto::{
Message,
};
use sqlx::{Pool, Postgres};
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use tokio::sync::RwLock;
use tonic::{Request, Response, Status};

pub struct GatewayService {
key_cache: KeyCache,
mobile_config_db_pool: Pool<Postgres>,
metadata_pool: Pool<Postgres>,
signing_key: Arc<Keypair>,
tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>>,
}

impl GatewayService {
pub fn new(
key_cache: KeyCache,
metadata_pool: Pool<Postgres>,
signing_key: Keypair,
mobile_config_db_pool: Pool<Postgres>,
tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>>,
) -> Self {
Self {
key_cache,
metadata_pool,
signing_key: Arc::new(signing_key),
mobile_config_db_pool,
tracked_radios_cache,
}
}

Expand Down Expand Up @@ -129,11 +130,10 @@ impl mobile_config::Gateway for GatewayService {
let pubkey: PublicKeyBinary = request.address.into();
tracing::debug!(pubkey = pubkey.to_string(), "fetching gateway info (v2)");

let updated_at = gateway_info::db::get_updated_at(&self.mobile_config_db_pool, &pubkey)
.await
.map_err(|_| {
Status::internal("error fetching updated_at field for gateway info (v2)")
})?;
let updated_at = {
let tracked_radios = self.tracked_radios_cache.read().await;
tracked_radios.get(&pubkey).cloned()
};

gateway_info::db::get_info(&self.metadata_pool, &pubkey)
.await
Expand Down Expand Up @@ -230,7 +230,6 @@ impl mobile_config::Gateway for GatewayService {
);

let metadata_db_pool = self.metadata_pool.clone();
let mobile_config_db_pool = self.mobile_config_db_pool.clone();
let signing_key = self.signing_key.clone();
let batch_size = request.batch_size;
let addresses = request
Expand All @@ -241,18 +240,14 @@ impl mobile_config::Gateway for GatewayService {

let (tx, rx) = tokio::sync::mpsc::channel(100);

let radios_cache = Arc::clone(&self.tracked_radios_cache);
tokio::spawn(async move {
let min_updated_at = DateTime::UNIX_EPOCH;
let updated_radios = get_updated_radios(&mobile_config_db_pool, min_updated_at).await?;

let stream = gateway_info::db::batch_info_stream(&metadata_db_pool, &addresses)?;
let stream = stream
.filter_map(|gateway_info| {
future::ready(handle_updated_at(
gateway_info,
&updated_radios,
min_updated_at,
))
handle_updated_at(gateway_info, Arc::clone(&radios_cache), min_updated_at)
})
.boxed();
stream_multi_gateways_info(stream, tx.clone(), signing_key.clone(), batch_size).await
Expand Down Expand Up @@ -307,7 +302,6 @@ impl mobile_config::Gateway for GatewayService {
self.verify_request_signature(&signer, &request)?;

let metadata_db_pool = self.metadata_pool.clone();
let mobile_config_db_pool = self.mobile_config_db_pool.clone();
let signing_key = self.signing_key.clone();
let batch_size = request.batch_size;

Expand All @@ -320,6 +314,7 @@ impl mobile_config::Gateway for GatewayService {
device_types
);

let radios_cache = Arc::clone(&self.tracked_radios_cache);
tokio::spawn(async move {
let min_updated_at = Utc
.timestamp_opt(request.min_updated_at as i64, 0)
Expand All @@ -328,15 +323,10 @@ impl mobile_config::Gateway for GatewayService {
"Invalid min_refreshed_at argument",
))?;

let updated_radios = get_updated_radios(&mobile_config_db_pool, min_updated_at).await?;
let stream = gateway_info::db::all_info_stream(&metadata_db_pool, &device_types);
let stream = stream
.filter_map(|gateway_info| {
future::ready(handle_updated_at(
gateway_info,
&updated_radios,
min_updated_at,
))
handle_updated_at(gateway_info, Arc::clone(&radios_cache), min_updated_at)
})
.boxed();
stream_multi_gateways_info(stream, tx.clone(), signing_key.clone(), batch_size).await
Expand All @@ -346,20 +336,23 @@ impl mobile_config::Gateway for GatewayService {
}
}

fn handle_updated_at(
async fn handle_updated_at(
mut gateway_info: GatewayInfo,
updated_radios: &HashMap<PublicKeyBinary, chrono::DateTime<Utc>>,
updated_radios: Arc<RwLock<TrackedRadiosMap>>,
min_updated_at: chrono::DateTime<Utc>,
) -> Option<GatewayInfo> {
// Check mobile_radio_tracker HashMap
if let Some(updated_at) = updated_radios.get(&gateway_info.address) {
// It could be already filtered by min_updated_at but recheck won't hurt
if updated_at >= &min_updated_at {
gateway_info.updated_at = Some(*updated_at);
return Some(gateway_info);
{
let updated_radios = updated_radios.read().await;
if let Some(updated_at) = updated_radios.get(&gateway_info.address) {
if updated_at >= &min_updated_at {
gateway_info.updated_at = Some(*updated_at);
return Some(gateway_info);
}
return None;
}
return None;
}

// Fallback solution #1. Try to use refreshed_at as updated_at field and check
// min_updated_at
if let Some(refreshed_at) = gateway_info.refreshed_at {
Expand Down
37 changes: 26 additions & 11 deletions mobile_config/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ use helium_proto::services::mobile_config::{
HexBoostingServer,
};
use mobile_config::{
admin_service::AdminService, authorization_service::AuthorizationService,
carrier_service::CarrierService, entity_service::EntityService,
gateway_service::GatewayService, hex_boosting_service::HexBoostingService, key_cache::KeyCache,
mobile_radio_tracker::MobileRadioTracker, settings::Settings,
admin_service::AdminService,
authorization_service::AuthorizationService,
carrier_service::CarrierService,
entity_service::EntityService,
gateway_service::GatewayService,
hex_boosting_service::HexBoostingService,
key_cache::KeyCache,
mobile_radio_tracker::{MobileRadioTracker, TrackedRadiosMap},
settings::Settings,
};
use std::{net::SocketAddr, path::PathBuf, time::Duration};
use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use task_manager::{ManagedTask, TaskManager};
use tokio::sync::RwLock;
use tonic::transport;

#[derive(Debug, clap::Parser)]
Expand Down Expand Up @@ -71,11 +77,15 @@ impl Daemon {

let admin_svc =
AdminService::new(settings, key_cache.clone(), key_cache_updater, pool.clone())?;

let tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>> =
Arc::new(RwLock::new(TrackedRadiosMap::new()));

let gateway_svc = GatewayService::new(
key_cache.clone(),
metadata_pool.clone(),
settings.signing_keypair()?,
pool.clone(),
Arc::clone(&tracked_radios_cache),
);
let auth_svc = AuthorizationService::new(key_cache.clone(), settings.signing_keypair()?);
let entity_svc = EntityService::new(
Expand Down Expand Up @@ -107,13 +117,18 @@ impl Daemon {
hex_boosting_svc,
};

let mobile_tracker = MobileRadioTracker::new(
pool.clone(),
metadata_pool.clone(),
settings.mobile_radio_tracker_interval,
Arc::clone(&tracked_radios_cache),
);
// Preinitialize tracked_radios_cache to avoid race condition in GatewayService
mobile_tracker.track_changes().await?;

TaskManager::builder()
.add_task(grpc_server)
.add_task(MobileRadioTracker::new(
pool.clone(),
metadata_pool.clone(),
settings.mobile_radio_tracker_interval,
))
.add_task(mobile_tracker)
.build()
.start()
.await
Expand Down
Loading

0 comments on commit d91ef05

Please sign in to comment.