diff --git a/moonraker/app.py b/moonraker/app.py index b5ee1b7f2..47d2fb52a 100644 --- a/moonraker/app.py +++ b/moonraker/app.py @@ -22,8 +22,16 @@ from tornado.routing import Rule, PathMatches, AnyMatches from tornado.http1connection import HTTP1Connection from tornado.log import access_log -from .common import WebRequest, APIDefinition, APITransport -from .utils import ServerError, source_info +from .utils import ServerError, source_info, parse_ip_address +from .common import ( + JsonRPC, + WebRequest, + APIDefinition, + APITransport, + TransportType, + RequestType, + KlippyState +) from .utils import json_wrapper as jsonw from .websockets import ( WebsocketManager, @@ -43,7 +51,6 @@ Union, Dict, List, - Tuple, AsyncGenerator, ) if TYPE_CHECKING: @@ -52,6 +59,7 @@ from .eventloop import EventLoop from .confighelper import ConfigHelper from .klippy_connection import KlippyConnection as Klippy + from .utils import IPAddress from .components.file_manager.file_manager import FileManager from .components.announcements import Announcements from .components.machine import Machine @@ -69,7 +77,6 @@ EXCLUDED_ARGS = ["_", "token", "access_token", "connection_id"] AUTHORIZED_EXTS = [".png", ".jpg"] DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log" -ALL_TRANSPORTS = ["http", "websocket", "mqtt", "internal"] class MutableRouter(tornado.web.ReversibleRuleRouter): def __init__(self, application: MoonrakerApp) -> None: @@ -115,48 +122,33 @@ def remove_handler(self, pattern: str) -> None: class InternalTransport(APITransport): def __init__(self, server: Server) -> None: self.server = server - self.callbacks: Dict[str, Tuple[str, str, APICallback]] = {} - - def register_api_handler(self, api_def: APIDefinition) -> None: - ep = api_def.endpoint - cb = api_def.callback - if cb is None: - # Request to Klippy - method = api_def.jrpc_methods[0] - action = "" - klippy: Klippy = self.server.lookup_component("klippy_connection") - cb = klippy.request - self.callbacks[method] = (ep, action, cb) - else: - for method, action in \ - zip(api_def.jrpc_methods, api_def.request_methods): - self.callbacks[method] = (ep, action, cb) - - def remove_api_handler(self, api_def: APIDefinition) -> None: - for method in api_def.jrpc_methods: - self.callbacks.pop(method, None) async def call_method(self, method_name: str, request_arguments: Dict[str, Any] = {}, **kwargs ) -> Any: - if method_name not in self.callbacks: + rpc: JsonRPC = self.server.lookup_component("jsonrpc") + method_info = rpc.get_method(method_name) + if method_info is None: + raise self.server.error(f"No method {method_name} available") + req_type, api_definition = method_info + if TransportType.INTERNAL not in api_definition.transports: raise self.server.error(f"No method {method_name} available") - ep, action, func = self.callbacks[method_name] - # Request arguments can be suppplied either through a dict object - # or via keyword arugments args = request_arguments or kwargs - return await func(WebRequest(ep, dict(args), action)) + return await api_definition.request(args, req_type, self) class MoonrakerApp: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() + self.json_rpc = JsonRPC(self.server) self.http_server: Optional[HTTPServer] = None self.secure_server: Optional[HTTPServer] = None - self.api_cache: Dict[str, APIDefinition] = {} self.template_cache: Dict[str, JinjaTemplate] = {} - self.registered_base_handlers: List[str] = [] + self.registered_base_handlers: List[str] = [ + "/server/redirect", + "/server/jsonrpc" + ] self.max_upload_size = config.getint('max_upload_size', 1024) self.max_upload_size *= 1024 * 1024 max_ws_conns = config.getint( @@ -184,14 +176,7 @@ def __init__(self, config: ConfigHelper) -> None: ) self._route_prefix = f"/{rp}" home_pattern = f"{self._route_prefix}/?" - - # Set Up Websocket and Authorization Managers - self.wsm = WebsocketManager(self.server) self.internal_transport = InternalTransport(self.server) - self.api_transports: Dict[str, APITransport] = { - "websocket": self.wsm, - "internal": self.internal_transport - } mimetypes.add_type('text/plain', '.log') mimetypes.add_type('text/plain', '.gcode') @@ -216,7 +201,8 @@ def __init__(self, config: ConfigHelper) -> None: (home_pattern, WelcomeHandler), (f"{self._route_prefix}/websocket", WebSocket), (f"{self._route_prefix}/klippysocket", BridgeSocket), - (f"{self._route_prefix}/server/redirect", RedirectHandler) + (f"{self._route_prefix}/server/redirect", RedirectHandler), + (f"{self._route_prefix}/server/jsonrpc", RPCHandler) ] self.app = tornado.web.Application(app_handlers, **app_args) self.get_handler_delegate = self.app.get_handler_delegate @@ -232,9 +218,8 @@ def __init__(self, config: ConfigHelper) -> None: # Register Server Components self.server.register_component("application", self) - self.server.register_component("websockets", self.wsm) - self.server.register_component("internal_transport", - self.internal_transport) + self.server.register_component("jsonrpc", self.json_rpc) + self.server.register_component("internal_transport", self.internal_transport) def _get_path_option( self, config: ConfigHelper, option: str @@ -322,63 +307,44 @@ async def close(self) -> None: if self.secure_server is not None: self.secure_server.stop() await self.secure_server.close_all_connections() - await self.wsm.close() - - def register_api_transport( - self, name: str, transport: APITransport - ) -> Dict[str, APIDefinition]: - self.api_transports[name] = transport - return self.api_cache - - def register_remote_handler(self, endpoint: str) -> None: - api_def = self._create_api_definition(endpoint) - if api_def.uri in self.registered_base_handlers: - # reserved handler or already registered - return - logging.info( - f"Registering HTTP endpoint: " - f"({' '.join(api_def.request_methods)}) {api_def.uri}") - params: Dict[str, Any] = {} - params['methods'] = api_def.request_methods - params['callback'] = api_def.endpoint - params['need_object_parser'] = api_def.need_object_parser - self.mutable_router.add_handler( - f"{self._route_prefix}{api_def.uri}", DynamicRequestHandler, params - ) - self.registered_base_handlers.append(api_def.uri) - for name, transport in self.api_transports.items(): - transport.register_api_handler(api_def) - def register_local_handler( + def register_endpoint( self, - uri: str, - request_methods: List[str], + endpoint: str, + request_types: Union[List[str], RequestType], callback: APICallback, - transports: List[str] = ALL_TRANSPORTS, + transports: Union[List[str], TransportType] = TransportType.all(), wrap_result: bool = True, - content_type: Optional[str] = None + content_type: Optional[str] = None, + auth_required: bool = True, + is_remote: bool = False ) -> None: - if uri in self.registered_base_handlers: + if isinstance(request_types, list): + request_types = RequestType.from_string_list(request_types) + if isinstance(transports, list): + transports = TransportType.from_string_list(transports) + api_def = APIDefinition.create( + endpoint, request_types, callback, transports, auth_required, is_remote + ) + http_path = api_def.http_path + if http_path in self.registered_base_handlers: + if not is_remote: + raise self.server.error( + f"Local endpoint '{endpoint}' already registered" + ) return - api_def = self._create_api_definition( - uri, request_methods, callback, transports=transports) - if "http" in transports: - logging.info( - f"Registering HTTP Endpoint: " - f"({' '.join(request_methods)}) {uri}") + logging.debug(f"Registering API: {api_def}") + if TransportType.HTTP in transports: params: dict[str, Any] = {} - params['methods'] = request_methods - params['callback'] = callback - params['wrap_result'] = wrap_result - params['is_remote'] = False - params['content_type'] = content_type + params["api_definition"] = api_def + params["wrap_result"] = wrap_result + params["content_type"] = content_type self.mutable_router.add_handler( - f"{self._route_prefix}{uri}", DynamicRequestHandler, params + f"{self._route_prefix}{http_path}", DynamicRequestHandler, params ) - self.registered_base_handlers.append(uri) - for name, transport in self.api_transports.items(): - if name in transports: - transport.register_api_handler(api_def) + self.registered_base_handlers.append(http_path) + for request_type, method_name in api_def.rpc_items(): + self.json_rpc.register_method(method_name, request_type, api_def) def register_static_file_handler( self, pattern: str, file_path: str, force: bool = False @@ -412,72 +378,33 @@ def register_upload_handler( f"{self._route_prefix}{pattern}", FileUploadHandler, params ) - def register_debug_handler( + def register_debug_endpoint( self, - uri: str, - request_methods: List[str], + endpoint: str, + request_types: Union[List[str], RequestType], callback: APICallback, - transports: List[str] = ALL_TRANSPORTS, + transports: Union[List[str], TransportType] = TransportType.all(), wrap_result: bool = True ) -> None: if not self.server.is_debug_enabled(): return - if not uri.startswith("/debug"): + if not endpoint.startswith("/debug"): raise self.server.error( - "Debug Endpoints must be registerd in the '/debug' path" + "Debug Endpoints must be registered in the '/debug' path" ) - self.register_local_handler( - uri, request_methods, callback, transports, wrap_result + self.register_endpoint( + endpoint, request_types, callback, transports, wrap_result ) - def remove_handler(self, endpoint: str) -> None: - api_def = self.api_cache.pop(endpoint, None) + def remove_endpoint(self, endpoint: str) -> None: + api_def = APIDefinition.pop_cached_def(endpoint) if api_def is not None: - self.mutable_router.remove_handler(api_def.uri) - for name, transport in self.api_transports.items(): - transport.remove_api_handler(api_def) - - def _create_api_definition( - self, - endpoint: str, - request_methods: List[str] = [], - callback: Optional[APICallback] = None, - transports: List[str] = ALL_TRANSPORTS - ) -> APIDefinition: - is_remote = callback is None - if endpoint in self.api_cache: - return self.api_cache[endpoint] - if endpoint[0] == '/': - uri = endpoint - elif is_remote: - uri = "/printer/" + endpoint - else: - uri = "/server/" + endpoint - jrpc_methods = [] - if is_remote: - # Remote requests accept both GET and POST requests. These - # requests execute the same callback, thus they resolve to - # only a single websocket method. - jrpc_methods.append(uri[1:].replace('/', '.')) - request_methods = ['GET', 'POST'] - else: - name_parts = uri[1:].split('/') - if len(request_methods) > 1: - for req_mthd in request_methods: - func_name = req_mthd.lower() + "_" + name_parts[-1] - jrpc_methods.append(".".join( - name_parts[:-1] + [func_name])) - else: - jrpc_methods.append(".".join(name_parts)) - if not is_remote and len(request_methods) != len(jrpc_methods): - raise self.server.error( - "Invalid API definition. Number of websocket methods must " - "match the number of request methods") - need_object_parser = endpoint.startswith("objects/") - api_def = APIDefinition(endpoint, uri, jrpc_methods, request_methods, - transports, callback, need_object_parser) - self.api_cache[endpoint] = api_def - return api_def + logging.debug(f"Removing Endpoint: {endpoint}") + if api_def.http_path in self.registered_base_handlers: + self.registered_base_handlers.remove(api_def.http_path) + self.mutable_router.remove_handler(api_def.http_path) + for method_name in api_def.rpc_methods: + self.json_rpc.remove_method(method_name) async def load_template(self, asset_name: str) -> JinjaTemplate: if asset_name in self.template_cache: @@ -496,7 +423,7 @@ async def load_template(self, asset_name: str) -> JinjaTemplate: class AuthorizedRequestHandler(tornado.web.RequestHandler): def initialize(self) -> None: self.server: Server = self.settings['server'] - self.endpoint: str = "" + self.auth_required: bool = True def set_default_headers(self) -> None: origin: Optional[str] = self.request.headers.get("Origin") @@ -509,11 +436,11 @@ def set_default_headers(self) -> None: self.cors_enabled = auth.check_cors(origin, self) def prepare(self) -> None: - app: MoonrakerApp = self.server.lookup_component("application") - self.endpoint = app.parse_endpoint(self.request.path or "") auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: - self.current_user = auth.check_authorized(self.request, self.endpoint) + self.current_user = auth.authenticate_request( + self.request, self.auth_required + ) def options(self, *args, **kwargs) -> None: # Enable CORS if configured @@ -534,8 +461,7 @@ def get_associated_websocket(self) -> Optional[WebSocket]: except Exception: pass else: - wsm: WebsocketManager = self.server.lookup_component( - "websockets") + wsm: WebsocketManager = self.server.lookup_component("websockets") conn = wsm.get_client(conn_id) if not isinstance(conn, WebSocket): return None @@ -558,7 +484,6 @@ def initialize(self, ) -> None: super(AuthorizedFileHandler, self).initialize(path, default_filename) self.server: Server = self.settings['server'] - self.endpoint: str = "" def set_default_headers(self) -> None: origin: Optional[str] = self.request.headers.get("Origin") @@ -571,11 +496,11 @@ def set_default_headers(self) -> None: self.cors_enabled = auth.check_cors(origin, self) def prepare(self) -> None: - app: MoonrakerApp = self.server.lookup_component("application") - self.endpoint = app.parse_endpoint(self.request.path or "") auth: AuthComp = self.server.lookup_component('authorization', None) - if auth is not None and self._check_need_auth(): - self.current_user = auth.check_authorized(self.request, self.endpoint) + if auth is not None: + self.current_user = auth.authenticate_request( + self.request, self._check_need_auth() + ) def options(self, *args, **kwargs) -> None: # Enable CORS if configured @@ -604,22 +529,16 @@ def _check_need_auth(self) -> bool: class DynamicRequestHandler(AuthorizedRequestHandler): def initialize( self, - callback: Union[str, Callable[[WebRequest], Coroutine]] = "", - methods: List[str] = [], - need_object_parser: bool = False, - is_remote: bool = True, + api_definition: Optional[APIDefinition] = None, wrap_result: bool = True, content_type: Optional[str] = None ) -> None: super(DynamicRequestHandler, self).initialize() - self.callback = callback - self.methods = methods + assert api_definition is not None + self.api_defintion = api_definition self.wrap_result = wrap_result - self._do_request = self._do_remote_request if is_remote \ - else self._do_local_request - self._parse_query = self._object_parser if need_object_parser \ - else self._default_parser self.content_type = content_type + self.auth_required = api_definition.auth_required # Converts query string values with type hints def _convert_type(self, value: str, hint: str) -> Any: @@ -667,7 +586,10 @@ def _object_parser(self) -> Dict[str, Dict[str, Any]]: def parse_args(self) -> Dict[str, Any]: try: - args = self._parse_query() + if self.api_defintion.need_object_parser: + args: Dict[str, Any] = self._object_parser() + else: + args = self._default_parser() except Exception: raise ServerError( "Error Parsing Request Arguments. " @@ -686,10 +608,11 @@ def parse_args(self) -> Dict[str, Any]: def _log_debug(self, header: str, args: Any) -> None: if self.server.is_verbose_enabled(): resp = args + endpoint = self.api_defintion.endpoint if isinstance(args, dict): if ( - self.endpoint.startswith("/access") or - self.endpoint.startswith("/machine/sudo/password") + endpoint.startswith("/access") or + endpoint.startswith("/machine/sudo/password") ): resp = {key: "" for key in args} elif isinstance(args, str): @@ -698,44 +621,26 @@ def _log_debug(self, header: str, args: Any) -> None: logging.debug(f"{header}::{resp}") async def get(self, *args, **kwargs) -> None: - await self._process_http_request() + await self._process_http_request(RequestType.GET) async def post(self, *args, **kwargs) -> None: - await self._process_http_request() + await self._process_http_request(RequestType.POST) async def delete(self, *args, **kwargs) -> None: - await self._process_http_request() - - async def _do_local_request(self, - args: Dict[str, Any], - conn: Optional[WebSocket] - ) -> Any: - assert callable(self.callback) - return await self.callback( - WebRequest(self.endpoint, args, self.request.method, - conn=conn, ip_addr=self.request.remote_ip or "", - user=self.current_user)) - - async def _do_remote_request(self, - args: Dict[str, Any], - conn: Optional[WebSocket] - ) -> Any: - assert isinstance(self.callback, str) - klippy: Klippy = self.server.lookup_component("klippy_connection") - return await klippy.request( - WebRequest(self.callback, args, conn=conn, - ip_addr=self.request.remote_ip or "", - user=self.current_user)) - - async def _process_http_request(self) -> None: - if self.request.method not in self.methods: + await self._process_http_request(RequestType.DELETE) + + async def _process_http_request(self, req_type: RequestType) -> None: + if req_type not in self.api_defintion.request_types: raise tornado.web.HTTPError(405) - conn = self.get_associated_websocket() args = self.parse_args() + transport = self.get_associated_websocket() req = f"{self.request.method} {self.request.path}" self._log_debug(f"HTTP Request::{req}", args) try: - result = await self._do_request(args, conn) + ip = parse_ip_address(self.request.remote_ip or "") + result = await self.api_defintion.request( + args, req_type, transport, ip, self.current_user + ) except ServerError as e: raise tornado.web.HTTPError( e.status_code, reason=str(e)) from e @@ -751,6 +656,50 @@ async def _process_http_request(self) -> None: self.set_header("Content-Type", self.content_type) self.finish(result) +class RPCHandler(AuthorizedRequestHandler, APITransport): + def initialize(self) -> None: + super(RPCHandler, self).initialize() + self.auth_required = False + + @property + def transport_type(self) -> TransportType: + return TransportType.HTTP + + @property + def user_info(self) -> Optional[Dict[str, Any]]: + return self.current_user + + @property + def ip_addr(self) -> Optional[IPAddress]: + return parse_ip_address(self.request.remote_ip or "") + + def screen_rpc_request( + self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] + ) -> None: + if self.current_user is None and api_def.auth_required: + raise self.server.error("Unauthorized", 401) + if api_def.endpoint == "objects/subscribe": + raise self.server.error( + "Subscriptions not available for HTTP transport", 404 + ) + + def send_status(self, status: Dict[str, Any], eventtime: float) -> None: + # Can't handle status updates. This should not be called, but + # we don't want to raise an exception if it is + pass + + async def post(self, *args, **kwargs) -> None: + content_type = self.request.headers.get('Content-Type', "").strip() + if not content_type.startswith("application/json"): + raise tornado.web.HTTPError( + 400, "Invalid content type, application/json required" + ) + rpc: JsonRPC = self.server.lookup_component("jsonrpc") + result = await rpc.dispatch(self.request.body, self) + if result is not None: + self.set_header("Content-Type", "application/json; charset=UTF-8") + self.finish(result) + class FileRequestHandler(AuthorizedFileHandler): def set_extra_headers(self, path: str) -> None: # The call below shold never return an empty string, @@ -766,7 +715,9 @@ def set_extra_headers(self, path: str) -> None: f"filename*=UTF-8\'\'{utf8_basename}") async def delete(self, path: str) -> None: - path = self.endpoint.lstrip("/").split("/", 2)[-1] + app: MoonrakerApp = self.server.lookup_component("application") + endpoint = app.parse_endpoint(self.request.path or "") + path = endpoint.lstrip("/").split("/", 2)[-1] path = url_unescape(path, plus=False) file_manager: FileManager file_manager = self.server.lookup_component('file_manager') @@ -1047,6 +998,10 @@ def write_error(self, status_code: int, **kwargs) -> None: self.finish(jsonw.dumps({'error': err})) class RedirectHandler(AuthorizedRequestHandler): + def initialize(self) -> None: + super().initialize() + self.auth_required = False + def get(self, *args, **kwargs) -> None: url: Optional[str] = self.get_argument('url', None) if url is None: @@ -1075,7 +1030,7 @@ async def get(self) -> None: auth: AuthComp = self.server.lookup_component("authorization", None) if auth is not None: try: - auth.check_authorized(self.request) + auth.authenticate_request(self.request) except tornado.web.HTTPError: authorized = False else: @@ -1118,11 +1073,10 @@ async def get(self) -> None: "The [authorization] section in moonraker.conf must be " "configured to enable CORS." ) - kstate = self.server.get_klippy_state() - if kstate != "disconnected": - kinfo = self.server.get_klippy_info() - kmsg = kinfo.get("state_message", kstate) - summary.append(f"Klipper reports {kmsg.lower()}") + kconn: Klippy = self.server.lookup_component("klippy_connection") + kstate = kconn.state + if kstate != KlippyState.DISCONNECTED: + summary.append(f"Klipper reports {kstate.message.lower()}") else: summary.append( "Moonraker is not currently connected to Klipper. Make sure " diff --git a/moonraker/common.py b/moonraker/common.py index ff93843ce..d0c63d395 100644 --- a/moonraker/common.py +++ b/moonraker/common.py @@ -5,9 +5,12 @@ # This file may be distributed under the terms of the GNU GPLv3 license from __future__ import annotations -import ipaddress +import sys import logging import copy +import re +from enum import Enum, Flag, auto +from dataclasses import dataclass from .utils import ServerError, Sentinel from .utils import json_wrapper as jsonw @@ -23,63 +26,275 @@ Union, Dict, List, - Awaitable + Awaitable, + ClassVar, + Tuple ) if TYPE_CHECKING: from .server import Server from .websockets import WebsocketManager from .components.authorization import Authorization + from .utils import IPAddress from asyncio import Future _T = TypeVar("_T") _C = TypeVar("_C", str, bool, float, int) - IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + _F = TypeVar("_F", bound="ExtendedFlag") ConvType = Union[str, bool, float, int] ArgVal = Union[None, int, float, bool, str] RPCCallback = Callable[..., Coroutine] AuthComp = Optional[Authorization] -class Subscribable: - def send_status(self, - status: Dict[str, Any], - eventtime: float - ) -> None: - raise NotImplementedError +ENDPOINT_PREFIXES = ["printer", "server", "machine", "access", "api", "debug"] + +class ExtendedFlag(Flag): + @classmethod + def from_string(cls: Type[_F], flag_name: str) -> _F: + str_name = flag_name.upper() + for name, member in cls.__members__.items(): + if name == str_name: + return cls(member.value) + raise ValueError(f"No flag member named {flag_name}") + + @classmethod + def from_string_list(cls: Type[_F], flag_list: List[str]) -> _F: + ret = cls(0) + for flag in flag_list: + flag = flag.upper() + ret |= cls.from_string(flag) + return ret + + @classmethod + def all(cls: Type[_F]) -> _F: + return ~cls(0) + + if sys.version_info < (3, 11): + def __len__(self) -> int: + return bin(self._value_).count("1") + + def __iter__(self): + for i in range(self._value_.bit_length()): + val = 1 << i + if val & self._value_ == val: + yield self.__class__(val) + +class RequestType(ExtendedFlag): + """ + The Request Type is also known as the "Request Method" for + HTTP/REST APIs. The use of "Request Method" nomenclature + is discouraged in Moonraker as it could be confused with + the JSON-RPC "method" field. + """ + GET = auto() + POST = auto() + DELETE = auto() + +class TransportType(ExtendedFlag): + HTTP = auto() + WEBSOCKET = auto() + MQTT = auto() + INTERNAL = auto() + +class ExtendedEnum(Enum): + @classmethod + def from_string(cls, enum_name: str): + str_name = enum_name.upper() + for name, member in cls.__members__.items(): + if name == str_name: + return cls(member.value) + raise ValueError(f"No enum member named {enum_name}") + + def __str__(self) -> str: + return self._name_.lower() # type: ignore + +class JobEvent(ExtendedEnum): + STANDBY = 1 + STARTED = 2 + PAUSED = 3 + RESUMED = 4 + COMPLETE = 5 + ERROR = 6 + CANCELLED = 7 + + @property + def finished(self) -> bool: + return self.value >= 5 + @property + def aborted(self) -> bool: + return self.value >= 6 + + @property + def is_printing(self) -> bool: + return self.value in [2, 4] + +class KlippyState(ExtendedEnum): + DISCONNECTED = 1 + STARTUP = 2 + READY = 3 + ERROR = 4 + SHUTDOWN = 5 + + @classmethod + def from_string(cls, enum_name: str, msg: str = ""): + str_name = enum_name.upper() + for name, member in cls.__members__.items(): + if name == str_name: + instance = cls(member.value) + if msg: + instance.set_message(msg) + return instance + raise ValueError(f"No enum member named {enum_name}") + + + def set_message(self, msg: str) -> None: + self._state_message: str = msg + + @property + def message(self) -> str: + if hasattr(self, "_state_message"): + return self._state_message + return "" + + def startup_complete(self) -> bool: + return self.value > 2 + +@dataclass(frozen=True) class APIDefinition: - def __init__(self, - endpoint: str, - http_uri: str, - jrpc_methods: List[str], - request_methods: Union[str, List[str]], - transports: List[str], - callback: Optional[Callable[[WebRequest], Coroutine]], - need_object_parser: bool): - self.endpoint = endpoint - self.uri = http_uri - self.jrpc_methods = jrpc_methods - if not isinstance(request_methods, list): - request_methods = [request_methods] - self.request_methods = request_methods - self.supported_transports = transports - self.callback = callback - self.need_object_parser = need_object_parser + endpoint: str + http_path: str + rpc_methods: List[str] + request_types: RequestType + transports: TransportType + callback: Callable[[WebRequest], Coroutine] + auth_required: bool + _cache: ClassVar[Dict[str, APIDefinition]] = {} + + def __str__(self) -> str: + tprt_str = "|".join([tprt.name for tprt in self.transports if tprt.name]) + val: str = f"(Transports: {tprt_str})" + if TransportType.HTTP in self.transports: + req_types = "|".join([rt.name for rt in self.request_types if rt.name]) + val += f" (HTTP Request: {req_types} {self.http_path})" + if self.rpc_methods: + methods = " ".join(self.rpc_methods) + val += f" (RPC Methods: {methods})" + val += f" (Auth Required: {self.auth_required})" + return val + + def request( + self, + args: Dict[str, Any], + request_type: RequestType, + transport: Optional[APITransport] = None, + ip_addr: Optional[IPAddress] = None, + user: Optional[Dict[str, Any]] = None + ) -> Coroutine: + return self.callback( + WebRequest(self.endpoint, args, request_type, transport, ip_addr, user) + ) + + @property + def need_object_parser(self) -> bool: + return self.endpoint.startswith("objects/") + + def rpc_items(self) -> zip[Tuple[RequestType, str]]: + return zip(self.request_types, self.rpc_methods) + + @classmethod + def create( + cls, + endpoint: str, + request_types: Union[List[str], RequestType], + callback: Callable[[WebRequest], Coroutine], + transports: Union[List[str], TransportType] = TransportType.all(), + auth_required: bool = True, + is_remote: bool = False + ) -> APIDefinition: + if isinstance(request_types, list): + request_types = RequestType.from_string_list(request_types) + if isinstance(transports, list): + transports = TransportType.from_string_list(transports) + if endpoint in cls._cache: + return cls._cache[endpoint] + http_path = f"/printer/{endpoint.strip('/')}" if is_remote else endpoint + prf_match = re.match(r"/([^/]+)", http_path) + if TransportType.HTTP in transports: + # Validate the first path segment for definitions that support the + # HTTP transport. We want to restrict components from registering + # using unknown paths. + if prf_match is None or prf_match.group(1) not in ENDPOINT_PREFIXES: + prefixes = [f"/{prefix} " for prefix in ENDPOINT_PREFIXES] + raise ServerError( + f"Invalid endpoint name '{endpoint}', must start with one of " + f"the following: {prefixes}" + ) + rpc_methods: List[str] = [] + if is_remote: + # Request Types have no meaning for remote requests. Therefore + # both GET and POST http requests are accepted. JRPC requests do + # not need an associated RequestType, so the unknown value is used. + request_types = RequestType.GET | RequestType.POST + rpc_methods.append(http_path[1:].replace('/', '.')) + elif transports != TransportType.HTTP: + name_parts = http_path[1:].split('/') + if len(request_types) > 1: + for rtype in request_types: + func_name = rtype.name.lower() + "_" + name_parts[-1] + rpc_methods.append(".".join(name_parts[:-1] + [func_name])) + else: + rpc_methods.append(".".join(name_parts)) + if len(request_types) != len(rpc_methods): + raise ServerError( + "Invalid API definition. Number of websocket methods must " + "match the number of request methods" + ) + + api_def = cls( + endpoint, http_path, rpc_methods, request_types, + transports, callback, auth_required + ) + cls._cache[endpoint] = api_def + return api_def + + @classmethod + def pop_cached_def(cls, endpoint: str) -> Optional[APIDefinition]: + return cls._cache.pop(endpoint, None) + + @classmethod + def get_cache(cls) -> Dict[str, APIDefinition]: + return cls._cache class APITransport: - def register_api_handler(self, api_def: APIDefinition) -> None: - raise NotImplementedError + @property + def transport_type(self) -> TransportType: + return TransportType.INTERNAL + + @property + def user_info(self) -> Optional[Dict[str, Any]]: + return None + + @property + def ip_addr(self) -> Optional[IPAddress]: + return None - def remove_api_handler(self, api_def: APIDefinition) -> None: + def screen_rpc_request( + self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] + ) -> None: + return None + + def send_status( + self, status: Dict[str, Any], eventtime: float + ) -> None: raise NotImplementedError -class BaseRemoteConnection(Subscribable): +class BaseRemoteConnection(APITransport): def on_create(self, server: Server) -> None: self.server = server self.eventloop = server.get_event_loop() self.wsm: WebsocketManager = self.server.lookup_component("websockets") - self.rpc = self.wsm.rpc + self.rpc: JsonRPC = self.server.lookup_component("jsonrpc") self._uid = id(self) - self.ip_addr = "" self.is_closed: bool = False self.queue_busy: bool = False self.pending_responses: Dict[int, Future] = {} @@ -133,6 +348,15 @@ def client_data(self, data: Dict[str, str]) -> None: self._client_data = data self._identified = True + @property + def transport_type(self) -> TransportType: + return TransportType.WEBSOCKET + + def screen_rpc_request( + self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] + ) -> None: + self.check_authenticated(api_def) + async def _process_message(self, message: str) -> None: try: response = await self.rpc.dispatch(message, self) @@ -162,16 +386,16 @@ def authenticate( self.user_info = auth.validate_jwt(token) elif api_key is not None and self.user_info is None: self.user_info = auth.validate_api_key(api_key) - else: - self.check_authenticated() + elif self._need_auth: + raise self.server.error("Unauthorized", 401) - def check_authenticated(self, path: str = "") -> None: + def check_authenticated(self, api_def: APIDefinition) -> None: if not self._need_auth: return auth: AuthComp = self.server.lookup_component("authorization", None) if auth is None: return - if not auth.is_path_permitted(path): + if api_def.auth_required: raise self.server.error("Unauthorized", 401) def on_user_logout(self, user: str) -> bool: @@ -256,43 +480,43 @@ def close_socket(self, code: int, reason: str) -> None: class WebRequest: - def __init__(self, - endpoint: str, - args: Dict[str, Any], - action: Optional[str] = "", - conn: Optional[Subscribable] = None, - ip_addr: str = "", - user: Optional[Dict[str, Any]] = None - ) -> None: + def __init__( + self, + endpoint: str, + args: Dict[str, Any], + request_type: RequestType = RequestType(0), + transport: Optional[APITransport] = None, + ip_addr: Optional[IPAddress] = None, + user: Optional[Dict[str, Any]] = None + ) -> None: self.endpoint = endpoint - self.action = action or "" self.args = args - self.conn = conn - self.ip_addr: Optional[IPUnion] = None - try: - self.ip_addr = ipaddress.ip_address(ip_addr) - except Exception: - self.ip_addr = None + self.transport = transport + self.request_type = request_type + self.ip_addr: Optional[IPAddress] = ip_addr self.current_user = user def get_endpoint(self) -> str: return self.endpoint + def get_request_type(self) -> RequestType: + return self.request_type + def get_action(self) -> str: - return self.action + return self.request_type.name or "" def get_args(self) -> Dict[str, Any]: return self.args - def get_subscribable(self) -> Optional[Subscribable]: - return self.conn + def get_subscribable(self) -> Optional[APITransport]: + return self.transport def get_client_connection(self) -> Optional[BaseRemoteConnection]: - if isinstance(self.conn, BaseRemoteConnection): - return self.conn + if isinstance(self.transport, BaseRemoteConnection): + return self.transport return None - def get_ip_address(self) -> Optional[IPUnion]: + def get_ip_address(self) -> Optional[IPAddress]: return self.ip_addr def get_current_user(self) -> Optional[Dict[str, Any]]: @@ -410,15 +634,12 @@ def get_list( class JsonRPC: - def __init__( - self, server: Server, transport: str = "Websocket" - ) -> None: - self.methods: Dict[str, RPCCallback] = {} - self.transport = transport + def __init__(self, server: Server) -> None: + self.methods: Dict[str, Tuple[RequestType, APIDefinition]] = {} self.sanitize_response = False self.verbose = server.is_verbose_enabled() - def _log_request(self, rpc_obj: Dict[str, Any], ) -> None: + def _log_request(self, rpc_obj: Dict[str, Any], trtype: TransportType) -> None: if not self.verbose: return self.sanitize_response = False @@ -439,9 +660,11 @@ def _log_request(self, rpc_obj: Dict[str, Any], ) -> None: for field in ["access_token", "api_key"]: if field in params: output["params"][field] = "" - logging.debug(f"{self.transport} Received::{jsonw.dumps(output).decode()}") + logging.debug(f"{trtype} Received::{jsonw.dumps(output).decode()}") - def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: + def _log_response( + self, resp_obj: Optional[Dict[str, Any]], trtype: TransportType + ) -> None: if not self.verbose: return if resp_obj is None: @@ -451,67 +674,84 @@ def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: output = copy.deepcopy(resp_obj) output["result"] = "" self.sanitize_response = False - logging.debug(f"{self.transport} Response::{jsonw.dumps(output).decode()}") + logging.debug(f"{trtype} Response::{jsonw.dumps(output).decode()}") - def register_method(self, - name: str, - method: RPCCallback - ) -> None: - self.methods[name] = method + def register_method( + self, + name: str, + request_type: RequestType, + api_definition: APIDefinition + ) -> None: + self.methods[name] = (request_type, api_definition) + + def get_method(self, name: str) -> Optional[Tuple[RequestType, APIDefinition]]: + return self.methods.get(name, None) def remove_method(self, name: str) -> None: self.methods.pop(name, None) - async def dispatch(self, - data: str, - conn: Optional[BaseRemoteConnection] = None - ) -> Optional[bytes]: + async def dispatch( + self, + data: Union[str, bytes], + transport: APITransport + ) -> Optional[bytes]: + transport_type = transport.transport_type try: obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data) except Exception: - msg = f"{self.transport} data not json: {data}" + if isinstance(data, bytes): + data = data.decode() + msg = f"{transport_type} data not valid json: {data}" logging.exception(msg) err = self.build_error(-32700, "Parse error") return jsonw.dumps(err) if isinstance(obj, list): responses: List[Dict[str, Any]] = [] for item in obj: - self._log_request(item) - resp = await self.process_object(item, conn) + self._log_request(item, transport_type) + resp = await self.process_object(item, transport) if resp is not None: - self._log_response(resp) + self._log_response(resp, transport_type) responses.append(resp) if responses: return jsonw.dumps(responses) else: - self._log_request(obj) - response = await self.process_object(obj, conn) + self._log_request(obj, transport_type) + response = await self.process_object(obj, transport) if response is not None: - self._log_response(response) + self._log_response(response, transport_type) return jsonw.dumps(response) return None - async def process_object(self, - obj: Dict[str, Any], - conn: Optional[BaseRemoteConnection] - ) -> Optional[Dict[str, Any]]: + async def process_object( + self, + obj: Dict[str, Any], + transport: APITransport + ) -> Optional[Dict[str, Any]]: req_id: Optional[int] = obj.get('id', None) rpc_version: str = obj.get('jsonrpc', "") if rpc_version != "2.0": return self.build_error(-32600, "Invalid Request", req_id) method_name = obj.get('method', Sentinel.MISSING) if method_name is Sentinel.MISSING: - self.process_response(obj, conn) + self.process_response(obj, transport) return None if not isinstance(method_name, str): return self.build_error( -32600, "Invalid Request", req_id, method_name=str(method_name) ) - method = self.methods.get(method_name, None) - if method is None: + method_info = self.methods.get(method_name, None) + if method_info is None: return self.build_error( -32601, "Method not found", req_id, method_name=method_name ) + request_type, api_definition = method_info + transport_type = transport.transport_type + if transport_type not in api_definition.transports: + return self.build_error( + -32601, f"Method not found for transport {transport_type.name}", + req_id, method_name=method_name + ) params: Dict[str, Any] = {} if 'params' in obj: params = obj['params'] @@ -519,12 +759,14 @@ async def process_object(self, return self.build_error( -32602, "Invalid params:", req_id, method_name=method_name ) - return await self.execute_method(method_name, method, req_id, conn, params) + return await self.execute_method( + method_name, request_type, api_definition, req_id, transport, params + ) def process_response( - self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection] + self, obj: Dict[str, Any], conn: APITransport ) -> None: - if conn is None: + if not isinstance(conn, BaseRemoteConnection): logging.debug(f"RPC Response to non-socket request: {obj}") return response_id = obj.get("id") @@ -549,15 +791,17 @@ def process_response( async def execute_method( self, method_name: str, - callback: RPCCallback, + request_type: RequestType, + api_definition: APIDefinition, req_id: Optional[int], - conn: Optional[BaseRemoteConnection], + transport: APITransport, params: Dict[str, Any] ) -> Optional[Dict[str, Any]]: - if conn is not None: - params["_socket_"] = conn try: - result = await callback(params) + transport.screen_rpc_request(api_definition, request_type, params) + result = await api_definition.request( + params, request_type, transport, transport.ip_addr, transport.user_info + ) except TypeError as e: return self.build_error( -32602, f"Invalid params:\n{e}", req_id, True, method_name diff --git a/moonraker/components/announcements.py b/moonraker/components/announcements.py index 68c8c161e..763197e5b 100644 --- a/moonraker/components/announcements.py +++ b/moonraker/components/announcements.py @@ -11,6 +11,7 @@ import logging import email.utils import xml.etree.ElementTree as etree +from ..common import RequestType from typing import ( TYPE_CHECKING, Awaitable, @@ -57,23 +58,23 @@ def __init__(self, config: ConfigHelper) -> None: ) self.server.register_endpoint( - "/server/announcements/list", ["GET"], + "/server/announcements/list", RequestType.GET, self._list_announcements ) self.server.register_endpoint( - "/server/announcements/dismiss", ["POST"], + "/server/announcements/dismiss", RequestType.POST, self._handle_dismiss_request ) self.server.register_endpoint( - "/server/announcements/update", ["POST"], + "/server/announcements/update", RequestType.POST, self._handle_update_request ) self.server.register_endpoint( - "/server/announcements/feed", ["POST", "DELETE"], + "/server/announcements/feed", RequestType.POST | RequestType.DELETE, self._handle_feed_request ) self.server.register_endpoint( - "/server/announcements/feeds", ["GET"], + "/server/announcements/feeds", RequestType.GET, self._handle_list_feeds ) self.server.register_notification( @@ -170,13 +171,13 @@ async def _handle_list_feeds( async def _handle_feed_request( self, web_request: WebRequest ) -> Dict[str, Any]: - action = web_request.get_action() + req_type = web_request.get_request_type() name: str = web_request.get("name") name = name.lower() changed: bool = False db: MoonrakerDatabase = self.server.lookup_component("database") result = "skipped" - if action == "POST": + if req_type == RequestType.POST: if name not in self.subscriptions: feed = RssFeed(name, self.entry_mgr, self.dev_mode) self.subscriptions[name] = feed @@ -187,7 +188,7 @@ async def _handle_feed_request( "moonraker", "announcements.stored_feeds", self.stored_feeds ) result = "added" - elif action == "DELETE": + elif req_type == RequestType.DELETE: if name not in self.stored_feeds: raise self.server.error(f"Feed '{name}' not stored") if name in self.configured_feeds: diff --git a/moonraker/components/authorization.py b/moonraker/components/authorization.py index 47d861ae1..b6dbc1c9f 100644 --- a/moonraker/components/authorization.py +++ b/moonraker/components/authorization.py @@ -20,13 +20,13 @@ from tornado.web import HTTPError from libnacl.sign import Signer, Verifier from ..utils import json_wrapper as jsonw +from ..common import RequestType, TransportType # Annotation imports from typing import ( TYPE_CHECKING, Any, Tuple, - Set, Optional, Union, Dict, @@ -151,7 +151,6 @@ def __init__(self, config: ConfigHelper) -> None: self.user_db.sync(self.users) self.trusted_users: Dict[IPAddr, Any] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {} - self.permitted_paths: Set[str] = set() # Get allowed cors domains self.cors_domains: List[str] = [] @@ -221,37 +220,46 @@ def __init__(self, config: ConfigHelper) -> None: self._prune_conn_handler) # Register Authorization Endpoints - self.permitted_paths.add("/server/redirect") - self.permitted_paths.add("/access/login") - self.permitted_paths.add("/access/refresh_jwt") - self.permitted_paths.add("/access/info") self.server.register_endpoint( - "/access/login", ['POST'], self._handle_login, - transports=['http', 'websocket']) + "/access/login", RequestType.POST, self._handle_login, + transports=TransportType.HTTP | TransportType.WEBSOCKET, + auth_required=False + ) self.server.register_endpoint( - "/access/logout", ['POST'], self._handle_logout, - transports=['http', 'websocket']) + "/access/logout", RequestType.POST, self._handle_logout, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/refresh_jwt", ['POST'], self._handle_refresh_jwt, - transports=['http', 'websocket']) + "/access/refresh_jwt", RequestType.POST, self._handle_refresh_jwt, + transports=TransportType.HTTP | TransportType.WEBSOCKET, + auth_required=False + ) self.server.register_endpoint( - "/access/user", ['GET', 'POST', 'DELETE'], - self._handle_user_request, transports=['http', 'websocket']) + "/access/user", RequestType.all(), self._handle_user_request, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/users/list", ['GET'], self._handle_list_request, - transports=['http', 'websocket']) + "/access/users/list", RequestType.GET, self._handle_list_request, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/user/password", ['POST'], self._handle_password_reset, - transports=['http', 'websocket']) + "/access/user/password", RequestType.POST, self._handle_password_reset, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/api_key", ['GET', 'POST'], - self._handle_apikey_request, transports=['http', 'websocket']) + "/access/api_key", RequestType.GET | RequestType.POST, + self._handle_apikey_request, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/oneshot_token", ['GET'], - self._handle_oneshot_request, transports=['http', 'websocket']) + "/access/oneshot_token", RequestType.GET, self._handle_oneshot_request, + transports=TransportType.HTTP | TransportType.WEBSOCKET + ) self.server.register_endpoint( - "/access/info", ['GET'], - self._handle_info_request, transports=['http', 'websocket']) + "/access/info", RequestType.GET, self._handle_info_request, + transports=TransportType.HTTP | TransportType.WEBSOCKET, + auth_required=False + ) wsm: WebsocketManager = self.server.lookup_component("websockets") wsm.register_notification("authorization:user_created") wsm.register_notification( @@ -261,12 +269,6 @@ def __init__(self, config: ConfigHelper) -> None: "authorization:user_logged_out", event_type="logout" ) - def register_permited_path(self, path: str) -> None: - self.permitted_paths.add(path) - - def is_path_permitted(self, path: str) -> bool: - return path in self.permitted_paths - def _sync_user(self, username: str) -> None: self.user_db[username] = self.users[username] @@ -274,8 +276,7 @@ async def component_init(self) -> None: self.prune_timer.start(delay=PRUNE_CHECK_TIME) async def _handle_apikey_request(self, web_request: WebRequest) -> str: - action = web_request.get_action() - if action.upper() == 'POST': + if web_request.get_request_type() == RequestType.POST: self.api_key = uuid.uuid4().hex self.users[API_USER]['api_key'] = self.api_key self._sync_user(API_USER) @@ -360,11 +361,11 @@ async def _handle_refresh_jwt(self, 'action': 'user_jwt_refresh' } - async def _handle_user_request(self, - web_request: WebRequest - ) -> Dict[str, Any]: - action = web_request.get_action() - if action == "GET": + async def _handle_user_request( + self, web_request: WebRequest + ) -> Dict[str, Any]: + req_type = web_request.get_request_type() + if req_type == RequestType.GET: user = web_request.get_current_user() if user is None: return { @@ -378,10 +379,10 @@ async def _handle_user_request(self, 'source': user.get("source", "moonraker"), 'created_on': user.get('created_on') } - elif action == "POST": + elif req_type == RequestType.POST: # Create User return await self._login_jwt_user(web_request, create=True) - elif action == "DELETE": + elif req_type == RequestType.DELETE: # Delete User return self._delete_jwt_user(web_request) raise self.server.error("Invalid Request Method") @@ -760,13 +761,10 @@ def check_logins_maxed(self, ip_addr: IPAddr) -> bool: return False return self.failed_logins.get(ip_addr, 0) >= self.max_logins - def check_authorized( - self, request: HTTPServerRequest, endpoint: str = "", + def authenticate_request( + self, request: HTTPServerRequest, auth_required: bool = True ) -> Optional[Dict[str, Any]]: - if ( - endpoint in self.permitted_paths - or request.method == "OPTIONS" - ): + if request.method == "OPTIONS": return None # Check JSON Web Token @@ -794,14 +792,17 @@ def check_authorized( if key and key == self.api_key: return self.users[API_USER] - # If the force_logins option is enabled and at least one - # user is created this is an unauthorized request + # If the force_logins option is enabled and at least one user is created + # then trusted user authentication is disabled if self.force_logins and len(self.users) > 1: + if not auth_required: + return None raise HTTPError(401, "Unauthorized, Force Logins Enabled") - # Check if IP is trusted + # Check if IP is trusted. If this endpoint doesn't require authentication + # then it is acceptable to return None trusted_user = self._check_trusted_connection(ip) - if trusted_user is not None: + if trusted_user is not None or not auth_required: return trusted_user raise HTTPError(401, "Unauthorized") diff --git a/moonraker/components/data_store.py b/moonraker/components/data_store.py index 869aa8dae..b45f162c5 100644 --- a/moonraker/components/data_store.py +++ b/moonraker/components/data_store.py @@ -8,6 +8,7 @@ import logging import time from collections import deque +from ..common import RequestType # Annotation imports from typing import ( @@ -59,11 +60,13 @@ def __init__(self, config: ConfigHelper) -> None: # Register endpoints self.server.register_endpoint( - "/server/temperature_store", ['GET'], - self._handle_temp_store_request) + "/server/temperature_store", RequestType.GET, + self._handle_temp_store_request + ) self.server.register_endpoint( - "/server/gcode_store", ['GET'], - self._handle_gcode_store_request) + "/server/gcode_store", RequestType.GET, + self._handle_gcode_store_request + ) async def _init_sensors(self) -> None: klippy_apis: APIComp = self.server.lookup_component('klippy_apis') diff --git a/moonraker/components/database.py b/moonraker/components/database.py index 7f174507e..70fb1bede 100644 --- a/moonraker/components/database.py +++ b/moonraker/components/database.py @@ -15,6 +15,7 @@ import lmdb from ..utils import Sentinel, ServerError from ..utils import json_wrapper as jsonw +from ..common import RequestType # Annotation imports from typing import ( @@ -174,15 +175,17 @@ def __init__(self, config: ConfigHelper) -> None: self.insert_item("moonraker", "database.unsafe_shutdowns", unsafe_shutdowns + 1) self.server.register_endpoint( - "/server/database/list", ['GET'], self._handle_list_request) + "/server/database/list", RequestType.GET, self._handle_list_request + ) self.server.register_endpoint( - "/server/database/item", ["GET", "POST", "DELETE"], - self._handle_item_request) + "/server/database/item", RequestType.all(), self._handle_item_request + ) self.server.register_debug_endpoint( - "/debug/database/list", ['GET'], self._handle_list_request) + "/debug/database/list", RequestType.GET, self._handle_list_request + ) self.server.register_debug_endpoint( - "/debug/database/item", ["GET", "POST", "DELETE"], - self._handle_item_request) + "/debug/database/item", RequestType.all(), self._handle_item_request + ) def get_database_path(self) -> str: return self.database_path @@ -735,7 +738,7 @@ async def _handle_list_request(self, async def _handle_item_request(self, web_request: WebRequest ) -> Dict[str, Any]: - action = web_request.get_action() + req_type = web_request.get_request_type() is_debug = web_request.get_endpoint().startswith("/debug/") namespace = web_request.get_str("namespace") if namespace in self.forbidden_namespaces and not is_debug: @@ -744,7 +747,7 @@ async def _handle_item_request(self, " is forbidden", 403) key: Any valid_types: Tuple[type, ...] - if action != "GET": + if req_type != RequestType.GET: if namespace in self.protected_namespaces and not is_debug: raise self.server.error( f"Write access to namespace '{namespace}'" @@ -758,16 +761,17 @@ async def _handle_item_request(self, raise self.server.error( "Value for argument 'key' is an invalid type: " f"{type(key).__name__}") - if action == "GET": + if req_type == RequestType.GET: val = await self.get_item(namespace, key) - elif action == "POST": + elif req_type == RequestType.POST: val = web_request.get("value") await self.insert_item(namespace, key, val) - elif action == "DELETE": + elif req_type == RequestType.DELETE: val = await self.delete_item(namespace, key, drop_empty_db=True) if is_debug: - self.debug_counter[action.lower()] += 1 + name = req_type.name or str(req_type).split(".", 1)[-1] + self.debug_counter[name.lower()] += 1 await self.insert_item( "moonraker", "database.debug_counter", self.debug_counter ) diff --git a/moonraker/components/extensions.py b/moonraker/components/extensions.py index e328d5977..3031c7978 100644 --- a/moonraker/components/extensions.py +++ b/moonraker/components/extensions.py @@ -7,7 +7,7 @@ import asyncio import pathlib import logging -from ..common import BaseRemoteConnection +from ..common import BaseRemoteConnection, RequestType, TransportType from ..utils import get_unix_peer_credentials # Annotation imports @@ -35,19 +35,19 @@ def __init__(self, config: ConfigHelper) -> None: self.agent_methods: Dict[int, List[str]] = {} self.uds_server: Optional[asyncio.AbstractServer] = None self.server.register_endpoint( - "/connection/register_remote_method", ["POST"], + "/connection/register_remote_method", RequestType.POST, self._register_agent_method, - transports=["websocket"] + transports=TransportType.WEBSOCKET ) self.server.register_endpoint( - "/connection/send_event", ["POST"], self._handle_agent_event, - transports=["websocket"] + "/connection/send_event", RequestType.POST, self._handle_agent_event, + transports=TransportType.WEBSOCKET ) self.server.register_endpoint( - "/server/extensions/list", ["GET"], self._handle_list_extensions + "/server/extensions/list", RequestType.GET, self._handle_list_extensions ) self.server.register_endpoint( - "/server/extensions/request", ["POST"], self._handle_call_agent + "/server/extensions/request", RequestType.POST, self._handle_call_agent ) def register_agent(self, connection: BaseRemoteConnection) -> None: diff --git a/moonraker/components/file_manager/file_manager.py b/moonraker/components/file_manager/file_manager.py index 0fc29a2d0..5d8161c2b 100644 --- a/moonraker/components/file_manager/file_manager.py +++ b/moonraker/components/file_manager/file_manager.py @@ -20,6 +20,7 @@ from inotify_simple import flags as iFlags from ...utils import source_info from ...utils import json_wrapper as jsonw +from ...common import RequestType, TransportType # Annotation imports from typing import ( @@ -108,27 +109,37 @@ def __init__(self, config: ConfigHelper) -> None: # Register file management endpoints self.server.register_endpoint( - "/server/files/list", ['GET'], self._handle_filelist_request) + "/server/files/list", RequestType.GET, self._handle_filelist_request + ) self.server.register_endpoint( - "/server/files/metadata", ['GET'], self._handle_metadata_request) + "/server/files/metadata", RequestType.GET, self._handle_metadata_request + ) self.server.register_endpoint( - "/server/files/metascan", ['POST'], self._handle_metascan_request) + "/server/files/metascan", RequestType.POST, self._handle_metascan_request + ) self.server.register_endpoint( - "/server/files/thumbnails", ['GET'], self._handle_list_thumbs) + "/server/files/thumbnails", RequestType.GET, self._handle_list_thumbs + ) self.server.register_endpoint( - "/server/files/roots", ['GET'], self._handle_list_roots) + "/server/files/roots", RequestType.GET, self._handle_list_roots + ) self.server.register_endpoint( - "/server/files/directory", ['GET', 'POST', 'DELETE'], - self._handle_directory_request) + "/server/files/directory", RequestType.all(), + self._handle_directory_request + ) self.server.register_endpoint( - "/server/files/move", ['POST'], self._handle_file_move_copy) + "/server/files/move", RequestType.POST, self._handle_file_move_copy + ) self.server.register_endpoint( - "/server/files/copy", ['POST'], self._handle_file_move_copy) + "/server/files/copy", RequestType.POST, self._handle_file_move_copy + ) self.server.register_endpoint( - "/server/files/zip", ['POST'], self._handle_zip_files) + "/server/files/zip", RequestType.POST, self._handle_zip_files + ) self.server.register_endpoint( - "/server/files/delete_file", ['DELETE'], self._handle_file_delete, - transports=["websocket"]) + "/server/files/delete_file", RequestType.DELETE, self._handle_file_delete, + transports=TransportType.WEBSOCKET + ) # register client notificaitons self.server.register_notification("file_manager:filelist_changed") @@ -474,8 +485,8 @@ async def _handle_directory_request(self, ) -> Dict[str, Any]: directory = web_request.get_str('path', "gcodes") root, dir_path = self._convert_request_path(directory) - method = web_request.get_action() - if method == 'GET': + req_type = web_request.get_request_type() + if req_type == RequestType.GET: is_extended = web_request.get_boolean('extended', False) # Get list of files and subdirectories for this target dir_info = self._list_directory(dir_path, root, is_extended) @@ -483,7 +494,7 @@ async def _handle_directory_request(self, async with self.sync_lock: self.check_reserved_path(dir_path, True) action = "create_dir" - if method == 'POST' and root in self.full_access_roots: + if req_type == RequestType.POST and root in self.full_access_roots: # Create a new directory self.sync_lock.setup("create_dir", dir_path) try: @@ -491,7 +502,7 @@ async def _handle_directory_request(self, except Exception as e: raise self.server.error(str(e)) self.fs_observer.on_item_create(root, dir_path, is_dir=True) - elif method == 'DELETE' and root in self.full_access_roots: + elif req_type == RequestType.DELETE and root in self.full_access_roots: # Remove a directory action = "delete_dir" if directory.strip("/") == root: diff --git a/moonraker/components/history.py b/moonraker/components/history.py index 6c0fbd160..cd7ed39b6 100644 --- a/moonraker/components/history.py +++ b/moonraker/components/history.py @@ -6,6 +6,7 @@ import time import logging from asyncio import Lock +from ..common import JobEvent, RequestType # Annotation imports from typing import ( @@ -49,26 +50,23 @@ def __init__(self, config: ConfigHelper) -> None: self.server.register_event_handler( "server:klippy_shutdown", self._handle_shutdown) self.server.register_event_handler( - "job_state:started", self._on_job_started) - self.server.register_event_handler( - "job_state:complete", self._on_job_complete) - self.server.register_event_handler( - "job_state:cancelled", self._on_job_cancelled) - self.server.register_event_handler( - "job_state:standby", self._on_job_standby) - self.server.register_event_handler( - "job_state:error", self._on_job_error) + "job_state:state_changed", self._on_job_state_changed) self.server.register_notification("history:history_changed") self.server.register_endpoint( - "/server/history/job", ['GET', 'DELETE'], self._handle_job_request) + "/server/history/job", RequestType.GET | RequestType.DELETE, + self._handle_job_request + ) self.server.register_endpoint( - "/server/history/list", ['GET'], self._handle_jobs_list) + "/server/history/list", RequestType.GET, self._handle_jobs_list + ) self.server.register_endpoint( - "/server/history/totals", ['GET'], self._handle_job_totals) + "/server/history/totals", RequestType.GET, self._handle_job_totals + ) self.server.register_endpoint( - "/server/history/reset_totals", ['POST'], - self._handle_job_total_reset) + "/server/history/reset_totals", RequestType.POST, + self._handle_job_total_reset + ) database.register_local_namespace(HIST_NAMESPACE) self.history_ns = database.wrap_namespace(HIST_NAMESPACE, @@ -85,14 +83,14 @@ async def _handle_job_request(self, web_request: WebRequest ) -> Dict[str, Any]: async with self.request_lock: - action = web_request.get_action() - if action == "GET": + req_type = web_request.get_request_type() + if req_type == RequestType.GET: job_id = web_request.get_str("uid") if job_id not in self.cached_job_ids: raise self.server.error(f"Invalid job uid: {job_id}", 404) job = await self.history_ns[job_id] return {"job": self._prep_requested_job(job, job_id)} - if action == "DELETE": + if req_type == RequestType.DELETE: all = web_request.get_boolean("all", False) if all: deljobs = self.cached_job_ids @@ -192,40 +190,25 @@ async def _handle_job_total_reset(self, "moonraker", "history.job_totals", self.job_totals) return {'last_totals': last_totals} - def _on_job_started(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: - if self.current_job is not None: - # Finish with the previous state + def _on_job_state_changed( + self, + job_event: JobEvent, + prev_stats: Dict[str, Any], + new_stats: Dict[str, Any] + ) -> None: + if job_event == JobEvent.STARTED: + if self.current_job is not None: + # Finish with the previous state + self.finish_job("cancelled", prev_stats) + self.add_job(PrinterJob(new_stats)) + elif job_event == JobEvent.COMPLETE: + self.finish_job("completed", new_stats) + elif job_event == JobEvent.ERROR: + self.finish_job("error", new_stats) + elif job_event in (JobEvent.CANCELLED, JobEvent.STANDBY): + # Cancel on "standby" for backward compatibility with + # `CLEAR_PAUSE/SDCARD_RESET_FILE` workflow self.finish_job("cancelled", prev_stats) - self.add_job(PrinterJob(new_stats)) - - def _on_job_complete(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: - self.finish_job("completed", new_stats) - - def _on_job_cancelled(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: - self.finish_job("cancelled", new_stats) - - def _on_job_error(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: - self.finish_job("error", new_stats) - - def _on_job_standby(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: - # Backward compatibility with - # `CLEAR_PAUSE/SDCARD_RESET_FILE` workflow - self.finish_job("cancelled", prev_stats) def _handle_shutdown(self) -> None: jstate: JobState = self.server.lookup_component("job_state") diff --git a/moonraker/components/job_queue.py b/moonraker/components/job_queue.py index 274895618..279b3cacc 100644 --- a/moonraker/components/job_queue.py +++ b/moonraker/components/job_queue.py @@ -8,6 +8,7 @@ import asyncio import time import logging +from ..common import JobEvent, RequestType # Annotation imports from typing import ( @@ -46,11 +47,8 @@ def __init__(self, config: ConfigHelper) -> None: self.server.register_event_handler( "server:klippy_shutdown", self._handle_shutdown) self.server.register_event_handler( - "job_state:complete", self._on_job_complete) - self.server.register_event_handler( - "job_state:error", self._on_job_abort) - self.server.register_event_handler( - "job_state:cancelled", self._on_job_abort) + "job_state:state_changed", self._on_job_state_changed + ) self.server.register_notification("job_queue:job_queue_changed") self.server.register_remote_method("pause_job_queue", self.pause_queue) @@ -58,16 +56,21 @@ def __init__(self, config: ConfigHelper) -> None: self.start_queue) self.server.register_endpoint( - "/server/job_queue/job", ['POST', 'DELETE'], - self._handle_job_request) + "/server/job_queue/job", RequestType.POST | RequestType.DELETE, + self._handle_job_request + ) self.server.register_endpoint( - "/server/job_queue/pause", ['POST'], self._handle_pause_queue) + "/server/job_queue/pause", RequestType.POST, self._handle_pause_queue + ) self.server.register_endpoint( - "/server/job_queue/start", ['POST'], self._handle_start_queue) + "/server/job_queue/start", RequestType.POST, self._handle_start_queue + ) self.server.register_endpoint( - "/server/job_queue/status", ['GET'], self._handle_queue_status) + "/server/job_queue/status", RequestType.GET, self._handle_queue_status + ) self.server.register_endpoint( - "/server/job_queue/jump", ['POST'], self._handle_jump) + "/server/job_queue/jump", RequestType.POST, self._handle_jump + ) async def _handle_ready(self) -> None: async with self.lock: @@ -85,10 +88,13 @@ async def _handle_shutdown(self) -> None: if not self.queued_jobs and self.automatic: self._set_queue_state("ready") - async def _on_job_complete(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: + async def _on_job_state_changed(self, job_event: JobEvent, *args) -> None: + if job_event == JobEvent.COMPLETE: + await self._on_job_complete() + elif job_event.aborted: + await self._on_job_abort() + + async def _on_job_complete(self) -> None: if not self.automatic: return async with self.lock: @@ -99,10 +105,7 @@ async def _on_job_complete(self, self.pop_queue_handle = event_loop.delay_callback( self.job_delay, self._pop_job) - async def _on_job_abort(self, - prev_stats: Dict[str, Any], - new_stats: Dict[str, Any] - ) -> None: + async def _on_job_abort(self) -> None: async with self.lock: if self.queued_jobs: self._set_queue_state("paused") @@ -250,23 +253,23 @@ def _send_queue_event(self, action: str = "state_changed"): 'queue_state': self.queue_state }) - async def _handle_job_request(self, - web_request: WebRequest - ) -> Dict[str, Any]: - action = web_request.get_action() - if action == "POST": + async def _handle_job_request( + self, web_request: WebRequest + ) -> Dict[str, Any]: + req_type = web_request.get_request_type() + if req_type == RequestType.POST: files = web_request.get_list('filenames') reset = web_request.get_boolean("reset", False) # Validate that all files exist before queueing await self.queue_job(files, reset=reset) - elif action == "DELETE": + elif req_type == RequestType.DELETE: if web_request.get_boolean("all", False): await self.delete_job([], all=True) else: job_ids = web_request.get_list('job_ids') await self.delete_job(job_ids) else: - raise self.server.error(f"Invalid action: {action}") + raise self.server.error(f"Invalid request type: {req_type}") return { 'queued_jobs': self._job_map_to_list(), 'queue_state': self.queue_state diff --git a/moonraker/components/job_state.py b/moonraker/components/job_state.py index 2d5abf8df..eec17cd83 100644 --- a/moonraker/components/job_state.py +++ b/moonraker/components/job_state.py @@ -15,6 +15,7 @@ Dict, List, ) +from ..common import JobEvent, KlippyState if TYPE_CHECKING: from ..confighelper import ConfigHelper from .klippy_apis import KlippyAPI @@ -26,8 +27,8 @@ def __init__(self, config: ConfigHelper) -> None: self.server.register_event_handler( "server:klippy_started", self._handle_started) - async def _handle_started(self, state: str) -> None: - if state != "ready": + async def _handle_started(self, state: KlippyState) -> None: + if state != KlippyState.READY: return kapis: KlippyAPI = self.server.lookup_component('klippy_apis') sub: Dict[str, Optional[List[str]]] = {"print_stats": None} @@ -65,8 +66,16 @@ async def _status_update(self, data: Dict[str, Any], _: float) -> None: f"Job State Changed - Prev State: {old_state}, " f"New State: {new_state}" ) + # NOTE: Individual job_state events are DEPRECATED. New modules + # should register handlers for "job_state: status_changed" and + # match against the JobEvent object provided. + self.server.send_event(f"job_state:{new_state}", prev_ps, new_ps) self.server.send_event( - f"job_state:{new_state}", prev_ps, new_ps) + "job_state:state_changed", + JobEvent.from_string(new_state), + prev_ps, + new_ps + ) if "info" in ps: cur_layer: Optional[int] = ps["info"].get("current_layer") if cur_layer is not None: diff --git a/moonraker/components/klippy_apis.py b/moonraker/components/klippy_apis.py index 8bdfd253a..88cd4ad90 100644 --- a/moonraker/components/klippy_apis.py +++ b/moonraker/components/klippy_apis.py @@ -6,7 +6,7 @@ from __future__ import annotations from ..utils import Sentinel -from ..common import WebRequest, Subscribable +from ..common import WebRequest, APITransport, RequestType # Annotation imports from typing import ( @@ -38,7 +38,7 @@ OBJ_LIST_ENDPOINT = "objects/list" REG_METHOD_ENDPOINT = "register_remote_method" -class KlippyAPI(Subscribable): +class KlippyAPI(APITransport): def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.klippy: Klippy = self.server.lookup_component("klippy_connection") @@ -52,17 +52,23 @@ def __init__(self, config: ConfigHelper) -> None: # Register GCode Aliases self.server.register_endpoint( - "/printer/print/pause", ['POST'], self._gcode_pause) + "/printer/print/pause", RequestType.POST, self._gcode_pause + ) self.server.register_endpoint( - "/printer/print/resume", ['POST'], self._gcode_resume) + "/printer/print/resume", RequestType.POST, self._gcode_resume + ) self.server.register_endpoint( - "/printer/print/cancel", ['POST'], self._gcode_cancel) + "/printer/print/cancel", RequestType.POST, self._gcode_cancel + ) self.server.register_endpoint( - "/printer/print/start", ['POST'], self._gcode_start_print) + "/printer/print/start", RequestType.POST, self._gcode_start_print + ) self.server.register_endpoint( - "/printer/restart", ['POST'], self._gcode_restart) + "/printer/restart", RequestType.POST, self._gcode_restart + ) self.server.register_endpoint( - "/printer/firmware_restart", ['POST'], self._gcode_firmware_restart) + "/printer/firmware_restart", RequestType.POST, self._gcode_firmware_restart + ) self.server.register_event_handler( "server:klippy_disconnect", self._on_klippy_disconnect ) @@ -94,10 +100,11 @@ async def _send_klippy_request( self, method: str, params: Dict[str, Any], - default: Any = Sentinel.MISSING + default: Any = Sentinel.MISSING, + transport: Optional[APITransport] = None ) -> Any: try: - req = WebRequest(method, params, conn=self) + req = WebRequest(method, params, transport=transport or self) result = await self.klippy.request(req) except self.server.error: if default is Sentinel.MISSING: @@ -221,6 +228,7 @@ async def subscribe_objects( callback: Optional[SubCallback] = None, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[_T, Dict[str, Any]]: + # The host transport shares subscriptions amongst all components for obj, items in objects.items(): if obj in self.host_subscription: prev = self.host_subscription[obj] @@ -231,9 +239,8 @@ async def subscribe_objects( self.host_subscription[obj] = uitems else: self.host_subscription[obj] = items - params = {'objects': dict(self.host_subscription)} - result = await self._send_klippy_request( - SUBSCRIPTION_ENDPOINT, params, default) + params = {"objects": dict(self.host_subscription)} + result = await self._send_klippy_request(SUBSCRIPTION_ENDPOINT, params, default) if isinstance(result, dict) and "status" in result: if callback is not None: self.subscription_callbacks.append(callback) @@ -242,6 +249,22 @@ async def subscribe_objects( return default raise self.server.error("Invalid response received from Klippy", 500) + async def subscribe_from_transport( + self, + objects: Mapping[str, Optional[List[str]]], + transport: APITransport, + default: Union[Sentinel, _T] = Sentinel.MISSING, + ) -> Union[_T, Dict[str, Any]]: + params = {"objects": dict(objects)} + result = await self._send_klippy_request( + SUBSCRIPTION_ENDPOINT, params, default, transport + ) + if isinstance(result, dict) and "status" in result: + return result["status"] + if default is not Sentinel.MISSING: + return default + raise self.server.error("Invalid response received from Klippy", 500) + async def subscribe_gcode_output(self) -> str: template = {'response_template': {'method': "process_gcode_response"}} diff --git a/moonraker/components/machine.py b/moonraker/components/machine.py index bd003e2d7..56eb1ff07 100644 --- a/moonraker/components/machine.py +++ b/moonraker/components/machine.py @@ -23,6 +23,7 @@ from ..confighelper import FileSourceWrapper from ..utils import source_info from ..utils import json_wrapper as jsonw +from ..common import RequestType # Annotation imports from typing import ( @@ -46,7 +47,6 @@ from .shell_command import ShellCommandFactory as SCMDComp from .database import MoonrakerDatabase from .file_manager.file_manager import FileManager - from .authorization import Authorization from .announcements import Announcements from .proc_stats import ProcStats from .dbus_manager import DbusManager @@ -132,26 +132,29 @@ def __init__(self, config: ConfigHelper) -> None: self.sudo_requests: List[Tuple[SudoCallback, str]] = [] self.server.register_endpoint( - "/machine/reboot", ['POST'], self._handle_machine_request) + "/machine/reboot", RequestType.POST, self._handle_machine_request + ) self.server.register_endpoint( - "/machine/shutdown", ['POST'], self._handle_machine_request) + "/machine/shutdown", RequestType.POST, self._handle_machine_request + ) self.server.register_endpoint( - "/machine/services/restart", ['POST'], - self._handle_service_request) + "/machine/services/restart", RequestType.POST, self._handle_service_request + ) self.server.register_endpoint( - "/machine/services/stop", ['POST'], - self._handle_service_request) + "/machine/services/stop", RequestType.POST, self._handle_service_request + ) self.server.register_endpoint( - "/machine/services/start", ['POST'], - self._handle_service_request) + "/machine/services/start", RequestType.POST, self._handle_service_request + ) self.server.register_endpoint( - "/machine/system_info", ['GET'], - self._handle_sysinfo_request) + "/machine/system_info", RequestType.GET, self._handle_sysinfo_request + ) self.server.register_endpoint( - "/machine/sudo/info", ["GET"], self._handle_sudo_info) + "/machine/sudo/info", RequestType.GET, self._handle_sudo_info + ) self.server.register_endpoint( - "/machine/sudo/password", ["POST"], - self._set_sudo_password) + "/machine/sudo/password", RequestType.POST, self._set_sudo_password + ) self.server.register_notification("machine:service_state_changed") self.server.register_notification("machine:sudo_alert") @@ -1929,11 +1932,6 @@ def _request_sudo_access(self) -> None: if self._sudo_requested: return self._sudo_requested = True - auth: Optional[Authorization] - auth = self.server.lookup_component("authorization", None) - if auth is not None: - # Bypass authentication requirements - auth.register_permited_path("/machine/sudo/password") machine: Machine = self.server.lookup_component("machine") machine.register_sudo_request( self._on_password_received, diff --git a/moonraker/components/mqtt.py b/moonraker/components/mqtt.py index 976b86d84..f998c1410 100644 --- a/moonraker/components/mqtt.py +++ b/moonraker/components/mqtt.py @@ -12,7 +12,13 @@ import ssl from collections import deque import paho.mqtt.client as paho_mqtt -from ..common import Subscribable, WebRequest, APITransport, JsonRPC +from ..common import ( + TransportType, + RequestType, + WebRequest, + APITransport, + KlippyState +) from ..utils import json_wrapper as jsonw # Annotation imports @@ -30,9 +36,9 @@ Deque, ) if TYPE_CHECKING: - from ..app import APIDefinition from ..confighelper import ConfigHelper - from ..klippy_connection import KlippyConnection as Klippy + from ..common import JsonRPC, APIDefinition + from .klippy_apis import KlippyAPI FlexCallback = Callable[[bytes], Optional[Coroutine]] RPCCallback = Callable[..., Coroutine] @@ -241,11 +247,10 @@ async def misc_loop(self) -> None: logging.info("MQTT Misc Loop Complete") -class MQTTClient(APITransport, Subscribable): +class MQTTClient(APITransport): def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.eventloop = self.server.get_event_loop() - self.klippy: Klippy = self.server.lookup_component("klippy_connection") self.address: str = config.get('address') self.port: int = config.getint('port', 1883) user = config.gettemplate('username', None) @@ -296,29 +301,32 @@ def __init__(self, config: ConfigHelper) -> None: self.pending_responses: List[asyncio.Future] = [] self.pending_acks: Dict[int, asyncio.Future] = {} + # We don't need to register these endpoints over the MQTT transport as they + # are redundant. MQTT clients can already publish and subscribe. + ep_transports = TransportType.all() & ~TransportType.MQTT self.server.register_endpoint( - "/server/mqtt/publish", ["POST"], - self._handle_publish_request, - transports=["http", "websocket", "internal"]) + "/server/mqtt/publish", RequestType.POST, self._handle_publish_request, + transports=ep_transports + ) self.server.register_endpoint( - "/server/mqtt/subscribe", ["POST"], + "/server/mqtt/subscribe", RequestType.POST, self._handle_subscription_request, - transports=["http", "websocket", "internal"]) + transports=ep_transports + ) # Subscribe to API requests - self.json_rpc = JsonRPC(self.server, transport="MQTT") self.api_request_topic = f"{self.instance_name}/moonraker/api/request" self.api_resp_topic = f"{self.instance_name}/moonraker/api/response" self.klipper_status_topic = f"{self.instance_name}/klipper/status" self.klipper_state_prefix = f"{self.instance_name}/klipper/state" self.moonraker_status_topic = f"{self.instance_name}/moonraker/status" - status_cfg: Dict[str, Any] = config.getdict("status_objects", {}, - allow_empty_fields=True) - self.status_objs: Dict[str, Any] = {} + status_cfg: Dict[str, str] = config.getdict( + "status_objects", {}, allow_empty_fields=True + ) + self.status_objs: Dict[str, Optional[List[str]]] = {} for key, val in status_cfg.items(): if val is not None: - self.status_objs[key] = [v.strip() for v in val.split(',') - if v.strip()] + self.status_objs[key] = [v.strip() for v in val.split(',') if v.strip()] else: self.status_objs[key] = None if status_cfg: @@ -330,10 +338,6 @@ def __init__(self, config: ConfigHelper) -> None: self.timestamp_deque: Deque = deque(maxlen=20) self.api_qos = config.getint('api_qos', self.qos) if config.getboolean("enable_moonraker_api", True): - api_cache = self.server.register_api_transport("mqtt", self) - for api_def in api_cache.values(): - if "mqtt" in api_def.supported_transports: - self.register_api_handler(api_def) self.subscribe_topic(self.api_request_topic, self._process_api_request, self.api_qos) @@ -361,14 +365,12 @@ async def component_init(self) -> None: self._do_reconnect(first=True) ) - async def _handle_klippy_started(self, state: str) -> None: + async def _handle_klippy_started(self, state: KlippyState) -> None: if self.status_objs: - args = {'objects': self.status_objs} - try: - await self.klippy.request( - WebRequest("objects/subscribe", args, conn=self)) - except self.server.error: - pass + kapi: KlippyAPI = self.server.lookup_component("klippy_apis") + await kapi.subscribe_from_transport( + self.status_objs, self, default=None, + ) def _on_message(self, client: str, @@ -670,51 +672,19 @@ async def _handle_subscription_request(self, } async def _process_api_request(self, payload: bytes) -> None: - response = await self.json_rpc.dispatch(payload.decode()) + rpc: JsonRPC = self.server.lookup_component("jsonrpc") + response = await rpc.dispatch(payload, self) if response is not None: await self.publish_topic(self.api_resp_topic, response, self.api_qos) - def register_api_handler(self, api_def: APIDefinition) -> None: - if api_def.callback is None: - # Remote API, uses RPC to reach out to Klippy - mqtt_method = api_def.jrpc_methods[0] - rpc_cb = self._generate_remote_callback(api_def.endpoint) - self.json_rpc.register_method(mqtt_method, rpc_cb) - else: - # Local API, uses local callback - for mqtt_method, req_method in \ - zip(api_def.jrpc_methods, api_def.request_methods): - rpc_cb = self._generate_local_callback( - api_def.endpoint, req_method, api_def.callback) - self.json_rpc.register_method(mqtt_method, rpc_cb) - logging.info( - "Registering MQTT JSON-RPC methods: " - f"{', '.join(api_def.jrpc_methods)}") - - def remove_api_handler(self, api_def: APIDefinition) -> None: - for jrpc_method in api_def.jrpc_methods: - self.json_rpc.remove_method(jrpc_method) - - def _generate_local_callback(self, - endpoint: str, - request_method: str, - callback: Callable[[WebRequest], Coroutine] - ) -> RPCCallback: - async def func(args: Dict[str, Any]) -> Any: - self._check_timestamp(args) - result = await callback(WebRequest(endpoint, args, request_method)) - return result - return func - - def _generate_remote_callback(self, endpoint: str) -> RPCCallback: - async def func(args: Dict[str, Any]) -> Any: - self._check_timestamp(args) - result = await self.klippy.request(WebRequest(endpoint, args)) - return result - return func - - def _check_timestamp(self, args: Dict[str, Any]) -> None: + @property + def transport_type(self) -> TransportType: + return TransportType.MQTT + + def screen_rpc_request( + self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] + ) -> None: ts = args.pop("mqtt_timestamp", None) if ts is not None: if ts in self.timestamp_deque: diff --git a/moonraker/components/notifier.py b/moonraker/components/notifier.py index 4cf4236db..7f01296e6 100644 --- a/moonraker/components/notifier.py +++ b/moonraker/components/notifier.py @@ -10,6 +10,7 @@ import logging import pathlib import re +from ..common import JobEvent, RequestType # Annotation imports from typing import ( @@ -29,23 +30,20 @@ class Notifier: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.notifiers: Dict[str, NotifierInstance] = {} - self.events: Dict[str, NotifierEvent] = {} + self.events: Dict[str, List[NotifierInstance]] = {} prefix_sections = config.get_prefix_sections("notifier") - - self.register_events(config) self.register_remote_actions() - for section in prefix_sections: cfg = config[section] try: notifier = NotifierInstance(cfg) - - for event in self.events: - if event in notifier.events or "*" in notifier.events: - self.events[event].register_notifier(notifier) - + for job_event in list(JobEvent): + if job_event == JobEvent.STANDBY: + continue + evt_name = str(job_event) + if "*" in notifier.events or evt_name in notifier.events: + self.events.setdefault(evt_name, []).append(notifier) logging.info(f"Registered notifier: '{notifier.get_name()}'") - except Exception as e: msg = f"Failed to load notifier[{cfg.get_name()}]\n{e}" self.server.add_warning(msg) @@ -53,6 +51,9 @@ def __init__(self, config: ConfigHelper) -> None: self.notifiers[notifier.get_name()] = notifier self.register_endpoints(config) + self.server.register_event_handler( + "job_state:state_changed", self._on_job_state_changed + ) def register_remote_actions(self): self.server.register_remote_method("notify", self.notify_action) @@ -61,47 +62,24 @@ async def notify_action(self, name: str, message: str = ""): if name not in self.notifiers: raise self.server.error(f"Notifier '{name}' not found", 404) notifier = self.notifiers[name] - await notifier.notify("remote_action", [], message) - def register_events(self, config: ConfigHelper): - - self.events["started"] = NotifierEvent( - "started", - "job_state:started", - config) - - self.events["complete"] = NotifierEvent( - "complete", - "job_state:complete", - config) - - self.events["error"] = NotifierEvent( - "error", - "job_state:error", - config) - - self.events["cancelled"] = NotifierEvent( - "cancelled", - "job_state:cancelled", - config) - - self.events["paused"] = NotifierEvent( - "paused", - "job_state:paused", - config) - - self.events["resumed"] = NotifierEvent( - "resumed", - "job_state:resumed", - config) + async def _on_job_state_changed( + self, + job_event: JobEvent, + prev_stats: Dict[str, Any], + new_stats: Dict[str, Any] + ) -> None: + evt_name = str(job_event) + for notifier in self.events.get(evt_name, []): + await notifier.notify(evt_name, [prev_stats, new_stats]) def register_endpoints(self, config: ConfigHelper): self.server.register_endpoint( - "/server/notifiers/list", ["GET"], self._handle_notifier_list + "/server/notifiers/list", RequestType.GET, self._handle_notifier_list ) self.server.register_debug_endpoint( - "/debug/notifiers/test", ["POST"], self._handle_notifier_test + "/debug/notifiers/test", RequestType.POST, self._handle_notifier_test ) async def _handle_notifier_list( @@ -134,34 +112,6 @@ async def _handle_notifier_test( "stats": print_stats } - -class NotifierEvent: - def __init__(self, identifier: str, event_name: str, config: ConfigHelper): - self.identifier = identifier - self.event_name = event_name - self.server = config.get_server() - self.notifiers: Dict[str, NotifierInstance] = {} - self.config = config - - self.server.register_event_handler(self.event_name, self._handle) - - def register_notifier(self, notifier: NotifierInstance): - self.notifiers[notifier.get_name()] = notifier - - async def _handle(self, *args) -> None: - logging.info(f"'{self.identifier}' notifier event triggered'") - await self.invoke_notifiers(args) - - async def invoke_notifiers(self, args): - for notifier_name in self.notifiers: - try: - notifier = self.notifiers[notifier_name] - await notifier.notify(self.identifier, args) - except Exception as e: - logging.info(f"Failed to notify [{notifier_name}]\n{e}") - continue - - class NotifierInstance: def __init__(self, config: ConfigHelper) -> None: self.config = config diff --git a/moonraker/components/octoprint_compat.py b/moonraker/components/octoprint_compat.py index 33341658f..8d77dd1c8 100644 --- a/moonraker/components/octoprint_compat.py +++ b/moonraker/components/octoprint_compat.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +from ..common import RequestType, TransportType, KlippyState # Annotation imports from typing import ( @@ -15,6 +16,7 @@ List, ) if TYPE_CHECKING: + from ..klippy_connection import KlippyConnection from ..confighelper import ConfigHelper from ..common import WebRequest from .klippy_apis import KlippyAPI as APIComp @@ -65,22 +67,27 @@ def __init__(self, config: ConfigHelper) -> None: # Version & Server information self.server.register_endpoint( - '/api/version', ['GET'], self._get_version, - transports=['http'], wrap_result=False) + '/api/version', RequestType.GET, self._get_version, + transports=TransportType.HTTP, wrap_result=False + ) self.server.register_endpoint( - '/api/server', ['GET'], self._get_server, - transports=['http'], wrap_result=False) + '/api/server', RequestType.GET, self._get_server, + transports=TransportType.HTTP, wrap_result=False + ) # Login, User & Settings self.server.register_endpoint( - '/api/login', ['POST'], self._post_login_user, - transports=['http'], wrap_result=False) + '/api/login', RequestType.POST, self._post_login_user, + transports=TransportType.HTTP, wrap_result=False + ) self.server.register_endpoint( - '/api/currentuser', ['GET'], self._post_login_user, - transports=['http'], wrap_result=False) + '/api/currentuser', RequestType.GET, self._post_login_user, + transports=TransportType.HTTP, wrap_result=False + ) self.server.register_endpoint( - '/api/settings', ['GET'], self._get_settings, - transports=['http'], wrap_result=False) + '/api/settings', RequestType.GET, self._get_settings, + transports=TransportType.HTTP, wrap_result=False + ) # File operations # Note that file upload is handled in file_manager.py @@ -88,30 +95,34 @@ def __init__(self, config: ConfigHelper) -> None: # Job operations self.server.register_endpoint( - '/api/job', ['GET'], self._get_job, - transports=['http'], wrap_result=False) + '/api/job', RequestType.GET, self._get_job, + transports=TransportType.HTTP, wrap_result=False + ) # TODO: start/cancel/restart/pause jobs # Printer operations self.server.register_endpoint( - '/api/printer', ['GET'], self._get_printer, - transports=['http'], wrap_result=False) + '/api/printer', RequestType.GET, self._get_printer, + transports=TransportType.HTTP, wrap_result=False) self.server.register_endpoint( - '/api/printer/command', ['POST'], self._post_command, - transports=['http'], wrap_result=False) + '/api/printer/command', RequestType.POST, self._post_command, + transports=TransportType.HTTP, wrap_result=False + ) # TODO: head/tool/bed/chamber specific read/issue # Printer profiles self.server.register_endpoint( - '/api/printerprofiles', ['GET'], self._get_printerprofiles, - transports=['http'], wrap_result=False) + '/api/printerprofiles', RequestType.GET, self._get_printerprofiles, + transports=TransportType.HTTP, wrap_result=False + ) # Upload Handlers self.server.register_upload_handler( "/api/files/local", location_prefix="api/files/moonraker") self.server.register_endpoint( - "/api/files/moonraker/(?P.+)", ['POST'], - self._select_file, transports=['http'], wrap_result=False) + "/api/files/moonraker/(?P.+)", RequestType.POST, + self._select_file, transports=TransportType.HTTP, wrap_result=False + ) # System # TODO: shutdown/reboot/restart operations @@ -143,10 +154,11 @@ def _handle_status_update(self, status: Dict[str, Any]) -> None: data.update(status[heater_name]) def printer_state(self) -> str: - klippy_state = self.server.get_klippy_state() - if klippy_state in ["disconnected", "startup"]: + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + klippy_state = kconn.state + if not klippy_state.startup_complete(): return 'Offline' - elif klippy_state != 'ready': + elif klippy_state != KlippyState.READY: return 'Error' return { 'standby': 'Operational', @@ -192,11 +204,11 @@ async def _get_server(self, """ Server status """ - klippy_state = self.server.get_klippy_state() + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + klippy_state = kconn.state return { 'server': OCTO_VERSION, - 'safemode': ( - None if klippy_state == 'ready' else 'settings') + 'safemode': None if klippy_state == KlippyState.READY else 'settings' } async def _post_login_user(self, diff --git a/moonraker/components/power.py b/moonraker/components/power.py index 4d97a2f89..355688b48 100644 --- a/moonraker/components/power.py +++ b/moonraker/components/power.py @@ -12,6 +12,7 @@ import time from urllib.parse import quote, urlencode from ..utils import json_wrapper as jsonw +from ..common import RequestType, KlippyState # Annotation imports from typing import ( @@ -74,20 +75,24 @@ def __init__(self, config: ConfigHelper) -> None: self.devices[dev.get_name()] = dev self.server.register_endpoint( - "/machine/device_power/devices", ['GET'], - self._handle_list_devices) + "/machine/device_power/devices", RequestType.GET, self._handle_list_devices + ) self.server.register_endpoint( - "/machine/device_power/status", ['GET'], - self._handle_batch_power_request) + "/machine/device_power/status", RequestType.GET, + self._handle_batch_power_request + ) self.server.register_endpoint( - "/machine/device_power/on", ['POST'], - self._handle_batch_power_request) + "/machine/device_power/on", RequestType.POST, + self._handle_batch_power_request + ) self.server.register_endpoint( - "/machine/device_power/off", ['POST'], - self._handle_batch_power_request) + "/machine/device_power/off", RequestType.POST, + self._handle_batch_power_request + ) self.server.register_endpoint( - "/machine/device_power/device", ['GET', 'POST'], - self._handle_single_power_request) + "/machine/device_power/device", RequestType.GET | RequestType.POST, + self._handle_single_power_request + ) self.server.register_remote_method( "set_device_power", self.set_device_power) self.server.register_event_handler( @@ -122,34 +127,35 @@ async def _handle_job_queued(self, queue_info: Dict[str, Any]) -> None: ) await dev.process_request("on") - async def _handle_list_devices(self, - web_request: WebRequest - ) -> Dict[str, Any]: + async def _handle_list_devices( + self, web_request: WebRequest + ) -> Dict[str, Any]: dev_list = [d.get_device_info() for d in self.devices.values()] output = {"devices": dev_list} return output - async def _handle_single_power_request(self, - web_request: WebRequest - ) -> Dict[str, Any]: + async def _handle_single_power_request( + self, web_request: WebRequest + ) -> Dict[str, Any]: dev_name: str = web_request.get_str('device') - req_action = web_request.get_action() + req_type = web_request.get_request_type() if dev_name not in self.devices: raise self.server.error(f"No valid device named {dev_name}") dev = self.devices[dev_name] - if req_action == 'GET': + if req_type == RequestType.GET: action = "status" - elif req_action == "POST": + elif req_type == RequestType.POST: action = web_request.get_str('action').lower() if action not in ["on", "off", "toggle"]: - raise self.server.error( - f"Invalid requested action '{action}'") + raise self.server.error(f"Invalid requested action '{action}'") + else: + raise self.server.error(f"Invalid Request Type: {req_type}") result = await dev.process_request(action) return {dev_name: result} - async def _handle_batch_power_request(self, - web_request: WebRequest - ) -> Dict[str, Any]: + async def _handle_batch_power_request( + self, web_request: WebRequest + ) -> Dict[str, Any]: args = web_request.get_args() ep = web_request.get_endpoint() if not args: @@ -256,11 +262,11 @@ def __init__(self, config: ConfigHelper) -> None: 'initial_state', None ) - def _schedule_firmware_restart(self, state: str = "") -> None: + def _schedule_firmware_restart(self, state: KlippyState) -> None: if not self.need_scheduled_restart: return self.need_scheduled_restart = False - if state == "ready": + if state == KlippyState.READY: logging.info( f"Power Device {self.name}: Klipper reports 'ready', " "aborting FIRMWARE_RESTART" @@ -298,8 +304,9 @@ async def process_power_changed(self) -> None: await self.process_bound_services() if self.state == "on" and self.klipper_restart: self.need_scheduled_restart = True - klippy_state = self.server.get_klippy_state() - if klippy_state in ["disconnected", "startup"]: + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + klippy_state = kconn.state + if not klippy_state.startup_complete(): # If klippy is currently disconnected or hasn't proceeded past # the startup state, schedule the restart in the # "klippy_started" event callback. @@ -332,7 +339,8 @@ def process_klippy_shutdown(self) -> None: self.off_when_shutdown_delay, self._power_off_on_shutdown) def _power_off_on_shutdown(self) -> None: - if self.server.get_klippy_state() != "shutdown": + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + if kconn.state != KlippyState.SHUTDOWN: return logging.info( f"Powering off device '{self.name}' due to klippy shutdown") diff --git a/moonraker/components/proc_stats.py b/moonraker/components/proc_stats.py index 404952d4d..ca8efd5af 100644 --- a/moonraker/components/proc_stats.py +++ b/moonraker/components/proc_stats.py @@ -15,6 +15,7 @@ import logging from collections import deque from ..utils import ioctl_macros +from ..common import RequestType # Annotation imports from typing import ( @@ -79,9 +80,11 @@ def __init__(self, config: ConfigHelper) -> None: self.cpu_stats_file = pathlib.Path(CPU_STAT_PATH) self.meminfo_file = pathlib.Path(MEM_AVAIL_PATH) self.server.register_endpoint( - "/machine/proc_stats", ["GET"], self._handle_stat_request) + "/machine/proc_stats", RequestType.GET, self._handle_stat_request + ) self.server.register_event_handler( - "server:klippy_shutdown", self._handle_shutdown) + "server:klippy_shutdown", self._handle_shutdown + ) self.server.register_notification("proc_stats:proc_stat_update") self.proc_stat_queue: Deque[Dict[str, Any]] = deque(maxlen=30) self.last_update_time = time.time() diff --git a/moonraker/components/sensor.py b/moonraker/components/sensor.py index 3c1bc28a5..ea6b16b29 100644 --- a/moonraker/components/sensor.py +++ b/moonraker/components/sensor.py @@ -12,6 +12,7 @@ from collections import defaultdict, deque from dataclasses import dataclass, replace from functools import partial +from ..common import RequestType # Annotation imports from typing import ( @@ -180,17 +181,17 @@ def __init__(self, config: ConfigHelper) -> None: # Register endpoints self.server.register_endpoint( "/server/sensors/list", - ["GET"], + RequestType.GET, self._handle_sensor_list_request, ) self.server.register_endpoint( "/server/sensors/info", - ["GET"], + RequestType.GET, self._handle_sensor_info_request, ) self.server.register_endpoint( "/server/sensors/measurements", - ["GET"], + RequestType.GET, self._handle_sensor_measurements_request, ) diff --git a/moonraker/components/simplyprint.py b/moonraker/components/simplyprint.py index 41732a64d..dc0a9b67d 100644 --- a/moonraker/components/simplyprint.py +++ b/moonraker/components/simplyprint.py @@ -17,7 +17,7 @@ import tempfile from queue import SimpleQueue from ..loghelper import LocalQueueHandler -from ..common import Subscribable, WebRequest +from ..common import APITransport, JobEvent, KlippyState from ..utils import json_wrapper as jsonw from typing import ( @@ -28,6 +28,7 @@ List, Union, Any, + Callable, ) if TYPE_CHECKING: from ..app import InternalTransport @@ -57,7 +58,7 @@ "ping" ] -class SimplyPrint(Subscribable): +class SimplyPrint(APITransport): def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self._logger = ProtoLogger(config) @@ -157,19 +158,7 @@ def __init__(self, config: ConfigHelper) -> None: self.server.register_event_handler( "server:klippy_disconnect", self._on_klippy_disconnected) self.server.register_event_handler( - "job_state:started", self._on_print_start) - self.server.register_event_handler( - "job_state:paused", self._on_print_paused) - self.server.register_event_handler( - "job_state:resumed", self._on_print_resumed) - self.server.register_event_handler( - "job_state:standby", self._on_print_standby) - self.server.register_event_handler( - "job_state:complete", self._on_print_complete) - self.server.register_event_handler( - "job_state:error", self._on_print_error) - self.server.register_event_handler( - "job_state:cancelled", self._on_print_cancelled) + "job_state:state_changed", self._on_job_state_changed) self.server.register_event_handler( "klippy_apis:pause_requested", self._on_pause_requested) self.server.register_event_handler( @@ -542,7 +531,7 @@ async def _test_webcam(self) -> None: async def _on_klippy_ready(self) -> None: last_stats: Dict[str, Any] = self.job_state.get_last_stats() if last_stats["state"] == "printing": - self._on_print_start(last_stats, last_stats, False) + self._on_print_started(last_stats, last_stats, False) else: self._update_state("operational") query: Optional[Dict[str, Any]] @@ -591,15 +580,9 @@ async def _on_klippy_ready(self) -> None: if not sub_objs: return # Create our own subscription rather than use the host sub - args = {'objects': sub_objs} - klippy: KlippyConnection - klippy = self.server.lookup_component("klippy_connection") - try: - resp: Dict[str, Dict[str, Any]] = await klippy.request( - WebRequest("objects/subscribe", args, conn=self)) - status: Dict[str, Any] = resp.get("status", {}) - except self.server.error: - status = {} + status: Dict[str, Any] = await self.klippy_apis.subscribe_from_transport( + sub_objs, self, default={} + ) if status: logging.debug(f"SimplyPrint: Got Initial Status: {status}") self.printer_status = status @@ -651,12 +634,12 @@ def _on_websocket_removed(self, ws: BaseRemoteConnection) -> None: self.cache.firmware_info.update(ui_data) self.send_sp("machine_data", ui_data) - def _on_klippy_startup(self, state: str) -> None: - if state != "ready": + def _on_klippy_startup(self, state: KlippyState) -> None: + if state != KlippyState.READY: self._update_state("error") kconn: KlippyConnection kconn = self.server.lookup_component("klippy_connection") - self.send_sp("printer_error", {"error": kconn.state_message}) + self.send_sp("printer_error", {"error": kconn.state.message}) self.send_sp("connection", {"new": "connected"}) self._send_firmware_data() @@ -664,7 +647,7 @@ def _on_klippy_shutdown(self) -> None: self._update_state("error") kconn: KlippyConnection kconn = self.server.lookup_component("klippy_connection") - self.send_sp("printer_error", {"error": kconn.state_message}) + self.send_sp("printer_error", {"error": kconn.state.message}) def _on_klippy_disconnected(self) -> None: self._update_state("offline") @@ -674,7 +657,14 @@ def _on_klippy_disconnected(self) -> None: self.cache.reset_print_state() self.printer_status = {} - def _on_print_start( + def _on_job_state_changed(self, job_event: JobEvent, *args) -> None: + callback: Optional[Callable] = getattr(self, f"_on_print_{job_event}", None) + if callback is not None: + callback(*args) + else: + logging.info(f"No defined callback for Job Event: {job_event}") + + def _on_print_started( self, prev_stats: Dict[str, Any], new_stats: Dict[str, Any], @@ -931,10 +921,11 @@ def _update_temps(self, eventtime: float) -> None: self.send_sp("temps", temp_data) def _update_state_from_klippy(self) -> None: - kstate = self.server.get_klippy_state() - if kstate == "ready": + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + klippy_state = kconn.state + if klippy_state == KlippyState.READY: sp_state = "operational" - elif kstate in ["error", "shutdown"]: + elif klippy_state in [KlippyState.ERROR, KlippyState.SHUTDOWN]: sp_state = "error" else: sp_state = "offline" @@ -1617,7 +1608,8 @@ async def start_print(self) -> None: self.simplyprint.send_sp("file_progress", data) async def _check_can_print(self) -> bool: - if self.server.get_klippy_state() != "ready": + kconn: KlippyConnection = self.server.lookup_component("klippy_connection") + if kconn.state != KlippyState.READY: return False kapi: KlippyAPI = self.server.lookup_component("klippy_apis") try: diff --git a/moonraker/components/spoolman.py b/moonraker/components/spoolman.py index 7585cb3eb..e5cec2f17 100644 --- a/moonraker/components/spoolman.py +++ b/moonraker/components/spoolman.py @@ -9,6 +9,7 @@ import datetime import logging from typing import TYPE_CHECKING, Dict, Any +from ..common import RequestType if TYPE_CHECKING: from typing import Optional @@ -64,12 +65,12 @@ def _register_listeners(self): def _register_endpoints(self): self.server.register_endpoint( "/server/spoolman/spool_id", - ["GET", "POST"], + RequestType.GET | RequestType.POST, self._handle_spool_id_request, ) self.server.register_endpoint( "/server/spoolman/proxy", - ["POST"], + RequestType.POST, self._proxy_spoolman_request, ) @@ -157,7 +158,7 @@ async def track_filament_usage(self): self.extruded = 0 async def _handle_spool_id_request(self, web_request: WebRequest): - if web_request.get_action() == "POST": + if web_request.get_request_type() == RequestType.POST: spool_id = web_request.get_int("spool_id", None) await self.set_active_spool(spool_id) # For GET requests we will simply return the spool_id diff --git a/moonraker/components/update_manager/common.py b/moonraker/components/update_manager/common.py index a55513475..4ed928947 100644 --- a/moonraker/components/update_manager/common.py +++ b/moonraker/components/update_manager/common.py @@ -9,7 +9,7 @@ import sys import copy import pathlib -from enum import Enum +from ...common import ExtendedEnum from ...utils import source_info from typing import ( TYPE_CHECKING, @@ -46,25 +46,13 @@ } } -class ExtEnum(Enum): - @classmethod - def from_string(cls, enum_name: str): - str_name = enum_name.upper() - for name, member in cls.__members__.items(): - if name == str_name: - return cls(member.value) - raise ValueError(f"No enum member named {enum_name}") - - def __str__(self) -> str: - return self._name_.lower() # type: ignore - -class AppType(ExtEnum): +class AppType(ExtendedEnum): NONE = 1 WEB = 2 GIT_REPO = 3 ZIP = 4 -class Channel(ExtEnum): +class Channel(ExtendedEnum): STABLE = 1 BETA = 2 DEV = 3 diff --git a/moonraker/components/update_manager/update_manager.py b/moonraker/components/update_manager/update_manager.py index 9786f039d..4c5651b49 100644 --- a/moonraker/components/update_manager/update_manager.py +++ b/moonraker/components/update_manager/update_manager.py @@ -17,6 +17,7 @@ from .zip_deploy import ZipDeploy from .system_deploy import PackageDeploy from .web_deploy import WebClientDeploy +from ...common import RequestType # Annotation imports from typing import ( @@ -130,32 +131,32 @@ def __init__(self, config: ConfigHelper) -> None: self._handle_auto_refresh) self.server.register_endpoint( - "/machine/update/moonraker", ["POST"], - self._handle_update_request) + "/machine/update/moonraker", RequestType.POST, self._handle_update_request + ) self.server.register_endpoint( - "/machine/update/klipper", ["POST"], - self._handle_update_request) + "/machine/update/klipper", RequestType.POST, self._handle_update_request + ) self.server.register_endpoint( - "/machine/update/system", ["POST"], - self._handle_update_request) + "/machine/update/system", RequestType.POST, self._handle_update_request + ) self.server.register_endpoint( - "/machine/update/client", ["POST"], - self._handle_update_request) + "/machine/update/client", RequestType.POST, self._handle_update_request + ) self.server.register_endpoint( - "/machine/update/full", ["POST"], - self._handle_full_update_request) + "/machine/update/full", RequestType.POST, self._handle_full_update_request + ) self.server.register_endpoint( - "/machine/update/status", ["GET"], - self._handle_status_request) + "/machine/update/status", RequestType.GET, self._handle_status_request + ) self.server.register_endpoint( - "/machine/update/refresh", ["POST"], - self._handle_refresh_request) + "/machine/update/refresh", RequestType.POST, self._handle_refresh_request + ) self.server.register_endpoint( - "/machine/update/recover", ["POST"], - self._handle_repo_recovery) + "/machine/update/recover", RequestType.POST, self._handle_repo_recovery + ) self.server.register_endpoint( - "/machine/update/rollback", ["POST"], - self._handle_rollback) + "/machine/update/rollback", RequestType.POST, self._handle_rollback + ) self.server.register_notification("update_manager:update_response") self.server.register_notification("update_manager:update_refreshed") diff --git a/moonraker/components/webcam.py b/moonraker/components/webcam.py index 8ddfdc2cc..6c33b88b4 100644 --- a/moonraker/components/webcam.py +++ b/moonraker/components/webcam.py @@ -10,6 +10,7 @@ import socket import uuid import logging +from ..common import RequestType from typing import ( TYPE_CHECKING, Optional, @@ -50,14 +51,14 @@ def __init__(self, config: ConfigHelper) -> None: self.webcams[webcam.name] = webcam self.server.register_endpoint( - "/server/webcams/list", ["GET"], self._handle_webcam_list + "/server/webcams/list", RequestType.GET, self._handle_webcam_list ) self.server.register_endpoint( - "/server/webcams/item", ["GET", "POST", "DELETE"], + "/server/webcams/item", RequestType.all(), self._handle_webcam_request ) self.server.register_endpoint( - "/server/webcams/test", ["POST"], self._handle_webcam_test + "/server/webcams/test", RequestType.POST, self._handle_webcam_test ) self.server.register_notification("webcam:webcams_changed") self.server.register_event_handler( @@ -163,13 +164,13 @@ def _lookup_camera( return webcam async def _handle_webcam_request(self, web_request: WebRequest) -> Dict[str, Any]: - action = web_request.get_action() - webcam = self._lookup_camera(web_request, action != "POST") + req_type = web_request.get_request_type() + webcam = self._lookup_camera(web_request, req_type != RequestType.POST) webcam_data: Dict[str, Any] = {} - if action == "GET": + if req_type == RequestType.GET: assert webcam is not None webcam_data = webcam.as_dict() - elif action == "POST": + elif req_type == RequestType.POST: if webcam is not None: if webcam.source == "config": raise self.server.error( @@ -191,7 +192,7 @@ async def _handle_webcam_request(self, web_request: WebRequest) -> Dict[str, Any webcam = WebCam.from_web_request(self.server, web_request, uid) await self._save_cam(webcam) webcam_data = webcam.as_dict() - elif action == "DELETE": + elif req_type == RequestType.DELETE: assert webcam is not None if webcam.source == "config": raise self.server.error( @@ -200,7 +201,7 @@ async def _handle_webcam_request(self, web_request: WebRequest) -> Dict[str, Any ) webcam_data = webcam.as_dict() self._delete_cam(webcam) - if action != "GET": + if req_type != RequestType.GET: self.server.send_event( "webcam:webcams_changed", {"webcams": self._list_webcams()} ) diff --git a/moonraker/components/wled.py b/moonraker/components/wled.py index 9e6afafe2..ea758e0f2 100644 --- a/moonraker/components/wled.py +++ b/moonraker/components/wled.py @@ -16,6 +16,7 @@ from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import HTTPRequest from ..utils import json_wrapper as jsonw +from ..common import RequestType # Annotation imports from typing import ( @@ -388,23 +389,24 @@ def __init__(self: WLED, config: ConfigHelper) -> None: # As moonraker is about making things a web api, let's try it # Yes, this is largely a cut-n-paste from power.py self.server.register_endpoint( - "/machine/wled/strips", ["GET"], - self._handle_list_strips) + "/machine/wled/strips", RequestType.GET, self._handle_list_strips + ) self.server.register_endpoint( - "/machine/wled/status", ["GET"], - self._handle_batch_wled_request) + "/machine/wled/status", RequestType.GET, self._handle_batch_wled_request + ) self.server.register_endpoint( - "/machine/wled/on", ["POST"], - self._handle_batch_wled_request) + "/machine/wled/on", RequestType.POST, self._handle_batch_wled_request + ) self.server.register_endpoint( - "/machine/wled/off", ["POST"], - self._handle_batch_wled_request) + "/machine/wled/off", RequestType.POST, self._handle_batch_wled_request + ) self.server.register_endpoint( - "/machine/wled/toggle", ["POST"], - self._handle_batch_wled_request) + "/machine/wled/toggle", RequestType.POST, self._handle_batch_wled_request + ) self.server.register_endpoint( - "/machine/wled/strip", ["GET", "POST"], - self._handle_single_wled_request) + "/machine/wled/strip", RequestType.GET | RequestType.POST, + self._handle_single_wled_request + ) async def component_init(self) -> None: try: @@ -521,19 +523,19 @@ async def _handle_single_wled_request(self: WLED, intensity: int = web_request.get_int('intensity', -1) speed: int = web_request.get_int('speed', -1) - req_action = web_request.get_action() + req_type = web_request.get_request_type() if strip_name not in self.strips: raise self.server.error(f"No valid strip named {strip_name}") strip = self.strips[strip_name] - if req_action == 'GET': + if req_type == RequestType.GET: return {strip_name: strip.get_strip_info()} - elif req_action == "POST": + elif req_type == RequestType.POST: action = web_request.get_str('action').lower() if action not in ["on", "off", "toggle", "control"]: - raise self.server.error( - f"Invalid requested action '{action}'") - result = await self._process_request(strip, action, preset, - brightness, intensity, speed) + raise self.server.error(f"Invalid requested action '{action}'") + result = await self._process_request( + strip, action, preset, brightness, intensity, speed + ) return {strip_name: result} async def _handle_batch_wled_request(self: WLED, diff --git a/moonraker/components/zeroconf.py b/moonraker/components/zeroconf.py index 77742bb1c..8c9dd823b 100644 --- a/moonraker/components/zeroconf.py +++ b/moonraker/components/zeroconf.py @@ -14,6 +14,7 @@ from email.utils import formatdate from zeroconf import IPVersion from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf +from ..common import RequestType, TransportType from typing import ( TYPE_CHECKING, @@ -29,7 +30,6 @@ from ..confighelper import ConfigHelper from ..common import WebRequest from ..app import MoonrakerApp - from .authorization import Authorization from .machine import Machine ZC_SERVICE_TYPE = "_moonraker._tcp.local." @@ -208,17 +208,14 @@ def __init__(self, config: ConfigHelper) -> None: self.boot_id = int(eventloop.get_loop_time()) self.config_id = 1 self.ad_timer = eventloop.register_timer(self._advertise_presence) - auth: Optional[Authorization] - auth = self.server.load_component(config, "authorization", None) - if auth is not None: - auth.register_permited_path("/server/zeroconf/ssdp") self.server.register_endpoint( "/server/zeroconf/ssdp", - ["GET"], + RequestType.GET, self._handle_xml_request, - transports=["http"], + transports=TransportType.HTTP, wrap_result=False, - content_type="application/xml" + content_type="application/xml", + auth_required=False ) def _create_ssdp_socket( diff --git a/moonraker/klippy_connection.py b/moonraker/klippy_connection.py index b7400e69b..93737c183 100644 --- a/moonraker/klippy_connection.py +++ b/moonraker/klippy_connection.py @@ -14,6 +14,7 @@ import pathlib from .utils import ServerError, get_unix_peer_credentials from .utils import json_wrapper as jsonw +from .common import KlippyState, RequestType # Annotation imports from typing import ( @@ -31,8 +32,7 @@ ) if TYPE_CHECKING: from .server import Server - from .app import MoonrakerApp - from .common import WebRequest, Subscribable, BaseRemoteConnection + from .common import WebRequest, APITransport, BaseRemoteConnection from .confighelper import ConfigHelper from .components.klippy_apis import KlippyAPI from .components.file_manager.file_manager import FileManager @@ -78,9 +78,9 @@ def __init__(self, server: Server) -> None: self._peer_cred: Dict[str, int] = {} self._service_info: Dict[str, Any] = {} self.init_attempts: int = 0 - self._state: str = "disconnected" - self._state_message: str = "Klippy Disconnected" - self.subscriptions: Dict[Subscribable, Subscription] = {} + self._state: KlippyState = KlippyState.DISCONNECTED + self._state.set_message("Klippy Disconnected") + self.subscriptions: Dict[APITransport, Subscription] = {} self.subscription_cache: Dict[str, Dict[str, Any]] = {} # Setup remote methods accessable to Klippy. Note that all # registered remote methods should be of the notification type, @@ -106,14 +106,14 @@ def klippy_apis(self) -> KlippyAPI: return self.server.lookup_component("klippy_apis") @property - def state(self) -> str: + def state(self) -> KlippyState: if self.is_connected() and not self._klippy_started: - return "startup" + return KlippyState.STARTUP return self._state @property def state_message(self) -> str: - return self._state_message + return self._state.message @property def klippy_info(self) -> Dict[str, Any]: @@ -241,7 +241,7 @@ def _on_agent_method_received(**kwargs) -> None: connection.call_method(method_name, kwargs) self.remote_methods[method_name] = _on_agent_method_received self.klippy_reg_methods.append(method_name) - if self._methods_registered and self._state != "disconnected": + if self._methods_registered and self._state != KlippyState.DISCONNECTED: coro = self.klippy_apis.register_method(method_name) return self.event_loop.create_task(coro) return None @@ -331,7 +331,7 @@ async def _init_klippy_connection(self) -> bool: self._methods_registered = False self._missing_reqs.clear() self.init_attempts = 0 - self._state = "startup" + self._state = KlippyState.STARTUP while self.server.is_running(): await asyncio.sleep(INIT_TIME) await self._check_ready() @@ -351,10 +351,12 @@ async def _request_endpoints(self) -> None: if result is None: return endpoints = result.get('endpoints', []) - app: MoonrakerApp = self.server.lookup_component("application") for ep in endpoints: if ep not in RESERVED_ENDPOINTS: - app.register_remote_handler(ep) + self.server.register_endpoint( + ep, RequestType.GET | RequestType.POST, self.request, + is_remote=True + ) async def _request_initial_subscriptions(self) -> None: try: @@ -391,8 +393,10 @@ async def _check_ready(self) -> None: msg = f"Klipper Version: {version}" self.server.add_log_rollover_item("klipper_version", msg) self._klippy_info = dict(result) + state_message: str = self._state.message if "state_message" in self._klippy_info: - self._state_message = self._klippy_info["state_message"] + state_message = self._klippy_info["state_message"] + self._state.set_message(state_message) if "state" not in result: return if send_id: @@ -400,19 +404,20 @@ async def _check_ready(self) -> None: await self.server.send_event("server:klippy_identified") # Request initial endpoints to register info, emergency stop APIs await self._request_endpoints() - self._state = result["state"] - if self._state != "startup": + self._state = KlippyState.from_string(result["state"], state_message) + if self._state != KlippyState.STARTUP: await self._request_initial_subscriptions() # Register remaining endpoints available await self._request_endpoints() startup_state = self._state - await self.server.send_event( - "server:klippy_started", startup_state - ) + await self.server.send_event("server:klippy_started", startup_state) self._klippy_started = True - if self._state != "ready": - logging.info("\n" + self._state_message) - if self._state == "shutdown" and startup_state != "shutdown": + if self._state != KlippyState.READY: + logging.info("\n" + self._state.message) + if ( + self._state == KlippyState.SHUTDOWN and + startup_state != KlippyState.SHUTDOWN + ): # Klippy shutdown during startup event self.server.send_event("server:klippy_shutdown") else: @@ -425,10 +430,10 @@ async def _check_ready(self) -> None: logging.exception( f"Unable to register method '{method}'") self._methods_registered = True - if self._state == "ready": + if self._state == KlippyState.READY: logging.info("Klippy ready") await self.server.send_event("server:klippy_ready") - if self._state == "shutdown": + if self._state == KlippyState.SHUTDOWN: # Klippy shutdown during ready event self.server.send_event("server:klippy_shutdown") else: @@ -520,21 +525,23 @@ def _process_status_update( self.subscription_cache.setdefault(field, {}).update(item) if 'webhooks' in status: wh: Dict[str, str] = status['webhooks'] + state_message: str = self._state.message if "state_message" in wh: - self._state_message = wh["state_message"] + state_message = wh["state_message"] + self._state.set_message(state_message) # XXX - process other states (startup, ready, error, etc)? if "state" in wh: - state = wh["state"] + new_state = KlippyState.from_string(wh["state"], state_message) if ( - state == "shutdown" and + new_state == KlippyState.SHUTDOWN and not self._klippy_initializing and - self._state != "shutdown" + self._state != KlippyState.SHUTDOWN ): # If the shutdown state is received during initialization # defer the event, the init routine will handle it. logging.info("Klippy has shutdown") self.server.send_event("server:klippy_shutdown") - self._state = state + self._state = new_state for conn, sub in self.subscriptions.items(): conn_status: Dict[str, Any] = {} for name, fields in sub.items(): @@ -650,14 +657,14 @@ async def _request_standard( finally: self.pending_requests.pop(base_request.id, None) - def remove_subscription(self, conn: Subscribable) -> None: + def remove_subscription(self, conn: APITransport) -> None: self.subscriptions.pop(conn, None) def is_connected(self) -> bool: return self.writer is not None and not self.closing def is_ready(self) -> bool: - return self._state == "ready" + return self._state == KlippyState.READY def is_printing(self) -> bool: if not self.is_ready(): @@ -705,8 +712,8 @@ async def _on_connection_closed(self) -> None: self._klippy_initializing = False self._klippy_started = False self._methods_registered = False - self._state = "disconnected" - self._state_message = "Klippy Disconnected" + self._state = KlippyState.DISCONNECTED + self._state.set_message("Klippy Disconnected") for request in self.pending_requests.values(): request.set_exception(ServerError("Klippy Disconnected", 503)) self.pending_requests = {} diff --git a/moonraker/loghelper.py b/moonraker/loghelper.py index 419a8c35a..fd936564a 100644 --- a/moonraker/loghelper.py +++ b/moonraker/loghelper.py @@ -12,6 +12,7 @@ import sys import asyncio from queue import SimpleQueue as Queue +from .common import RequestType # Annotation imports from typing import ( @@ -112,7 +113,7 @@ def __init__( def set_server(self, server: Server) -> None: self.server = server self.server.register_endpoint( - "/server/logs/rollover", ['POST'], self._handle_log_rollover + "/server/logs/rollover", RequestType.POST, self._handle_log_rollover ) def set_rollover_info(self, name: str, item: str) -> None: diff --git a/moonraker/server.py b/moonraker/server.py index f968c23bb..b2907287a 100755 --- a/moonraker/server.py +++ b/moonraker/server.py @@ -25,6 +25,8 @@ from .klippy_connection import KlippyConnection from .utils import ServerError, Sentinel, get_software_info, json_wrapper from .loghelper import LogManager +from .common import RequestType +from .websockets import WebsocketManager # Annotation imports from typing import ( @@ -41,7 +43,6 @@ ) if TYPE_CHECKING: from .common import WebRequest - from .websockets import WebsocketManager from .components.file_manager.file_manager import FileManager from .components.machine import Machine from .components.extensions import ExtensionManager @@ -91,22 +92,25 @@ def __init__(self, # Tornado Application/Server self.moonraker_app = app = MoonrakerApp(config) - self.register_endpoint = app.register_local_handler - self.register_debug_endpoint = app.register_debug_handler + self.register_endpoint = app.register_endpoint + self.register_debug_endpoint = app.register_debug_endpoint self.register_static_file_handler = app.register_static_file_handler self.register_upload_handler = app.register_upload_handler - self.register_api_transport = app.register_api_transport self.log_manager.set_server(self) + self.websocket_manager = WebsocketManager(config) for warning in args.get("startup_warnings", []): self.add_warning(warning) self.register_endpoint( - "/server/info", ['GET'], self._handle_info_request) + "/server/info", RequestType.GET, self._handle_info_request + ) self.register_endpoint( - "/server/config", ['GET'], self._handle_config_request) + "/server/config", RequestType.GET, self._handle_config_request + ) self.register_endpoint( - "/server/restart", ['POST'], self._handle_server_restart) + "/server/restart", RequestType.POST, self._handle_server_restart + ) self.register_notification("server:klippy_ready") self.register_notification("server:klippy_shutdown") self.register_notification("server:klippy_disconnect", @@ -305,8 +309,7 @@ def register_component(self, component_name: str, component: Any) -> None: def register_notification( self, event_name: str, notify_name: Optional[str] = None ) -> None: - wsm: WebsocketManager = self.lookup_component("websockets") - wsm.register_notification(event_name, notify_name) + self.websocket_manager.register_notification(event_name, notify_name) def register_event_handler( self, event: str, callback: FlexCallback @@ -364,9 +367,6 @@ def get_host_info(self) -> Dict[str, Any]: def get_klippy_info(self) -> Dict[str, Any]: return self.klippy_connection.klippy_info - def get_klippy_state(self) -> str: - return self.klippy_connection.state - def _handle_term_signal(self) -> None: logging.info("Exiting with signal SIGTERM") self.event_loop.register_callback(self._stop_server, "terminate") @@ -390,6 +390,7 @@ async def _stop_server(self, exit_reason: str = "restart") -> None: await asyncio.sleep(.1) try: await self.moonraker_app.close() + await self.websocket_manager.close() except Exception: logging.exception("Error Closing App") @@ -433,7 +434,6 @@ async def _handle_info_request(self, web_request: WebRequest) -> Dict[str, Any]: reg_dirs = [] if file_manager is not None: reg_dirs = file_manager.get_registered_dirs() - wsm: WebsocketManager = self.lookup_component('websockets') mreqs = self.klippy_connection.missing_requirements if raw: warnings = list(self.warnings.values()) @@ -443,12 +443,12 @@ async def _handle_info_request(self, web_request: WebRequest) -> Dict[str, Any]: ] return { 'klippy_connected': self.klippy_connection.is_connected(), - 'klippy_state': self.klippy_connection.state, + 'klippy_state': str(self.klippy_connection.state), 'components': list(self.components.keys()), 'failed_components': self.failed_components, 'registered_directories': reg_dirs, 'warnings': warnings, - 'websocket_count': wsm.get_count(), + 'websocket_count': self.websocket_manager.get_count(), 'moonraker_version': self.app_args['software_version'], 'missing_klippy_requirements': mreqs, 'api_version': API_VERSION, diff --git a/moonraker/utils/__init__.py b/moonraker/utils/__init__.py index e9ebde1bf..9043f8656 100644 --- a/moonraker/utils/__init__.py +++ b/moonraker/utils/__init__.py @@ -19,6 +19,7 @@ import struct import socket import enum +import ipaddress from . import source_info from . import json_wrapper @@ -39,6 +40,7 @@ SYS_MOD_PATHS = glob.glob("/usr/lib/python3*/dist-packages") SYS_MOD_PATHS += glob.glob("/usr/lib/python3*/site-packages") +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] class ServerError(Exception): def __init__(self, message: str, status_code: int = 400) -> None: @@ -264,3 +266,9 @@ def pretty_print_time(seconds: int) -> str: continue fmt_list.append(f"{val} {ident}" if val == 1 else f"{val} {ident}s") return ", ".join(fmt_list) + +def parse_ip_address(address: str) -> Optional[IPAddress]: + try: + return ipaddress.ip_address(address) + except Exception: + return None diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 6aa86d521..c51ffa130 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -6,18 +6,16 @@ from __future__ import annotations import logging -import ipaddress import asyncio from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.web import HTTPError from .common import ( + RequestType, WebRequest, BaseRemoteConnection, - APITransport, - APIDefinition, - JsonRPC + TransportType, ) -from .utils import ServerError +from .utils import ServerError, parse_ip_address # Annotation imports from typing import ( @@ -35,9 +33,10 @@ if TYPE_CHECKING: from .server import Server from .klippy_connection import KlippyConnection as Klippy + from .confighelper import ConfigHelper from .components.extensions import ExtensionManager from .components.authorization import Authorization - IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + from .utils import IPAddress ConvType = Union[str, bool, float, int] ArgVal = Union[None, int, float, bool, str] RPCCallback = Callable[..., Coroutine] @@ -45,17 +44,21 @@ CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"] -class WebsocketManager(APITransport): - def __init__(self, server: Server) -> None: - self.server = server +class WebsocketManager: + def __init__(self, config: ConfigHelper) -> None: + self.server = config.get_server() self.clients: Dict[int, BaseRemoteConnection] = {} self.bridge_connections: Dict[int, BridgeSocket] = {} - self.rpc = JsonRPC(server) self.closed_event: Optional[asyncio.Event] = None - - self.rpc.register_method("server.websocket.id", self._handle_id_request) - self.rpc.register_method( - "server.connection.identify", self._handle_identify) + self.server.register_endpoint( + "/server/websocket/id", RequestType.GET, self._handle_id_request, + TransportType.WEBSOCKET + ) + self.server.register_endpoint( + "/server/connection/identify", RequestType.POST, self._handle_identify, + TransportType.WEBSOCKET, auth_required=False + ) + self.server.register_component("websockets", self) def register_notification( self, @@ -74,72 +77,27 @@ def notify_handler(*args): self.notify_clients(notify_name, args) self.server.register_event_handler(event_name, notify_handler) - def register_api_handler(self, api_def: APIDefinition) -> None: - klippy: Klippy = self.server.lookup_component("klippy_connection") - if api_def.callback is None: - # Remote API, uses RPC to reach out to Klippy - ws_method = api_def.jrpc_methods[0] - rpc_cb = self._generate_callback( - api_def.endpoint, "", klippy.request - ) - self.rpc.register_method(ws_method, rpc_cb) - else: - # Local API, uses local callback - for ws_method, req_method in \ - zip(api_def.jrpc_methods, api_def.request_methods): - rpc_cb = self._generate_callback( - api_def.endpoint, req_method, api_def.callback - ) - self.rpc.register_method(ws_method, rpc_cb) - logging.info( - "Registering Websocket JSON-RPC methods: " - f"{', '.join(api_def.jrpc_methods)}" - ) - - def remove_api_handler(self, api_def: APIDefinition) -> None: - for jrpc_method in api_def.jrpc_methods: - self.rpc.remove_method(jrpc_method) - - def _generate_callback( - self, - endpoint: str, - request_method: str, - callback: Callable[[WebRequest], Coroutine] - ) -> RPCCallback: - async def func(args: Dict[str, Any]) -> Any: - sc: BaseRemoteConnection = args.pop("_socket_") - sc.check_authenticated(path=endpoint) - result = await callback( - WebRequest(endpoint, args, request_method, sc, - ip_addr=sc.ip_addr, user=sc.user_info)) - return result - return func - - async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: - sc: BaseRemoteConnection = args["_socket_"] - sc.check_authenticated() + async def _handle_id_request(self, web_request: WebRequest) -> Dict[str, int]: + sc = web_request.get_client_connection() + assert sc is not None return {'websocket_id': sc.uid} - async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: - sc: BaseRemoteConnection = args["_socket_"] - sc.authenticate( - token=args.get("access_token", None), - api_key=args.get("api_key", None) - ) + async def _handle_identify(self, web_request: WebRequest) -> Dict[str, int]: + sc = web_request.get_client_connection() + assert sc is not None if sc.identified: raise self.server.error( f"Connection already identified: {sc.client_data}" ) - try: - name = str(args["client_name"]) - version = str(args["version"]) - client_type: str = str(args["type"]).lower() - url = str(args["url"]) - except KeyError as e: - missing_key = str(e).split(":")[-1].strip() - raise self.server.error( - f"No data for argument: {missing_key}" - ) from None + name = web_request.get_str("client_name") + version = web_request.get_str("version") + client_type: str = web_request.get_str("type").lower() + url = web_request.get_str("url") + sc.authenticate( + token=web_request.get_str("access_token", None), + api_key=web_request.get_str("api_key", None) + ) + if client_type not in CLIENT_TYPES: raise self.server.error(f"Invalid Client Type: {client_type}") sc.client_data = { @@ -272,9 +230,13 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection): def initialize(self) -> None: self.on_create(self.settings['server']) - self.ip_addr: str = self.request.remote_ip or "" + self._ip_addr = parse_ip_address(self.request.remote_ip or "") self.last_pong_time: float = self.eventloop.get_loop_time() + @property + def ip_addr(self) -> Optional[IPAddress]: + return self._ip_addr + @property def hostname(self) -> str: return self.request.host_name @@ -362,7 +324,7 @@ def prepare(self) -> None: auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: try: - self._user_info = auth.check_authorized(self.request) + self._user_info = auth.authenticate_request(self.request) except Exception as e: logging.info(f"Websocket Failed Authentication: {e}") self._user_info = None @@ -377,13 +339,17 @@ def initialize(self) -> None: self.wsm: WebsocketManager = self.server.lookup_component("websockets") self.eventloop = self.server.get_event_loop() self.uid = id(self) - self.ip_addr: str = self.request.remote_ip or "" + self._ip_addr = parse_ip_address(self.request.remote_ip or "") self.last_pong_time: float = self.eventloop.get_loop_time() self.is_closed = False self.klippy_writer: Optional[asyncio.StreamWriter] = None self.klippy_write_buf: List[bytes] = [] self.klippy_queue_busy: bool = False + @property + def ip_addr(self) -> Optional[IPAddress]: + return self._ip_addr + @property def hostname(self) -> str: return self.request.host_name @@ -502,7 +468,7 @@ async def prepare(self) -> None: ) auth: AuthComp = self.server.lookup_component("authorization", None) if auth is not None: - self.current_user = auth.check_authorized(self.request) + self.current_user = auth.authenticate_request(self.request) kconn: Klippy = self.server.lookup_component("klippy_connection") try: reader, writer = await kconn.open_klippy_connection()