Skip to content

Commit

Permalink
Rework discovery timeout logic (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar authored May 4, 2024
1 parent 33a1318 commit 266a3d0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 94 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
packages=setuptools.find_packages(exclude=["tests", "tests.*"]),
install_requires=[
"aiohttp>=3.5.4, <4",
"async_timeout>=4.0.2",
"voluptuous>=0.11.5",
"importlib_metadata>=3.6; python_version<'3.10'",
"typing_extensions>=4.1.0; python_version<'3.11'",
Expand Down
17 changes: 11 additions & 6 deletions solax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import asyncio
import logging

from async_timeout import timeout

from solax.discovery import discover
from solax.inverter import Inverter, InverterResponse
from solax.inverter_http_client import REQUEST_TIMEOUT

_LOGGER = logging.getLogger(__name__)


REQUEST_TIMEOUT = 5
__all__ = (
"discover",
"real_time_api",
"rt_request",
"Inverter",
"InverterResponse",
"RealTimeAPI",
"REQUEST_TIMEOUT",
)


async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse:
Expand All @@ -23,8 +29,7 @@ async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse:
new_wait = (t_wait * 2) + 5
retry = retry - 1
try:
async with timeout(REQUEST_TIMEOUT):
return await inv.get_data()
return await inv.get_data()
except asyncio.TimeoutError:
if retry > 0:
return await rt_request(inv, retry, new_wait)
Expand Down
164 changes: 79 additions & 85 deletions solax/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import sys
from asyncio import Future, Task
from collections import defaultdict
from typing import Dict, Literal, Optional, Sequence, Set, TypedDict, Union, cast

from async_timeout import timeout
from typing import Dict, Literal, Sequence, Set, TypedDict, Union, cast

from solax.inverter import Inverter
from solax.inverter_http_client import InverterHttpClient
Expand All @@ -29,7 +27,6 @@


class DiscoveryKeywords(TypedDict, total=False):
timeout: Optional[float]
inverters: Sequence[Inverter]
return_when: Union[Literal["ALL_COMPLETED"], Literal["FIRST_COMPLETED"]]

Expand Down Expand Up @@ -72,89 +69,86 @@ async def _discovery_task(i) -> Inverter:
async def discover(
host, port, pwd="", **kwargs: Unpack[DiscoveryKeywords]
) -> Union[Inverter, Set[Inverter]]:
async with timeout(kwargs.get("timeout", 15)):
done: Set[_InverterTask] = set()
pending: Set[_InverterTask] = set()
failures = set()
requests: Dict[InverterHttpClient, Future] = defaultdict(
asyncio.get_running_loop().create_future
)

return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED)
for cls in kwargs.get("inverters", REGISTRY):
for inverter in cls.build_all_variants(host, port, pwd):
inverter.http_client = cast(
InverterHttpClient,
_DiscoveryHttpClient(
inverter, inverter.http_client, requests[inverter.http_client]
),
)

pending.add(
asyncio.create_task(_discovery_task(inverter), name=f"{inverter}")
)

if not pending:
raise DiscoveryError("No inverters to try to discover")

def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]:
for task in pending:
task.cancel()
return pending

def remove_failures_from(done: Set[_InverterTask]) -> None:
for task in set(done):
exc = task.exception()
if exc:
failures.add(exc)
done.remove(task)

# stagger HTTP request to prevent accidental Denial Of Service
async def stagger() -> None:
for http_client, future in requests.items():
future.set_result(asyncio.create_task(http_client.request()))
await asyncio.sleep(1)

staggered = asyncio.create_task(stagger())

while pending and (not done or return_when != asyncio.FIRST_COMPLETED):
try:
done, pending = await asyncio.wait(pending, return_when=return_when)
except asyncio.CancelledError:
staggered.cancel()
await asyncio.gather(
staggered, *cancel(pending), return_exceptions=True
)
raise

remove_failures_from(done)

if done and return_when == asyncio.FIRST_COMPLETED:
break

logging.debug("%d discovery tasks are still running...", len(pending))

if pending and return_when != asyncio.FIRST_COMPLETED:
pending.update(done)
done.clear()
done: Set[_InverterTask] = set()
pending: Set[_InverterTask] = set()
failures = set()
requests: Dict[InverterHttpClient, Future] = defaultdict(
asyncio.get_running_loop().create_future
)

return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED)
for cls in kwargs.get("inverters", REGISTRY):
for inverter in cls.build_all_variants(host, port, pwd):
inverter.http_client = cast(
InverterHttpClient,
_DiscoveryHttpClient(
inverter, inverter.http_client, requests[inverter.http_client]
),
)

pending.add(
asyncio.create_task(_discovery_task(inverter), name=f"{inverter}")
)

if not pending:
raise DiscoveryError("No inverters to try to discover")

def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]:
for task in pending:
task.cancel()
return pending

def remove_failures_from(done: Set[_InverterTask]) -> None:
for task in set(done):
exc = task.exception()
if exc:
failures.add(exc)
done.remove(task)

# stagger HTTP request to prevent accidental Denial Of Service
async def stagger() -> None:
for http_client, future in requests.items():
future.set_result(asyncio.create_task(http_client.request()))
await asyncio.sleep(1)

staggered = asyncio.create_task(stagger())

while pending and (not done or return_when != asyncio.FIRST_COMPLETED):
try:
done, pending = await asyncio.wait(pending, return_when=return_when)
except asyncio.CancelledError:
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)
raise

remove_failures_from(done)
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)

if done:
logging.info("Discovered inverters: %s", {task.result() for task in done})
if return_when == asyncio.FIRST_COMPLETED:
return await next(iter(done))

return {task.result() for task in done}

raise DiscoveryError(
"Unable to connect to the inverter at "
f"host={host} port={port}, or your inverter is not supported yet.\n"
"Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n"
f"Failures={str(failures)}"
)

if done and return_when == asyncio.FIRST_COMPLETED:
break

logging.debug("%d discovery tasks are still running...", len(pending))

if pending and return_when != asyncio.FIRST_COMPLETED:
pending.update(done)
done.clear()

remove_failures_from(done)
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)

if done:
logging.info("Discovered inverters: %s", {task.result() for task in done})
if return_when == asyncio.FIRST_COMPLETED:
return await next(iter(done))

return {task.result() for task in done}

raise DiscoveryError(
"Unable to connect to the inverter at "
f"host={host} port={port}, or your inverter is not supported yet.\n"
"Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n"
f"Failures={str(failures)}"
)


class DiscoveryError(Exception):
Expand Down
10 changes: 8 additions & 2 deletions solax/inverter_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if sys.version_info >= (3, 10):
from dataclasses import KW_ONLY


REQUEST_TIMEOUT = 5.0
_CACHE: WeakValueDictionary[int, InverterHttpClient] = WeakValueDictionary()


Expand Down Expand Up @@ -107,7 +109,9 @@ async def request(self):
async def get(self):
url = self.url + "?" + self.query if self.query else self.url
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as req:
async with session.get(
url, headers=self.headers, timeout=REQUEST_TIMEOUT
) as req:
req.raise_for_status()
resp = await req.read()
return resp
Expand All @@ -116,7 +120,9 @@ async def post(self):
url = self.url + "?" + self.query if self.query else self.url
data = self.data.encode("utf-8") if self.data else None
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=self.headers, data=data) as req:
async with session.post(
url, headers=self.headers, data=data, timeout=REQUEST_TIMEOUT
) as req:
req.raise_for_status()
resp = await req.read()
return resp
Expand Down

0 comments on commit 266a3d0

Please sign in to comment.