Skip to content

Commit

Permalink
Merge pull request #13 from elokapina/retry-decrypt
Browse files Browse the repository at this point in the history
Retry decrypting events when keys received
  • Loading branch information
jaywink authored Jan 15, 2022
2 parents b861535 + f32e41f commit db0dd83
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
* Added rooms `unlink` and `unlink-and-leave` subcommands. The first variant unlinks a room
tracked by Bubo, the second also leaves the room.

* When receiving an event the bot cannot decrypt, the event will now be stored for
later. When keys are received later matching any stored encrypted events, a new attempt
will be made to decrypt them.

### Changed

* Message edits are now understood as new commands from clients that send them
Expand Down
54 changes: 50 additions & 4 deletions bubo/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
from typing import Union

# noinspection PyPackageRequirements
from nio import JoinError, MatrixRoom, Event
from nio import JoinError, MatrixRoom, MegolmEvent, RoomKeyEvent, Event, RoomMessageText, UnknownEvent

from bubo.bot_commands import Command
from bubo.chat_functions import send_text_to_room, invite_to_room
Expand All @@ -25,6 +28,14 @@ def __init__(self, client, store, config):
self.config = config
self.command_prefix = config.command_prefix

async def decrypted_callback(self, room_id: str, event: Union[RoomMessageText, UnknownEvent]):
if isinstance(event, RoomMessageText):
await self.message(self.client.rooms[room_id], event)
elif isinstance(event, UnknownEvent):
await self.reaction(self.client.rooms[room_id], event)
else:
logger.warning(f"Unknown event %s passed to decrypted_callback" % event)

async def message(self, room, event):
"""Callback for when a message event is received
Expand Down Expand Up @@ -92,6 +103,39 @@ async def reaction(self, room, event):
self.client, room_id, event.sender,
)

async def room_key(self, event: RoomKeyEvent):
"""Callback for ToDevice events like room key events."""
events = self.store.get_encrypted_events(event.session_id)
logger.info("Got room key event for session %s, matched sessions: %s" % (event.session_id, len(events)))
if not events:
return

for encrypted_event in events:
try:
event_dict = json.loads(encrypted_event["event"])
params = event_dict["source"]
params["room_id"] = event_dict["room_id"]
params["transaction_id"] = event_dict["transaction_id"]
megolm_event = MegolmEvent.from_dict(params)
except Exception as ex:
logger.warning("Failed to restore MegolmEvent for %s: %s" % (encrypted_event["event_id"], ex))
continue
try:
# noinspection PyTypeChecker
decrypted = self.client.decrypt_event(megolm_event)
except Exception as ex:
logger.warning(f"Error decrypting event %s: %s" % (megolm_event.event_id, ex))
continue
if isinstance(decrypted, Event):
logger.info(f"Successfully decrypted stored event %s" % decrypted.event_id)
parsed_event = Event.parse_event(decrypted.source)
logger.info(f"Parsed event: %s" % parsed_event)
self.store.remove_encrypted_event(decrypted.event_id)
# noinspection PyTypeChecker
await self.decrypted_callback(encrypted_event["room_id"], parsed_event)
else:
logger.warning(f"Failed to decrypt event %s" % (decrypted.event_id,))

async def invite(self, room, event):
"""Callback for when an invite is received. Join the room specified in the invite"""
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
Expand All @@ -112,11 +156,11 @@ async def invite(self, room, event):
# Successfully joined room
logger.info(f"Joined {room.room_id}")

async def decryption_failure(self, room: MatrixRoom, event: Event):
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent):
"""Callback for when an event fails to decrypt."""
logger.error(
logger.warning(
f"Failed to decrypt event {event.event_id} in room {room.name} ({room.canonical_alias} / {room.room_id}) "
f"from sender {event.sender}."
f"from sender {event.sender} - possibly missing session, storing for later."
)
if self.config.callbacks.get("unable_to_decrypt_responses", True):
user_msg = (
Expand All @@ -127,3 +171,5 @@ async def decryption_failure(self, room: MatrixRoom, event: Event):
await send_text_to_room(
self.client, room.room_id, user_msg, reply_to_event_id=event.event_id,
)

self.store.store_encrypted_event(event)
14 changes: 14 additions & 0 deletions bubo/migrations/009.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def forward(cursor):
cursor.execute("""
CREATE TABLE encrypted_events (
id INTEGER PRIMARY KEY autoincrement,
device_id text,
event_id text unique,
room_id text,
session_id text,
event text
)
""")
cursor.execute("""
CREATE INDEX encrypted_events_session_id_idx on encrypted_events (session_id);
""")
31 changes: 30 additions & 1 deletion bubo/storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
import logging
import time
from dataclasses import asdict
from importlib import import_module
from typing import Optional, List

import sqlite3
# noinspection PyPackageRequirements
from nio import MegolmEvent

latest_db_version = 8
latest_db_version = 9

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +87,12 @@ def get_breakout_room_id(self, event_id: str):
if room:
return room[0]

def get_encrypted_events(self, session_id: str):
results = self.cursor.execute("""
select * from encrypted_events where session_id = ?;
""", (session_id,))
return results.fetchall()

def get_recreate_room(self, room_id: str):
results = self.cursor.execute("""
select requester, timestamp, applied from recreate_rooms where room_id = ?;
Expand All @@ -109,6 +119,12 @@ def get_rooms(self) -> List[sqlite3.Row]:
""")
return results.fetchall()

def remove_encrypted_event(self, event_id: str):
self.cursor.execute("""
delete from encrypted_events where event_id = ?;
""", (event_id,))
self.conn.commit()

def set_recreate_room_applied(self, room_id: str):
self.cursor.execute("""
update recreate_rooms set applied = 1 where room_id = ?;
Expand Down Expand Up @@ -137,6 +153,19 @@ def store_community(self, name: str, alias: str, title: str):
""", (name, alias, title))
self.conn.commit()

def store_encrypted_event(self, event: MegolmEvent):
try:
event_dict = asdict(event)
event_json = json.dumps(event_dict)
self.cursor.execute("""
insert into encrypted_events
(device_id, event_id, room_id, session_id, event) values
(?, ?, ?, ?, ?)
""", (event.device_id, event.event_id, event.room_id, event.session_id, event_json))
self.conn.commit()
except Exception as ex:
logger.error("Failed to store encrypted event %s: %s" % (event.event_id, ex))

def store_recreate_room(self, requester: str, room_id: str):
timestamp = int(time.time())
self.cursor.execute("""
Expand Down
10 changes: 7 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from nio import (
AsyncClient,
AsyncClientConfig,
RoomMessageText,
ForwardedRoomKeyEvent,
InviteMemberEvent,
LoginError,
LocalProtocolError,
LoginError,
MegolmEvent,
UnknownEvent,
RoomKeyEvent,
RoomMessageText,
UnknownEvent,
)

from bubo.callbacks import Callbacks
Expand Down Expand Up @@ -66,6 +68,8 @@ async def main(config: Config):
# Nio doesn't currently have m.reaction events so we catch UnknownEvent for reactions and filter there
# noinspection PyTypeChecker
client.add_event_callback(callbacks.reaction, (UnknownEvent,))
# noinspection PyTypeChecker
client.add_to_device_callback(callbacks.room_key, (ForwardedRoomKeyEvent, RoomKeyEvent))

# Keep trying to reconnect on failure (with some time in-between)
while True:
Expand Down

0 comments on commit db0dd83

Please sign in to comment.