From e66744b6e9b3475cb2fb3908777112992a01bf76 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Sun, 11 Aug 2024 06:23:36 -0400 Subject: [PATCH] power: improve basic auth implementation Rather than pass the user name and password via the url, supply them directly to the http request. This should guarantee that the authorization header is generated correctly. Signed-off-by: Eric Callahan --- moonraker/components/http_client.py | 18 +++++++--- moonraker/components/power.py | 51 ++++++++++++++++++----------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/moonraker/components/http_client.py b/moonraker/components/http_client.py index ef002bb6d..fc23f6b78 100644 --- a/moonraker/components/http_client.py +++ b/moonraker/components/http_client.py @@ -80,7 +80,9 @@ async def request( retry_pause_time: float = .1, enable_cache: bool = False, send_etag: bool = True, - send_if_modified_since: bool = True + send_if_modified_since: bool = True, + basic_auth_user: Optional[str] = None, + basic_auth_pass: Optional[str] = None ) -> HttpResponse: cache_key = url.split("?", 1)[0] method = method.upper() @@ -103,9 +105,17 @@ async def request( headers = req_headers timeout = 1 + connect_timeout + request_timeout - request = HTTPRequest(url, method, headers, body=body, - request_timeout=request_timeout, - connect_timeout=connect_timeout) + req_args: Dict[str, Any] = dict( + body=body, + request_timeout=request_timeout, + connect_timeout=connect_timeout + ) + if basic_auth_user is not None: + assert basic_auth_pass is not None + req_args["auth_username"] = basic_auth_user + req_args["auth_password"] = basic_auth_pass + req_args["auth_mode"] = "basic" + request = HTTPRequest(url, method, headers, **req_args) err: Optional[BaseException] = None for i in range(attempts): if i: diff --git a/moonraker/components/power.py b/moonraker/components/power.py index f698dd0ea..3042ac6b8 100644 --- a/moonraker/components/power.py +++ b/moonraker/components/power.py @@ -451,12 +451,16 @@ def __init__( self.addr: str = config.get("address") self.port = config.getint("port", default_port) self.user = config.load_template("user", default_user).render() - self.password = config.load_template( - "password", default_password).render() + self.password = config.load_template("password", default_password).render() + self.has_basic_auth: bool = False self.protocol = config.get("protocol", default_protocol) if self.port == -1: self.port = 443 if self.protocol.lower() == "https" else 80 + def enable_basic_authentication(self) -> None: + if self.user and self.password: + self.has_basic_auth = True + async def init_state(self) -> None: async with self.request_lock: last_err: Exception = Exception() @@ -492,9 +496,15 @@ async def init_state(self) -> None: async def _send_http_command( self, url: str, command: str, retries: int = 3 ) -> Dict[str, Any]: + ba_user: Optional[str] = None + ba_pass: Optional[str] = None + if self.has_basic_auth: + ba_user = self.user + ba_pass = self.password response = await self.client.get( - url, request_timeout=20., attempts=retries, - retry_pause_time=1., enable_cache=False) + url, request_timeout=20., attempts=retries, retry_pause_time=1., + enable_cache=False, basic_auth_user=ba_user, basic_auth_pass=ba_pass + ) response.raise_for_status( f"Error sending '{self.type}' command: {command}") data = cast(dict, response.json()) @@ -632,7 +642,7 @@ async def _handle_ready(self) -> None: sub: Dict[str, Optional[List[str]]] = {self.object_name: None} data = await kapis.subscribe_objects(sub, self._status_update, None) if not self._validate_data(data): - self.state == "error" + self.state = "error" else: assert data is not None self._set_state_from_data(data) @@ -1012,6 +1022,7 @@ def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_user="admin", default_password="") self.output_id = config.getint("output_id", 0) self.timer = config.get("timer", "") + self.enable_basic_authentication() async def _send_shelly_command(self, command: str) -> Dict[str, Any]: query_args: Dict[str, Any] = {} @@ -1023,12 +1034,8 @@ async def _send_shelly_command(self, command: str) -> Dict[str, Any]: query_args["timer"] = self.timer elif command != "info": raise self.server.error(f"Invalid shelly command: {command}") - if self.password != "": - out_pwd = f"{quote(self.user)}:{quote(self.password)}@" - else: - out_pwd = "" query = urlencode(query_args) - url = f"{self.protocol}://{out_pwd}{quote(self.addr)}/{out_cmd}?{query}" + url = f"{self.protocol}://{quote(self.addr)}/{out_cmd}?{query}" return await self._send_http_command(url, command) async def _send_status_request(self) -> str: @@ -1102,6 +1109,7 @@ class HomeSeer(HTTPDevice): def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_user="admin", default_password="") self.device = config.getint("device") + self.enable_basic_authentication() async def _send_homeseer( self, request: str, state: str = "" @@ -1116,8 +1124,7 @@ async def _send_homeseer( query_args["label"] = state query = urlencode(query_args) url = ( - f"{self.protocol}://{quote(self.user)}:{quote(self.password)}@" - f"{quote(self.addr)}:{self.port}/JSON?{query}" + f"{self.protocol}://{quote(self.addr)}:{self.port}/JSON?{query}" ) return await self._send_http_command(url, request) @@ -1182,6 +1189,7 @@ def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_user="admin", default_password="admin") self.output_id = config.get("output_id", "") + self.enable_basic_authentication() async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]: if command in ["on", "off"]: @@ -1190,11 +1198,7 @@ async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]: out_cmd = f"jdev/sps/io/{quote(self.output_id)}" else: raise self.server.error(f"Invalid loxonev1 command: {command}") - if self.password != "": - out_pwd = f"{quote(self.user)}:{quote(self.password)}@" - else: - out_pwd = "" - url = f"http://{out_pwd}{quote(self.addr)}/{out_cmd}" + url = f"http://{quote(self.addr)}/{out_cmd}" return await self._send_http_command(url, command) async def _send_status_request(self) -> str: @@ -1242,6 +1246,7 @@ def _on_state_update(self, payload: bytes) -> None: context = { 'payload': payload.decode() } + response: str = "" try: response = self.state_response.render(context) except Exception as e: @@ -1389,7 +1394,6 @@ async def set_power(self, state: str) -> None: class HueDevice(HTTPDevice): - def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_port=80) self.device_id = config.get("device_id") @@ -1428,7 +1432,7 @@ async def _send_status_request(self) -> str: return "on" if resp["state"][self.on_state] else "off" class GenericHTTP(HTTPDevice): - def __init__(self, config: ConfigHelper,) -> None: + def __init__(self, config: ConfigHelper) -> None: super().__init__(config, is_generic=True) self.urls: Dict[str, str] = { "on": config.gettemplate("on_url").render(), @@ -1439,10 +1443,17 @@ def __init__(self, config: ConfigHelper,) -> None: "request_template", None, is_async=True ) self.response_template = config.gettemplate("response_template", is_async=True) + self.enable_basic_authentication() async def _send_generic_request(self, command: str) -> str: + ba_user: Optional[str] = None + ba_pass: Optional[str] = None + if self.has_basic_auth: + ba_user = self.user + ba_pass = self.password request = self.client.wrap_request( - self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1. + self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1., + basic_auth_user=ba_user, basic_auth_pass=ba_pass ) context: Dict[str, Any] = { "command": command,