Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal endpoint refactoring, handler optimizations #771

Merged
merged 11 commits into from
Dec 22, 2023
364 changes: 159 additions & 205 deletions moonraker/app.py

Large diffs are not rendered by default.

438 changes: 341 additions & 97 deletions moonraker/common.py

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions moonraker/components/announcements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import email.utils
import xml.etree.ElementTree as etree
from ..common import RequestType
from typing import (
TYPE_CHECKING,
Awaitable,
Expand Down Expand Up @@ -57,23 +58,23 @@ def __init__(self, config: ConfigHelper) -> None:
)

self.server.register_endpoint(
"/server/announcements/list", ["GET"],
"/server/announcements/list", RequestType.GET,
self._list_announcements
)
self.server.register_endpoint(
"/server/announcements/dismiss", ["POST"],
"/server/announcements/dismiss", RequestType.POST,
self._handle_dismiss_request
)
self.server.register_endpoint(
"/server/announcements/update", ["POST"],
"/server/announcements/update", RequestType.POST,
self._handle_update_request
)
self.server.register_endpoint(
"/server/announcements/feed", ["POST", "DELETE"],
"/server/announcements/feed", RequestType.POST | RequestType.DELETE,
self._handle_feed_request
)
self.server.register_endpoint(
"/server/announcements/feeds", ["GET"],
"/server/announcements/feeds", RequestType.GET,
self._handle_list_feeds
)
self.server.register_notification(
Expand Down Expand Up @@ -170,13 +171,13 @@ async def _handle_list_feeds(
async def _handle_feed_request(
self, web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
req_type = web_request.get_request_type()
name: str = web_request.get("name")
name = name.lower()
changed: bool = False
db: MoonrakerDatabase = self.server.lookup_component("database")
result = "skipped"
if action == "POST":
if req_type == RequestType.POST:
if name not in self.subscriptions:
feed = RssFeed(name, self.entry_mgr, self.dev_mode)
self.subscriptions[name] = feed
Expand All @@ -187,7 +188,7 @@ async def _handle_feed_request(
"moonraker", "announcements.stored_feeds", self.stored_feeds
)
result = "added"
elif action == "DELETE":
elif req_type == RequestType.DELETE:
if name not in self.stored_feeds:
raise self.server.error(f"Feed '{name}' not stored")
if name in self.configured_feeds:
Expand Down
99 changes: 50 additions & 49 deletions moonraker/components/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from tornado.web import HTTPError
from libnacl.sign import Signer, Verifier
from ..utils import json_wrapper as jsonw
from ..common import RequestType, TransportType

# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Set,
Optional,
Union,
Dict,
Expand Down Expand Up @@ -151,7 +151,6 @@ def __init__(self, config: ConfigHelper) -> None:
self.user_db.sync(self.users)
self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens: Dict[str, OneshotToken] = {}
self.permitted_paths: Set[str] = set()

# Get allowed cors domains
self.cors_domains: List[str] = []
Expand Down Expand Up @@ -221,37 +220,46 @@ def __init__(self, config: ConfigHelper) -> None:
self._prune_conn_handler)

# Register Authorization Endpoints
self.permitted_paths.add("/server/redirect")
self.permitted_paths.add("/access/login")
self.permitted_paths.add("/access/refresh_jwt")
self.permitted_paths.add("/access/info")
self.server.register_endpoint(
"/access/login", ['POST'], self._handle_login,
transports=['http', 'websocket'])
"/access/login", RequestType.POST, self._handle_login,
transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
)
self.server.register_endpoint(
"/access/logout", ['POST'], self._handle_logout,
transports=['http', 'websocket'])
"/access/logout", RequestType.POST, self._handle_logout,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/refresh_jwt", ['POST'], self._handle_refresh_jwt,
transports=['http', 'websocket'])
"/access/refresh_jwt", RequestType.POST, self._handle_refresh_jwt,
transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
)
self.server.register_endpoint(
"/access/user", ['GET', 'POST', 'DELETE'],
self._handle_user_request, transports=['http', 'websocket'])
"/access/user", RequestType.all(), self._handle_user_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/users/list", ['GET'], self._handle_list_request,
transports=['http', 'websocket'])
"/access/users/list", RequestType.GET, self._handle_list_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/user/password", ['POST'], self._handle_password_reset,
transports=['http', 'websocket'])
"/access/user/password", RequestType.POST, self._handle_password_reset,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/api_key", ['GET', 'POST'],
self._handle_apikey_request, transports=['http', 'websocket'])
"/access/api_key", RequestType.GET | RequestType.POST,
self._handle_apikey_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/oneshot_token", ['GET'],
self._handle_oneshot_request, transports=['http', 'websocket'])
"/access/oneshot_token", RequestType.GET, self._handle_oneshot_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/info", ['GET'],
self._handle_info_request, transports=['http', 'websocket'])
"/access/info", RequestType.GET, self._handle_info_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
)
wsm: WebsocketManager = self.server.lookup_component("websockets")
wsm.register_notification("authorization:user_created")
wsm.register_notification(
Expand All @@ -261,21 +269,14 @@ def __init__(self, config: ConfigHelper) -> None:
"authorization:user_logged_out", event_type="logout"
)

def register_permited_path(self, path: str) -> None:
self.permitted_paths.add(path)

def is_path_permitted(self, path: str) -> bool:
return path in self.permitted_paths

def _sync_user(self, username: str) -> None:
self.user_db[username] = self.users[username]

async def component_init(self) -> None:
self.prune_timer.start(delay=PRUNE_CHECK_TIME)

async def _handle_apikey_request(self, web_request: WebRequest) -> str:
action = web_request.get_action()
if action.upper() == 'POST':
if web_request.get_request_type() == RequestType.POST:
self.api_key = uuid.uuid4().hex
self.users[API_USER]['api_key'] = self.api_key
self._sync_user(API_USER)
Expand Down Expand Up @@ -360,11 +361,11 @@ async def _handle_refresh_jwt(self,
'action': 'user_jwt_refresh'
}

async def _handle_user_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
if action == "GET":
async def _handle_user_request(
self, web_request: WebRequest
) -> Dict[str, Any]:
req_type = web_request.get_request_type()
if req_type == RequestType.GET:
user = web_request.get_current_user()
if user is None:
return {
Expand All @@ -378,10 +379,10 @@ async def _handle_user_request(self,
'source': user.get("source", "moonraker"),
'created_on': user.get('created_on')
}
elif action == "POST":
elif req_type == RequestType.POST:
# Create User
return await self._login_jwt_user(web_request, create=True)
elif action == "DELETE":
elif req_type == RequestType.DELETE:
# Delete User
return self._delete_jwt_user(web_request)
raise self.server.error("Invalid Request Method")
Expand Down Expand Up @@ -760,13 +761,10 @@ def check_logins_maxed(self, ip_addr: IPAddr) -> bool:
return False
return self.failed_logins.get(ip_addr, 0) >= self.max_logins

def check_authorized(
self, request: HTTPServerRequest, endpoint: str = "",
def authenticate_request(
self, request: HTTPServerRequest, auth_required: bool = True
) -> Optional[Dict[str, Any]]:
if (
endpoint in self.permitted_paths
or request.method == "OPTIONS"
):
if request.method == "OPTIONS":
return None

# Check JSON Web Token
Expand Down Expand Up @@ -794,14 +792,17 @@ def check_authorized(
if key and key == self.api_key:
return self.users[API_USER]

# If the force_logins option is enabled and at least one
# user is created this is an unauthorized request
# If the force_logins option is enabled and at least one user is created
# then trusted user authentication is disabled
if self.force_logins and len(self.users) > 1:
if not auth_required:
return None
raise HTTPError(401, "Unauthorized, Force Logins Enabled")

# Check if IP is trusted
# Check if IP is trusted. If this endpoint doesn't require authentication
# then it is acceptable to return None
trusted_user = self._check_trusted_connection(ip)
if trusted_user is not None:
if trusted_user is not None or not auth_required:
return trusted_user

raise HTTPError(401, "Unauthorized")
Expand Down
11 changes: 7 additions & 4 deletions moonraker/components/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import time
from collections import deque
from ..common import RequestType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -59,11 +60,13 @@ def __init__(self, config: ConfigHelper) -> None:

# Register endpoints
self.server.register_endpoint(
"/server/temperature_store", ['GET'],
self._handle_temp_store_request)
"/server/temperature_store", RequestType.GET,
self._handle_temp_store_request
)
self.server.register_endpoint(
"/server/gcode_store", ['GET'],
self._handle_gcode_store_request)
"/server/gcode_store", RequestType.GET,
self._handle_gcode_store_request
)

async def _init_sensors(self) -> None:
klippy_apis: APIComp = self.server.lookup_component('klippy_apis')
Expand Down
28 changes: 16 additions & 12 deletions moonraker/components/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lmdb
from ..utils import Sentinel, ServerError
from ..utils import json_wrapper as jsonw
from ..common import RequestType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -174,15 +175,17 @@ def __init__(self, config: ConfigHelper) -> None:
self.insert_item("moonraker", "database.unsafe_shutdowns",
unsafe_shutdowns + 1)
self.server.register_endpoint(
"/server/database/list", ['GET'], self._handle_list_request)
"/server/database/list", RequestType.GET, self._handle_list_request
)
self.server.register_endpoint(
"/server/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request)
"/server/database/item", RequestType.all(), self._handle_item_request
)
self.server.register_debug_endpoint(
"/debug/database/list", ['GET'], self._handle_list_request)
"/debug/database/list", RequestType.GET, self._handle_list_request
)
self.server.register_debug_endpoint(
"/debug/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request)
"/debug/database/item", RequestType.all(), self._handle_item_request
)

def get_database_path(self) -> str:
return self.database_path
Expand Down Expand Up @@ -735,7 +738,7 @@ async def _handle_list_request(self,
async def _handle_item_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
req_type = web_request.get_request_type()
is_debug = web_request.get_endpoint().startswith("/debug/")
namespace = web_request.get_str("namespace")
if namespace in self.forbidden_namespaces and not is_debug:
Expand All @@ -744,7 +747,7 @@ async def _handle_item_request(self,
" is forbidden", 403)
key: Any
valid_types: Tuple[type, ...]
if action != "GET":
if req_type != RequestType.GET:
if namespace in self.protected_namespaces and not is_debug:
raise self.server.error(
f"Write access to namespace '{namespace}'"
Expand All @@ -758,16 +761,17 @@ async def _handle_item_request(self,
raise self.server.error(
"Value for argument 'key' is an invalid type: "
f"{type(key).__name__}")
if action == "GET":
if req_type == RequestType.GET:
val = await self.get_item(namespace, key)
elif action == "POST":
elif req_type == RequestType.POST:
val = web_request.get("value")
await self.insert_item(namespace, key, val)
elif action == "DELETE":
elif req_type == RequestType.DELETE:
val = await self.delete_item(namespace, key, drop_empty_db=True)

if is_debug:
self.debug_counter[action.lower()] += 1
name = req_type.name or str(req_type).split(".", 1)[-1]
self.debug_counter[name.lower()] += 1
await self.insert_item(
"moonraker", "database.debug_counter", self.debug_counter
)
Expand Down
14 changes: 7 additions & 7 deletions moonraker/components/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import pathlib
import logging
from ..common import BaseRemoteConnection
from ..common import BaseRemoteConnection, RequestType, TransportType
from ..utils import get_unix_peer_credentials

# Annotation imports
Expand Down Expand Up @@ -35,19 +35,19 @@ def __init__(self, config: ConfigHelper) -> None:
self.agent_methods: Dict[int, List[str]] = {}
self.uds_server: Optional[asyncio.AbstractServer] = None
self.server.register_endpoint(
"/connection/register_remote_method", ["POST"],
"/connection/register_remote_method", RequestType.POST,
self._register_agent_method,
transports=["websocket"]
transports=TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/connection/send_event", ["POST"], self._handle_agent_event,
transports=["websocket"]
"/connection/send_event", RequestType.POST, self._handle_agent_event,
transports=TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/server/extensions/list", ["GET"], self._handle_list_extensions
"/server/extensions/list", RequestType.GET, self._handle_list_extensions
)
self.server.register_endpoint(
"/server/extensions/request", ["POST"], self._handle_call_agent
"/server/extensions/request", RequestType.POST, self._handle_call_agent
)

def register_agent(self, connection: BaseRemoteConnection) -> None:
Expand Down
Loading
Loading