From 74ee31148700e7d065329706eb4c8be162bfd3ab Mon Sep 17 00:00:00 2001 From: boxdot Date: Mon, 16 Sep 2024 19:15:26 +0200 Subject: [PATCH] feat: handle read notifications from other clients The unread channel counters are now updated when a read notification is recieved from another client. Note that the counters are ephemeral and will be reset when the app is restarted. Marking a channel as read in Gurk does not update the unread counters in other clients yet. Related to #286 --- ...2bfce367a7cfc86201ddf11e25f18a9c8cdb2.json | 20 +++++ src/app.rs | 10 ++- src/handlers.rs | 74 ++++++++++++++++++- src/storage/forgetful.rs | 4 + src/storage/json.rs | 23 ++++++ src/storage/memcache.rs | 5 ++ src/storage/mod.rs | 2 + src/storage/sql/storage.rs | 43 +++++++++++ 8 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 .sqlx/query-1b6fc3cd9b2c351443f980ab3212bfce367a7cfc86201ddf11e25f18a9c8cdb2.json diff --git a/.sqlx/query-1b6fc3cd9b2c351443f980ab3212bfce367a7cfc86201ddf11e25f18a9c8cdb2.json b/.sqlx/query-1b6fc3cd9b2c351443f980ab3212bfce367a7cfc86201ddf11e25f18a9c8cdb2.json new file mode 100644 index 0000000..d3077c7 --- /dev/null +++ b/.sqlx/query-1b6fc3cd9b2c351443f980ab3212bfce367a7cfc86201ddf11e25f18a9c8cdb2.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n SELECT\n m.channel_id AS \"channel_id: _\"\n FROM messages AS m\n WHERE m.arrived_at = ?\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "name": "channel_id: _", + "ordinal": 0, + "type_info": "Blob" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "1b6fc3cd9b2c351443f980ab3212bfce367a7cfc86201ddf11e25f18a9c8cdb2" +} diff --git a/src/app.rs b/src/app.rs index 4161e6a..b634571 100644 --- a/src/app.rs +++ b/src/app.rs @@ -480,7 +480,7 @@ impl App { } pub async fn on_message(&mut self, content: Content) -> anyhow::Result<()> { - tracing::info!(?content, "incoming"); + // tracing::info!(?content, "incoming"); #[cfg(feature = "dev")] if self.config.developer.dump_raw_messages { @@ -491,6 +491,10 @@ impl App { let user_id = self.user_id; + if let ContentBody::SynchronizeMessage(SyncMessage { ref read, .. }) = content.body { + self.handle_read(read); + } + let (channel_idx, message) = match (content.metadata, content.body) { // Private note message ( @@ -1559,7 +1563,7 @@ fn add_emoji_from_sticker(body: &mut Option, sticker: Option) { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::config::User; @@ -1570,7 +1574,7 @@ mod tests { use std::cell::RefCell; use std::rc::Rc; - fn test_app() -> ( + pub(crate) fn test_app() -> ( App, mpsc::UnboundedReceiver, Rc>>, diff --git a/src/handlers.rs b/src/handlers.rs index bce546c..7f1fb3f 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,6 +1,8 @@ +use std::collections::BTreeMap; + use anyhow::Context; use presage::libsignal_service::content::Metadata; -use presage::proto::sync_message::Sent; +use presage::proto::sync_message::{Read, Sent}; use presage::proto::{DataMessage, EditMessage, SyncMessage}; use tracing::debug; @@ -19,8 +21,6 @@ impl App { return Ok(()); }; - tracing::info!(?sync_message, "#########"); - // edit message if let Some(Sent { edit_message: @@ -88,6 +88,36 @@ impl App { Ok(()) } + + /// Handles read notifications + pub(crate) fn handle_read(&mut self, read: &[Read]) { + // First collect all the read counters to avoid hitting the storage for the same channel + let read_counters: BTreeMap = read + .iter() + .filter_map(|read| { + let arrived_at = read.timestamp?; + let channel_id = self.storage.message_channel(arrived_at)?; + let num_unread = self + .storage + .messages(channel_id) + .rev() + .take_while(|msg| arrived_at < msg.arrived_at) + .count(); + let num_unread: u32 = num_unread.try_into().ok()?; + Some((channel_id, num_unread)) + }) + .collect(); + // Update the unread counters + for (channel_id, num_unread) in read_counters { + if let Some(channel) = self.storage.channel(channel_id) { + if channel.unread_messages > 0 { + let mut channel = channel.into_owned(); + channel.unread_messages = num_unread; + self.storage.store_channel(channel); + } + } + } + } } trait MessageExt { @@ -122,3 +152,41 @@ impl MessageExt for SyncMessage { } } } + +#[cfg(test)] +mod tests { + use crate::app::tests::test_app; + + use super::*; + + #[test] + #[ignore = "forgetful storage does not support lookup by arrived_at"] + fn test_handle_read() { + let (mut app, _events, _sent_messages) = test_app(); + + let channel_id = *app.channels.items.first().unwrap(); + + // new incoming message + let message = app + .storage + .store_message( + channel_id, + Message::text(app.user_id, 42, "unread message".to_string()), + ) + .into_owned(); + + // mark as unread + app.storage + .channel(channel_id) + .unwrap() + .into_owned() + .unread_messages = 1; + + app.handle_read(&[Read { + timestamp: Some(message.arrived_at), + ..Default::default() + }]); + + assert_eq!(app.storage.channel(channel_id).unwrap().unread_messages, 0); + } +} diff --git a/src/storage/forgetful.rs b/src/storage/forgetful.rs index 146cdff..67621e6 100644 --- a/src/storage/forgetful.rs +++ b/src/storage/forgetful.rs @@ -65,4 +65,8 @@ impl Storage for ForgetfulStorage { } fn save(&mut self) {} + + fn message_channel(&self, _arrived_at: u64) -> Option { + None + } } diff --git a/src/storage/json.rs b/src/storage/json.rs index ec518cc..10a03a3 100644 --- a/src/storage/json.rs +++ b/src/storage/json.rs @@ -335,6 +335,16 @@ impl Storage for JsonStorage { error!(error =% e, "failed to save json storage"); } } + + fn message_channel(&self, arrived_at: u64) -> Option { + self.data.channels.items.iter().find_map(|channel| { + channel + .messages + .binary_search_by_key(&arrived_at, |msg| msg.arrived_at) + .is_ok() + .then_some(channel.id) + }) + } } #[cfg(test)] @@ -601,4 +611,17 @@ mod tests { ); assert_eq!(storage.metadata().contacts_sync_request_at, Some(dt)); } + + #[test] + fn test_json_storage_message_channel() { + let mut storage = json_storage_from_snapshot(); + let channel_id = ChannelId::User(uuid!("966960e0-a8cd-43f1-ac7a-2c986dd470cd")); + let from_id = uuid!("00000000-0000-0000-0000-000000000000"); + storage.store_message( + channel_id, + Message::text(from_id, 1664832050004, "hello".to_owned()), + ); + assert_eq!(storage.message_channel(1664832050004), Some(channel_id)); + assert_eq!(storage.message_channel(0), None); + } } diff --git a/src/storage/memcache.rs b/src/storage/memcache.rs index 10c0ccb..12ef408 100644 --- a/src/storage/memcache.rs +++ b/src/storage/memcache.rs @@ -172,4 +172,9 @@ impl Storage for MemCache { fn save(&mut self) { self.storage.save(); } + + fn message_channel(&self, arrived_at: u64) -> Option { + // message arrived_at to channel_id conversion is not cached + self.storage.message_channel(arrived_at) + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index efff1ad..fe4eb3e 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -44,6 +44,8 @@ pub trait Storage { /// Gets the message by id fn message(&self, message_id: MessageId) -> Option>; + fn message_channel(&self, arrived_at: u64) -> Option; + fn edits( &self, message_id: MessageId, diff --git a/src/storage/sql/storage.rs b/src/storage/sql/storage.rs index 42e9b33..01beba1 100644 --- a/src/storage/sql/storage.rs +++ b/src/storage/sql/storage.rs @@ -693,6 +693,36 @@ impl Storage for SqliteStorage { } fn save(&mut self) {} + + fn message_channel(&self, arrived_at: u64) -> Option { + struct SqlChannelId { + channel_id: ChannelId, + } + + let arrived_at: i64 = arrived_at + .try_into() + .map_err(|_| MessageConvertError::InvalidTimestamp) + .ok_logged()?; + + self.execute(|ctx| { + Box::pin( + sqlx::query_as!( + SqlChannelId, + r#" + SELECT + m.channel_id AS "channel_id: _" + FROM messages AS m + WHERE m.arrived_at = ? + LIMIT 1 + "#, + arrived_at + ) + .fetch_optional(ctx.conn), + ) + }) + .ok_logged()? + .map(|channel_id| channel_id.channel_id) + } } #[cfg(test)] @@ -958,4 +988,17 @@ mod tests { assert_eq!(is_sqlite_encrypted_heuristics(&url), Some(true)); } + + #[test] + fn test_sqlite_storage_message_channel() { + let _ = tracing_subscriber::fmt().with_test_writer().try_init(); + let mut storage = fixtures(); + let from_id = uuid!("966960e0-a8cd-43f1-ac7a-2c986dd470cd"); + let channel_id = ChannelId::User(uuid!("a955d20f-6b83-4e69-846e-a99b1779ff7a")); + storage.store_message( + channel_id, + Message::text(from_id, 1664832050000, "hello".to_owned()), + ); + assert_eq!(storage.message_channel(1664832050000), Some(channel_id)); + } }