Skip to content

Commit

Permalink
power: improve basic auth implementation
Browse files Browse the repository at this point in the history
Rather than pass the user name and password via the
url, supply them directly to the http request.  This should guarantee that the authorization header is
generated correctly.

Signed-off-by:  Eric Callahan <[email protected]>
  • Loading branch information
Arksine committed Aug 11, 2024
1 parent 30ac5df commit e66744b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
18 changes: 14 additions & 4 deletions moonraker/components/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ async def request(
retry_pause_time: float = .1,
enable_cache: bool = False,
send_etag: bool = True,
send_if_modified_since: bool = True
send_if_modified_since: bool = True,
basic_auth_user: Optional[str] = None,
basic_auth_pass: Optional[str] = None
) -> HttpResponse:
cache_key = url.split("?", 1)[0]
method = method.upper()
Expand All @@ -103,9 +105,17 @@ async def request(
headers = req_headers

timeout = 1 + connect_timeout + request_timeout
request = HTTPRequest(url, method, headers, body=body,
request_timeout=request_timeout,
connect_timeout=connect_timeout)
req_args: Dict[str, Any] = dict(
body=body,
request_timeout=request_timeout,
connect_timeout=connect_timeout
)
if basic_auth_user is not None:
assert basic_auth_pass is not None
req_args["auth_username"] = basic_auth_user
req_args["auth_password"] = basic_auth_pass
req_args["auth_mode"] = "basic"
request = HTTPRequest(url, method, headers, **req_args)
err: Optional[BaseException] = None
for i in range(attempts):
if i:
Expand Down
51 changes: 31 additions & 20 deletions moonraker/components/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,16 @@ def __init__(
self.addr: str = config.get("address")
self.port = config.getint("port", default_port)
self.user = config.load_template("user", default_user).render()
self.password = config.load_template(
"password", default_password).render()
self.password = config.load_template("password", default_password).render()
self.has_basic_auth: bool = False
self.protocol = config.get("protocol", default_protocol)
if self.port == -1:
self.port = 443 if self.protocol.lower() == "https" else 80

def enable_basic_authentication(self) -> None:
if self.user and self.password:
self.has_basic_auth = True

async def init_state(self) -> None:
async with self.request_lock:
last_err: Exception = Exception()
Expand Down Expand Up @@ -492,9 +496,15 @@ async def init_state(self) -> None:
async def _send_http_command(
self, url: str, command: str, retries: int = 3
) -> Dict[str, Any]:
ba_user: Optional[str] = None
ba_pass: Optional[str] = None
if self.has_basic_auth:
ba_user = self.user
ba_pass = self.password
response = await self.client.get(
url, request_timeout=20., attempts=retries,
retry_pause_time=1., enable_cache=False)
url, request_timeout=20., attempts=retries, retry_pause_time=1.,
enable_cache=False, basic_auth_user=ba_user, basic_auth_pass=ba_pass
)
response.raise_for_status(
f"Error sending '{self.type}' command: {command}")
data = cast(dict, response.json())
Expand Down Expand Up @@ -632,7 +642,7 @@ async def _handle_ready(self) -> None:
sub: Dict[str, Optional[List[str]]] = {self.object_name: None}
data = await kapis.subscribe_objects(sub, self._status_update, None)
if not self._validate_data(data):
self.state == "error"
self.state = "error"
else:
assert data is not None
self._set_state_from_data(data)
Expand Down Expand Up @@ -1012,6 +1022,7 @@ def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_user="admin", default_password="")
self.output_id = config.getint("output_id", 0)
self.timer = config.get("timer", "")
self.enable_basic_authentication()

async def _send_shelly_command(self, command: str) -> Dict[str, Any]:
query_args: Dict[str, Any] = {}
Expand All @@ -1023,12 +1034,8 @@ async def _send_shelly_command(self, command: str) -> Dict[str, Any]:
query_args["timer"] = self.timer
elif command != "info":
raise self.server.error(f"Invalid shelly command: {command}")
if self.password != "":
out_pwd = f"{quote(self.user)}:{quote(self.password)}@"
else:
out_pwd = ""
query = urlencode(query_args)
url = f"{self.protocol}://{out_pwd}{quote(self.addr)}/{out_cmd}?{query}"
url = f"{self.protocol}://{quote(self.addr)}/{out_cmd}?{query}"
return await self._send_http_command(url, command)

async def _send_status_request(self) -> str:
Expand Down Expand Up @@ -1102,6 +1109,7 @@ class HomeSeer(HTTPDevice):
def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_user="admin", default_password="")
self.device = config.getint("device")
self.enable_basic_authentication()

async def _send_homeseer(
self, request: str, state: str = ""
Expand All @@ -1116,8 +1124,7 @@ async def _send_homeseer(
query_args["label"] = state
query = urlencode(query_args)
url = (
f"{self.protocol}://{quote(self.user)}:{quote(self.password)}@"
f"{quote(self.addr)}:{self.port}/JSON?{query}"
f"{self.protocol}://{quote(self.addr)}:{self.port}/JSON?{query}"
)
return await self._send_http_command(url, request)

Expand Down Expand Up @@ -1182,6 +1189,7 @@ def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_user="admin",
default_password="admin")
self.output_id = config.get("output_id", "")
self.enable_basic_authentication()

async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]:
if command in ["on", "off"]:
Expand All @@ -1190,11 +1198,7 @@ async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]:
out_cmd = f"jdev/sps/io/{quote(self.output_id)}"
else:
raise self.server.error(f"Invalid loxonev1 command: {command}")
if self.password != "":
out_pwd = f"{quote(self.user)}:{quote(self.password)}@"
else:
out_pwd = ""
url = f"http://{out_pwd}{quote(self.addr)}/{out_cmd}"
url = f"http://{quote(self.addr)}/{out_cmd}"
return await self._send_http_command(url, command)

async def _send_status_request(self) -> str:
Expand Down Expand Up @@ -1242,6 +1246,7 @@ def _on_state_update(self, payload: bytes) -> None:
context = {
'payload': payload.decode()
}
response: str = ""
try:
response = self.state_response.render(context)
except Exception as e:
Expand Down Expand Up @@ -1389,7 +1394,6 @@ async def set_power(self, state: str) -> None:


class HueDevice(HTTPDevice):

def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_port=80)
self.device_id = config.get("device_id")
Expand Down Expand Up @@ -1428,7 +1432,7 @@ async def _send_status_request(self) -> str:
return "on" if resp["state"][self.on_state] else "off"

class GenericHTTP(HTTPDevice):
def __init__(self, config: ConfigHelper,) -> None:
def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, is_generic=True)
self.urls: Dict[str, str] = {
"on": config.gettemplate("on_url").render(),
Expand All @@ -1439,10 +1443,17 @@ def __init__(self, config: ConfigHelper,) -> None:
"request_template", None, is_async=True
)
self.response_template = config.gettemplate("response_template", is_async=True)
self.enable_basic_authentication()

async def _send_generic_request(self, command: str) -> str:
ba_user: Optional[str] = None
ba_pass: Optional[str] = None
if self.has_basic_auth:
ba_user = self.user
ba_pass = self.password
request = self.client.wrap_request(
self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1.
self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1.,
basic_auth_user=ba_user, basic_auth_pass=ba_pass
)
context: Dict[str, Any] = {
"command": command,
Expand Down

0 comments on commit e66744b

Please sign in to comment.