Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add type hints #54

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ test: venv ## Run unit tests in virtual environment
test_direct: ## Run unit tests without virtual environment (typically for CI)
pip install .[tests] && python -m unittest discover -v ./tests

mypy: venv ## Run mypy
source $(VENV_DIR)/bin/activate && \
mypy --check-untyped-defs ./stubs $(PACKAGE)

cover: venv
source $(VENV_DIR)/bin/activate && green -a -r -s 1 -vv ./tests

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
'dataclasses;python_version=="3.6"',
"pure-protobuf",
"linkify-it-py",
"filetype"
"filetype",
]

docs_require = []
test_require = []
dev_require = ["green", "black", "isort"]
test_require = ["green", "mypy"]
dev_require = ["black", "isort"]

# What packages are optional?
EXTRAS = {
Expand Down
45 changes: 28 additions & 17 deletions signal2html/addressbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

import abc
import logging
import sqlite3

from typing import Dict
from typing import Optional

from .html_colors import get_random_color
from .models import Recipient
Expand All @@ -21,24 +25,24 @@ class Addressbook(metaclass=abc.ABCMeta):
- `_load_recipients()` to load all recipients
- `get_recipient_by_address()` to return a specific recipient"""

def __init__(self, db):
def __init__(self, db: sqlite3.Cursor):
"""Initializes the addressbook and load all known recipients."""
self.logger = logging.getLogger(__name__)
self.db = db
self.rid_to_recipient: dict[str, Recipient] = {}
self.phone_to_rid: dict[str, str] = {}
self.uuid_to_rid: dict[str, str] = {}
self.groups: dict[int, str] = {}
self.rid_to_recipient: Dict[str, Recipient] = {}
self.phone_to_rid: Dict[str, str] = {}
self.uuid_to_rid: Dict[str, str] = {}
self.groups: Dict[int, str] = {}

self._load_groups()
self._load_recipients() # Must be implemented by subclass
self.next_rid = 10000

@abc.abstractmethod
def _load_recipients():
def _load_recipients(self):
"""Load all recipients in the recipient_preferences table."""

def get_group_title(self, group_id: str) -> str:
def get_group_title(self, group_id: str) -> Optional[str]:
"""Retrieves the title of a group given the group_id (long
hexadecimal-based identifier)."""
return self.groups.get(group_id)
Expand All @@ -54,17 +58,25 @@ def get_recipient_by_address(self, address: str) -> Recipient:
If an address is provided that does not exist in the addressbook,
it is created on the spot."""

def get_recipient_by_phone(self, phone: str) -> Recipient:
def get_recipient_by_phone(self, phone: str) -> Optional[Recipient]:
"""Returns a Recipient object that matches the phone number provided."""
rid = self.phone_to_rid.get(phone)
return self.rid_to_recipient.get(rid)

def get_recipient_by_uuid(self, uuid: str) -> Recipient:
def get_recipient_by_uuid(self, uuid: str) -> Optional[Recipient]:
"""Returns a Recipient object that matches the UUID provided."""
rid = self.uuid_to_rid.get(uuid)
return self.rid_to_recipient.get(rid)

def _add_recipient(self, recipient_id, uuid, name, color, isgroup, phone):
def _add_recipient(
self,
recipient_id: int,
uuid: str,
name: str,
color: str,
isgroup: bool,
phone: str,
) -> Recipient:
"""Adds a recipient to the internal data structures."""
recipient = Recipient(
recipient_id,
Expand All @@ -83,20 +95,19 @@ def _add_recipient(self, recipient_id, uuid, name, color, isgroup, phone):
self.phone_to_rid[str(phone)] = str(recipient_id)
if uuid:
self.uuid_to_rid[uuid] = str(recipient_id)

return recipient

def _get_friendly_name_for_group(self, address: str):
def _get_friendly_name_for_group(self, address: str) -> str:
"""Creates a readable group name, either the title or a name derived from the group id."""
name = self.get_group_title(address)
if not name:
gid = self._get_group_id(address)
if gid:
return f"Group {gid}"
else:
return ""
return ""
return name

def _get_group_id(self, group_id: str) -> str:
def _get_group_id(self, group_id: str) -> Optional[str]:
"""Gets the integer ID of a group from the Signal database."""
qry = self.db.execute(
"SELECT group_id, _id FROM groups WHERE group_id LIKE ?",
Expand All @@ -105,6 +116,7 @@ def _get_group_id(self, group_id: str) -> str:
qry_res = qry.fetchone()
if qry_res:
return str(qry_res[1])
return None

def _get_new_rid(self) -> str:
"""Creates a new recipient ID for recipients not in the initial
Expand Down Expand Up @@ -231,8 +243,7 @@ def get_recipient_by_address(self, address: str) -> Recipient:
return self._add_recipient(
rid, "", "", get_random_color(), False, ""
)
else:
return recipient
return recipient

def _isgroup(self, group_id) -> bool:
"""Decides whether a group_id refers to a group."""
Expand Down
Loading