diff --git a/__init__.py b/__init__.py index 9582e68..7eb40a3 100644 --- a/__init__.py +++ b/__init__.py @@ -1,16 +1,13 @@ import asyncio -from typing import List from fastapi import APIRouter -from lnbits.db import Database -from lnbits.helpers import template_renderer -from lnbits.tasks import create_permanent_unique_task from loguru import logger -from .nostr.client.client import NostrClient -from .router import NostrRouter - -db = Database("ext_nostrclient") +from .crud import db +from .nostr_client import all_routers, nostr_client +from .tasks import check_relays, init_relays, subscribe_events +from .views import nostrclient_generic_router +from .views_api import nostrclient_api_router nostrclient_static_files = [ { @@ -20,23 +17,11 @@ ] nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"]) - -nostr_client: NostrClient = NostrClient() - -# we keep this in -all_routers: list[NostrRouter] = [] +nostrclient_ext.include_router(nostrclient_generic_router) +nostrclient_ext.include_router(nostrclient_api_router) scheduled_tasks: list[asyncio.Task] = [] -def nostr_renderer(): - return template_renderer(["nostrclient/templates"]) - - -from .tasks import check_relays, init_relays, subscribe_events # noqa -from .views import * # noqa -from .views_api import * # noqa - - async def nostrclient_stop(): for task in scheduled_tasks: try: @@ -55,9 +40,20 @@ async def nostrclient_stop(): def nostrclient_start(): + from lnbits.tasks import create_permanent_unique_task + task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays) task2 = create_permanent_unique_task( "ext_nostrclient_subscrive_events", subscribe_events ) task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays) scheduled_tasks.extend([task1, task2, task3]) + + +__all__ = [ + "db", + "nostrclient_ext", + "nostrclient_static_files", + "nostrclient_stop", + "nostrclient_start", +] diff --git a/crud.py b/crud.py index 7962d75..1462b51 100644 --- a/crud.py +++ b/crud.py @@ -1,11 +1,14 @@ import json -from typing import List, Optional +from typing import Optional + +from lnbits.db import Database -from . import db from .models import Config, Relay +db = Database("ext_nostrclient") + -async def get_relays() -> List[Relay]: +async def get_relays() -> list[Relay]: rows = await db.fetchall("SELECT * FROM nostrclient.relays") return [Relay.from_row(r) for r in rows] diff --git a/nostr_client.py b/nostr_client.py new file mode 100644 index 0000000..719674f --- /dev/null +++ b/nostr_client.py @@ -0,0 +1,5 @@ +from .nostr.client.client import NostrClient +from .router import NostrRouter + +nostr_client: NostrClient = NostrClient() +all_routers: list[NostrRouter] = [] diff --git a/pyproject.toml b/pyproject.toml index cf8a13a..ce32e66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,10 @@ build-backend = "poetry.core.masonry.api" [tool.mypy] exclude = "(nostr/*)" + [[tool.mypy.overrides]] module = [ + "nostr.*", "lnbits.*", "lnurl.*", "loguru.*", @@ -32,7 +34,9 @@ module = [ "pyqrcode.*", "shortuuid.*", "httpx.*", + "secp256k1.*", ] +follow_imports = "skip" ignore_missing_imports = "True" [tool.pytest.ini_options] diff --git a/router.py b/router.py index 942407a..e5665d4 100644 --- a/router.py +++ b/router.py @@ -11,9 +11,9 @@ class NostrRouter: - received_subscription_events: dict[str, List[EventMessage]] = {} - received_subscription_notices: list[NoticeMessage] = [] - received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} + received_subscription_events: dict[str, List[EventMessage]] + received_subscription_notices: list[NoticeMessage] + received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] def __init__(self, websocket: WebSocket): self.connected: bool = True @@ -154,7 +154,10 @@ def _handle_client_close(self, subscription_id): self.original_subscription_ids.pop(subscription_id_rewritten) nostr_client.relay_manager.close_subscription(subscription_id_rewritten) logger.info( - f"Unsubscribe from '{subscription_id_rewritten}'. Original id: '{subscription_id}.'" + f""" + Unsubscribe from '{subscription_id_rewritten}'. + Original id: '{subscription_id}.' + """ ) else: logger.info(f"Failed to unsubscribe from '{subscription_id}.'") diff --git a/tasks.py b/tasks.py index 4f51219..441e790 100644 --- a/tasks.py +++ b/tasks.py @@ -13,7 +13,7 @@ async def init_relays(): # get relays from db relays = await get_relays() # set relays and connect to them - valid_relays = list(set([r.url for r in relays if r.url])) + valid_relays = [r.url for r in relays if r.url] nostr_client.reconnect(valid_relays) @@ -29,34 +29,32 @@ async def check_relays(): async def subscribe_events(): - while not any([r.connected for r in nostr_client.relay_manager.relays.values()]): + while not [r.connected for r in nostr_client.relay_manager.relays.values()]: await asyncio.sleep(2) - def callback_events(eventMessage: EventMessage): - sub_id = eventMessage.subscription_id + def callback_events(event_message: EventMessage): + sub_id = event_message.subscription_id if sub_id not in NostrRouter.received_subscription_events: - NostrRouter.received_subscription_events[sub_id] = [eventMessage] + NostrRouter.received_subscription_events[sub_id] = [event_message] return # do not add duplicate events (by event id) - ids = set( - [e.event_id for e in NostrRouter.received_subscription_events[sub_id]] - ) - if eventMessage.event_id in ids: + ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]] + if event_message.event_id in ids: return - NostrRouter.received_subscription_events[sub_id].append(eventMessage) + NostrRouter.received_subscription_events[sub_id].append(event_message) - def callback_notices(noticeMessage: NoticeMessage): - if noticeMessage not in NostrRouter.received_subscription_notices: - NostrRouter.received_subscription_notices.append(noticeMessage) + def callback_notices(notice_message: NoticeMessage): + if notice_message not in NostrRouter.received_subscription_notices: + NostrRouter.received_subscription_notices.append(notice_message) - def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): - sub_id = eventMessage.subscription_id + def callback_eose_notices(event_message: EndOfStoredEventsMessage): + sub_id = event_message.subscription_id if sub_id in NostrRouter.received_subscription_eosenotices: return - NostrRouter.received_subscription_eosenotices[sub_id] = eventMessage + NostrRouter.received_subscription_eosenotices[sub_id] = event_message def wrap_async_subscribe(): asyncio.run( diff --git a/views.py b/views.py index 9b08559..f9a8c1e 100644 --- a/views.py +++ b/views.py @@ -1,15 +1,20 @@ -from fastapi import Depends, Request +from fastapi import APIRouter, Depends, Request +from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from lnbits.core.models import User from lnbits.decorators import check_admin -from starlette.responses import HTMLResponse - -from . import nostr_renderer, nostrclient_ext +from lnbits.helpers import template_renderer templates = Jinja2Templates(directory="templates") +nostrclient_generic_router = APIRouter() + + +def nostr_renderer(): + return template_renderer(["nostrclient/templates"]) + -@nostrclient_ext.get("/", response_class=HTMLResponse) +@nostrclient_generic_router.get("/", response_class=HTMLResponse) async def index(request: Request, user: User = Depends(check_admin)): return nostr_renderer().TemplateResponse( "nostrclient/index.html", {"request": request, "user": user.dict()} diff --git a/views_api.py b/views_api.py index b16dc12..6426a35 100644 --- a/views_api.py +++ b/views_api.py @@ -1,14 +1,11 @@ import asyncio from http import HTTPStatus -from typing import List -from fastapi import Depends, WebSocket +from fastapi import APIRouter, Depends, HTTPException, WebSocket from lnbits.decorators import check_admin from lnbits.helpers import decrypt_internal_message, urlsafe_short_hash from loguru import logger -from starlette.exceptions import HTTPException -from . import all_routers, nostr_client, nostrclient_ext from .crud import ( add_relay, create_config, @@ -18,13 +15,16 @@ update_config, ) from .helpers import normalize_public_key -from .models import Config, Relay, TestMessage, TestMessageResponse +from .models import Config, Relay, RelayStatus, TestMessage, TestMessageResponse from .nostr.key import EncryptedDirectMessage, PrivateKey +from .nostr_client import all_routers, nostr_client from .router import NostrRouter +nostrclient_api_router = APIRouter() -@nostrclient_ext.get("/api/v1/relays", dependencies=[Depends(check_admin)]) -async def api_get_relays() -> List[Relay]: + +@nostrclient_api_router.get("/api/v1/relays", dependencies=[Depends(check_admin)]) +async def api_get_relays() -> list[Relay]: relays = [] for url, r in nostr_client.relay_manager.relays.items(): relay_id = urlsafe_short_hash() @@ -33,13 +33,13 @@ async def api_get_relays() -> List[Relay]: id=relay_id, url=url, connected=r.connected, - status={ - "num_sent_events": r.num_sent_events, - "num_received_events": r.num_received_events, - "error_counter": r.error_counter, - "error_list": r.error_list, - "notice_list": r.notice_list, - }, + status=RelayStatus( + num_sent_events=r.num_sent_events, + num_received_events=r.num_received_events, + error_counter=r.error_counter, + error_list=r.error_list, + notice_list=r.notice_list, + ), ping=r.ping, active=True, ) @@ -47,10 +47,10 @@ async def api_get_relays() -> List[Relay]: return relays -@nostrclient_ext.post( +@nostrclient_api_router.post( "/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)] ) -async def api_add_relay(relay: Relay) -> List[Relay]: +async def api_add_relay(relay: Relay) -> list[Relay]: if not relay.url: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Relay url not provided." @@ -68,7 +68,7 @@ async def api_add_relay(relay: Relay) -> List[Relay]: return await get_relays() -@nostrclient_ext.delete( +@nostrclient_api_router.delete( "/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)] ) async def api_delete_relay(relay: Relay) -> None: @@ -81,7 +81,7 @@ async def api_delete_relay(relay: Relay) -> None: await delete_relay(relay) -@nostrclient_ext.put( +@nostrclient_api_router.put( "/api/v1/relay/test", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)] ) async def api_test_endpoint(data: TestMessage) -> TestMessageResponse: @@ -105,33 +105,34 @@ async def api_test_endpoint(data: TestMessage) -> TestMessageResponse: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=str(ex), - ) + ) from ex except Exception as ex: logger.warning(ex) raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Cannot generate test event", - ) + ) from ex -@nostrclient_ext.websocket("/api/v1/{id}") -async def ws_relay(id: str, websocket: WebSocket) -> None: +@nostrclient_api_router.websocket("/api/v1/{id}") +async def ws_relay(ws_id: str, websocket: WebSocket) -> None: """Relay multiplexer: one client (per endpoint) <-> multiple relays""" logger.info("New websocket connection at: '/api/v1/relay'") try: config = await get_config() + assert config, "Failed to get config" if not config.private_ws and not config.public_ws: raise ValueError("Websocket connections not accepted.") - if id == "relay": + if ws_id == "relay": if not config.public_ws: raise ValueError("Public websocket connections not accepted.") else: if not config.private_ws: raise ValueError("Private websocket connections not accepted.") - if decrypt_internal_message(id) != "relay": + if decrypt_internal_message(ws_id) != "relay": raise ValueError("Invalid websocket endpoint.") await websocket.accept() @@ -160,10 +161,10 @@ async def ws_relay(id: str, websocket: WebSocket) -> None: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Cannot accept websocket connection", - ) + ) from ex -@nostrclient_ext.get("/api/v1/config", dependencies=[Depends(check_admin)]) +@nostrclient_api_router.get("/api/v1/config", dependencies=[Depends(check_admin)]) async def api_get_config() -> Config: config = await get_config() if not config: @@ -172,7 +173,7 @@ async def api_get_config() -> Config: return config -@nostrclient_ext.put("/api/v1/config", dependencies=[Depends(check_admin)]) +@nostrclient_api_router.put("/api/v1/config", dependencies=[Depends(check_admin)]) async def api_update_config(data: Config): config = await update_config(data) assert config