Skip to content

Commit

Permalink
Import quirks in executor (#34)
Browse files Browse the repository at this point in the history
* Import quirks in executor

* await

* tests

* coverage and adjust shutdown

* adjust
  • Loading branch information
dmulcahey authored Apr 3, 2024
1 parent c64ce42 commit 3d160a3
Show file tree
Hide file tree
Showing 6 changed files with 602 additions and 72 deletions.
195 changes: 144 additions & 51 deletions tests/test_async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import asyncio
import functools
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest

from zha import async_ as zha_async
from zha.application.gateway import Gateway
from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task
from zha.decorators import callback
Expand Down Expand Up @@ -493,6 +494,9 @@ async def test_add_job_with_none(zha_gateway: Gateway) -> None:
with pytest.raises(ValueError):
zha_gateway.async_add_job(None, "test_arg")

with pytest.raises(ValueError):
zha_gateway.add_job(None, "test_arg")


async def test_async_functions_with_callback(zha_gateway: Gateway) -> None:
"""Test we deal with async functions accidentally marked as callback."""
Expand Down Expand Up @@ -559,56 +563,6 @@ async def test_task():
assert result.result() == "Foo"


async def test_shutdown_does_not_block_on_normal_tasks(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown does not block on normal tasks."""
result = asyncio.Future()
unshielded_task = asyncio.sleep(10)

async def test_task():
try:
await unshielded_task
except asyncio.CancelledError:
result.set_result("Foo")

start = time.monotonic()
task = zha_gateway.async_create_task(test_task())
await asyncio.sleep(0)
await zha_gateway.shutdown()
await asyncio.sleep(0)
assert result.done()
assert task.done()
assert time.monotonic() - start < 0.5


async def test_shutdown_does_not_block_on_shielded_tasks(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown does not block on shielded tasks."""
result = asyncio.Future()
sleep_task = asyncio.ensure_future(asyncio.sleep(10))
shielded_task = asyncio.shield(sleep_task)

async def test_task():
try:
await shielded_task
except asyncio.CancelledError:
result.set_result("Foo")

start = time.monotonic()
task = zha_gateway.async_create_task(test_task())
await asyncio.sleep(0)
await zha_gateway.shutdown()
await asyncio.sleep(0)
assert result.done()
assert task.done()
assert time.monotonic() - start < 0.5

# Cleanup lingering task after test is done
sleep_task.cancel()


@pytest.mark.parametrize("eager_start", [True, False])
async def test_cancellable_ZHAJob(zha_gateway: Gateway, eager_start: bool) -> None:
"""Simulate a shutdown, ensure cancellable jobs are cancelled."""
Expand Down Expand Up @@ -693,3 +647,142 @@ async def _async_add_executor_job():
await zha_gateway.async_block_till_done()
assert len(calls) == 1
await task


@patch("concurrent.futures.Future")
@patch("threading.get_ident")
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _) -> None:
"""Testing calling run_callback_threadsafe from inside an event loop."""
callback_fn = MagicMock()

loop = Mock(spec=["call_soon_threadsafe"])

loop._thread_ident = None
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 1

loop._thread_ident = 1
mock_ident.return_value = 5
zha_async.run_callback_threadsafe(loop, callback_fn)
assert len(loop.call_soon_threadsafe.mock_calls) == 2


async def test_gather_with_limited_concurrency() -> None:
"""Test gather_with_limited_concurrency limits the number of running tasks."""

runs = 0
now_time = time.time()

async def _increment_runs_if_in_time():
if time.time() - now_time > 0.1:
return -1

nonlocal runs
runs += 1
await asyncio.sleep(0.1)
return runs

results = await zha_async.gather_with_limited_concurrency(
2, *(_increment_runs_if_in_time() for i in range(4))
)

assert results == [2, 2, -1, -1]


async def test_shutdown_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test we can shutdown run_callback_threadsafe."""
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)
callback_fn = MagicMock()

with pytest.raises(RuntimeError):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)


async def test_run_callback_threadsafe(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True

assert zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True


async def test_run_callback_threadsafe_exception(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True
raise ValueError("Test")

future = zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert future
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True

with pytest.raises(ValueError):
future.result()


async def test_callback_is_always_scheduled(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe always calls call_soon_threadsafe before checking for shutdown."""
# We have to check the shutdown state AFTER the callback is scheduled otherwise
# the function could continue on and the caller call `future.result()` after
# the point in the main thread where callbacks are no longer run.

callback_fn = MagicMock()
zha_async.shutdown_run_callback_threadsafe(zha_gateway.loop)

with (
patch.object(
zha_gateway.loop, "call_soon_threadsafe"
) as mock_call_soon_threadsafe,
pytest.raises(RuntimeError),
):
zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)

mock_call_soon_threadsafe.assert_called_once()


async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument
"""Test create_eager_task schedules a task eagerly in the event loop.
For Python 3.12+, the task is scheduled eagerly in the event loop.
"""
events = []

async def _normal_task():
events.append("normal")

async def _eager_task():
events.append("eager")

task1 = zha_async.create_eager_task(_eager_task())
task2 = asyncio.create_task(_normal_task())

assert events == ["eager"]

await asyncio.sleep(0)
assert events == ["eager", "normal"]
await task1
await task2
2 changes: 2 additions & 0 deletions tests/test_cluster_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,9 @@ async def test_poll_control_ikea(poll_control_device: Device) -> None:
poll_control_ch = poll_control_device._endpoints[1].all_cluster_handlers["1:0x0020"]
cluster = poll_control_ch.cluster

delattr(poll_control_device, "manufacturer_code")
poll_control_device.device.node_desc.manufacturer_code = 4476

with mock.patch.object(cluster, "set_long_poll_interval", set_long_poll_mock):
await poll_control_ch.check_in_response(33)

Expand Down
94 changes: 94 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Test ZHA executor util."""

import concurrent.futures
import time
from unittest.mock import patch

import pytest

from zha import async_
from zha.async_ import InterruptibleThreadPoolExecutor


async def test_executor_shutdown_can_interrupt_threads(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that the executor shutdown can interrupt threads."""

iexecutor = InterruptibleThreadPoolExecutor()

def _loop_sleep_in_executor():
while True:
time.sleep(0.1)

sleep_futures = [iexecutor.submit(_loop_sleep_in_executor) for _ in range(100)]

iexecutor.shutdown()

for future in sleep_futures:
with pytest.raises((concurrent.futures.CancelledError, SystemExit)):
future.result()

assert "is still running at shutdown" in caplog.text
assert "time.sleep(0.1)" in caplog.text


async def test_executor_shutdown_only_logs_max_attempts(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that the executor shutdown will only log max attempts."""

iexecutor = InterruptibleThreadPoolExecutor()

def _loop_sleep_in_executor():
time.sleep(0.2)

iexecutor.submit(_loop_sleep_in_executor)

with patch.object(async_, "EXECUTOR_SHUTDOWN_TIMEOUT", 0.3):
iexecutor.shutdown()

assert "time.sleep(0.2)" in caplog.text
assert "is still running at shutdown" in caplog.text
iexecutor.shutdown()


async def test_executor_shutdown_does_not_log_shutdown_on_first_attempt(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that the executor shutdown does not log on first attempt."""

iexecutor = InterruptibleThreadPoolExecutor()

def _do_nothing():
return

for _ in range(5):
iexecutor.submit(_do_nothing)

iexecutor.shutdown()

assert "is still running at shutdown" not in caplog.text


async def test_overall_timeout_reached() -> None:
"""Test that shutdown moves on when the overall timeout is reached."""

def _loop_sleep_in_executor():
time.sleep(1)

with patch.object(async_, "EXECUTOR_SHUTDOWN_TIMEOUT", 0.5):
iexecutor = InterruptibleThreadPoolExecutor()
for _ in range(6):
iexecutor.submit(_loop_sleep_in_executor)
start = time.monotonic()
iexecutor.shutdown()
finish = time.monotonic()

# Idealy execution time (finish - start) should be < 1.2 sec.
# CI tests might not run in an ideal environment and timing might
# not be accurate, so we let this test pass
# if the duration is below 3 seconds.
assert finish - start < 3.0

iexecutor.shutdown()
Loading

0 comments on commit 3d160a3

Please sign in to comment.