diff --git a/changelog.d/17928.misc b/changelog.d/17928.misc new file mode 100644 index 00000000000..b5aef4457a4 --- /dev/null +++ b/changelog.d/17928.misc @@ -0,0 +1 @@ +Move server event filtering logic to rust. diff --git a/rust/src/events/filter.rs b/rust/src/events/filter.rs new file mode 100644 index 00000000000..7e39972c62d --- /dev/null +++ b/rust/src/events/filter.rs @@ -0,0 +1,107 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +use std::collections::HashMap; + +use pyo3::{exceptions::PyValueError, pyfunction, PyResult}; + +use crate::{ + identifier::UserID, + matrix_const::{ + HISTORY_VISIBILITY_INVITED, HISTORY_VISIBILITY_JOINED, MEMBERSHIP_INVITE, MEMBERSHIP_JOIN, + }, +}; + +#[pyfunction(name = "event_visible_to_server")] +pub fn event_visible_to_server_py( + sender: String, + target_server_name: String, + history_visibility: String, + erased_senders: HashMap, + partial_state_invisible: bool, + memberships: Vec<(String, String)>, // (state_key, membership) +) -> PyResult { + event_visible_to_server( + sender, + target_server_name, + history_visibility, + erased_senders, + partial_state_invisible, + memberships, + ) + .map_err(|e| PyValueError::new_err(format!("{e}"))) +} + +/// Return whether the target server is allowed to see the event. +/// +/// For a fully stated room, the target server is allowed to see an event E if: +/// - the state at E has world readable or shared history vis, OR +/// - the state at E says that the target server is in the room. +/// +/// For a partially stated room, the target server is allowed to see E if: +/// - E was created by this homeserver, AND: +/// - the partial state at E has world readable or shared history vis, OR +/// - the partial state at E says that the target server is in the room. +pub fn event_visible_to_server( + sender: String, + target_server_name: String, + history_visibility: String, + erased_senders: HashMap, + partial_state_invisible: bool, + memberships: Vec<(String, String)>, // (state_key, membership) +) -> anyhow::Result { + if let Some(&erased) = erased_senders.get(&sender) { + if erased { + return Ok(false); + } + } + + if partial_state_invisible { + return Ok(false); + } + + if history_visibility != HISTORY_VISIBILITY_INVITED + && history_visibility != HISTORY_VISIBILITY_JOINED + { + return Ok(true); + } + + let mut visible = false; + for (state_key, membership) in memberships { + let state_key = UserID::try_from(state_key.as_ref()) + .map_err(|e| anyhow::anyhow!(format!("invalid user_id ({state_key}): {e}")))?; + if state_key.server_name() != target_server_name { + return Err(anyhow::anyhow!( + "state_key.server_name ({}) does not match target_server_name ({target_server_name})", + state_key.server_name() + )); + } + + match membership.as_str() { + MEMBERSHIP_INVITE => { + if history_visibility == HISTORY_VISIBILITY_INVITED { + visible = true; + break; + } + } + MEMBERSHIP_JOIN => { + visible = true; + break; + } + _ => continue, + } + } + + Ok(visible) +} diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index a4ade1a1786..0bb6cdb181c 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -22,15 +22,17 @@ use pyo3::{ types::{PyAnyMethods, PyModule, PyModuleMethods}, - Bound, PyResult, Python, + wrap_pyfunction, Bound, PyResult, Python, }; +pub mod filter; mod internal_metadata; /// Called when registering modules with python. pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let child_module = PyModule::new_bound(py, "events")?; child_module.add_class::()?; + child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; m.add_submodule(&child_module)?; diff --git a/rust/src/identifier.rs b/rust/src/identifier.rs new file mode 100644 index 00000000000..b199c5838eb --- /dev/null +++ b/rust/src/identifier.rs @@ -0,0 +1,86 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +//! # Matrix Identifiers +//! +//! This module contains definitions and utilities for working with matrix identifiers. + +use std::{fmt, ops::Deref}; + +/// Errors that can occur when parsing a matrix identifier. +#[derive(Clone, Debug, PartialEq)] +pub enum IdentifierError { + IncorrectSigil, + MissingColon, +} + +impl fmt::Display for IdentifierError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +/// A Matrix user_id. +#[derive(Clone, Debug, PartialEq)] +pub struct UserID(String); + +impl UserID { + /// Returns the `localpart` of the user_id. + pub fn localpart(&self) -> &str { + &self[1..self.colon_pos()] + } + + /// Returns the `server_name` / `domain` of the user_id. + pub fn server_name(&self) -> &str { + &self[self.colon_pos() + 1..] + } + + /// Returns the position of the ':' inside of the user_id. + /// Used when splitting the user_id into it's respective parts. + fn colon_pos(&self) -> usize { + self.find(':').unwrap() + } +} + +impl TryFrom<&str> for UserID { + type Error = IdentifierError; + + /// Will try creating a `UserID` from the provided `&str`. + /// Can fail if the user_id is incorrectly formatted. + fn try_from(s: &str) -> Result { + if !s.starts_with('@') { + return Err(IdentifierError::IncorrectSigil); + } + + if s.find(':').is_none() { + return Err(IdentifierError::MissingColon); + } + + Ok(UserID(s.to_string())) + } +} + +impl Deref for UserID { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for UserID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 06477880b94..5de92383269 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -6,6 +6,8 @@ pub mod acl; pub mod errors; pub mod events; pub mod http; +pub mod identifier; +pub mod matrix_const; pub mod push; pub mod rendezvous; diff --git a/rust/src/matrix_const.rs b/rust/src/matrix_const.rs new file mode 100644 index 00000000000..f75f3bd7c34 --- /dev/null +++ b/rust/src/matrix_const.rs @@ -0,0 +1,28 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2024 New Vector, Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +//! # Matrix Constants +//! +//! This module contains definitions for constant values described by the matrix specification. + +pub const HISTORY_VISIBILITY_WORLD_READABLE: &str = "world_readable"; +pub const HISTORY_VISIBILITY_SHARED: &str = "shared"; +pub const HISTORY_VISIBILITY_INVITED: &str = "invited"; +pub const HISTORY_VISIBILITY_JOINED: &str = "joined"; + +pub const MEMBERSHIP_BAN: &str = "ban"; +pub const MEMBERSHIP_LEAVE: &str = "leave"; +pub const MEMBERSHIP_KNOCK: &str = "knock"; +pub const MEMBERSHIP_INVITE: &str = "invite"; +pub const MEMBERSHIP_JOIN: &str = "join"; diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs index 28ebed62c88..59536c99548 100644 --- a/rust/src/push/utils.rs +++ b/rust/src/push/utils.rs @@ -23,7 +23,6 @@ use anyhow::bail; use anyhow::Context; use anyhow::Error; use lazy_static::lazy_static; -use regex; use regex::Regex; use regex::RegexBuilder; diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 1682d0d151a..7d3422572dd 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -10,7 +10,7 @@ # See the GNU Affero General Public License for more details: # . -from typing import Optional +from typing import List, Mapping, Optional, Tuple from synapse.types import JsonDict @@ -105,3 +105,29 @@ class EventInternalMetadata: def is_notifiable(self) -> bool: """Whether this event can trigger a push notification""" + +def event_visible_to_server( + sender: str, + target_server_name: str, + history_visibility: str, + erased_senders: Mapping[str, bool], + partial_state_invisible: bool, + memberships: List[Tuple[str, str]], +) -> bool: + """Determine whether the server is allowed to see the unredacted event. + + Args: + sender: The sender of the event. + target_server_name: The server we want to send the event to. + history_visibility: The history_visibility value at the event. + erased_senders: A mapping of users and whether they have requested erasure. If a + user is not in the map, it is treated as though they haven't requested erasure. + partial_state_invisible: Whether the event should be treated as invisible due to + the partial state status of the room. + memberships: A list of membership state information at the event for users + matching the `target_server_name`. Each list item must contain a tuple of + (state_key, membership). + + Returns: + Whether the server is allowed to see the unredacted event. + """ diff --git a/synapse/visibility.py b/synapse/visibility.py index 3a2782bade7..dc7b6e4065e 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -27,7 +27,6 @@ Final, FrozenSet, List, - Mapping, Optional, Sequence, Set, @@ -48,6 +47,7 @@ from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore +from synapse.synapse_rust.events import event_visible_to_server from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import Clock @@ -628,17 +628,6 @@ async def filter_events_for_server( """Filter a list of events based on whether the target server is allowed to see them. - For a fully stated room, the target server is allowed to see an event E if: - - the state at E has world readable or shared history vis, OR - - the state at E says that the target server is in the room. - - For a partially stated room, the target server is allowed to see E if: - - E was created by this homeserver, AND: - - the partial state at E has world readable or shared history vis, OR - - the partial state at E says that the target server is in the room. - - TODO: state before or state after? - Args: storage target_server_name @@ -655,35 +644,6 @@ async def filter_events_for_server( The filtered events. """ - def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool: - if erased_senders and erased_senders[event.sender]: - logger.info("Sender of %s has been erased, redacting", event.event_id) - return True - return False - - def check_event_is_visible( - visibility: str, memberships: StateMap[EventBase] - ) -> bool: - if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED): - return True - - # We now loop through all membership events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == target_server_name - - memtype = ev.membership - if memtype == Membership.JOIN: - return True - elif memtype == Membership.INVITE: - if visibility == HistoryVisibility.INVITED: - return True - - # server has no users in the room: redact - return False - if filter_out_erased_senders: erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: @@ -726,20 +686,16 @@ def check_event_is_visible( target_server_name, ) - def include_event_in_output(e: EventBase) -> bool: - erased = is_sender_erased(e, erased_senders) - visible = check_event_is_visible( - event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) - ) - - if e.event_id in partial_state_invisible_event_ids: - visible = False - - return visible and not erased - to_return = [] for e in events: - if include_event_in_output(e): + if event_visible_to_server( + sender=e.sender, + target_server_name=target_server_name, + history_visibility=event_to_history_vis[e.event_id], + erased_senders=erased_senders, + partial_state_invisible=e.event_id in partial_state_invisible_event_ids, + memberships=list(event_to_memberships.get(e.event_id, {}).values()), + ): to_return.append(e) elif redact: to_return.append(prune_event(e)) @@ -796,7 +752,7 @@ async def _event_to_history_vis( async def _event_to_memberships( storage: StorageControllers, events: Collection[EventBase], server_name: str -) -> Dict[str, StateMap[EventBase]]: +) -> Dict[str, StateMap[Tuple[str, str]]]: """Get the remote membership list at each of the given events Returns a map from event id to state map, which will contain only membership events @@ -849,7 +805,7 @@ def include(state_key: str) -> bool: return { e_id: { - key: event_map[inner_e_id] + key: (event_map[inner_e_id].state_key, event_map[inner_e_id].membership) for key, inner_e_id in key_to_eid.items() if inner_e_id in event_map }