From eaf6e0a6367b833bfe5fcdb8eb5b33a526beed3a Mon Sep 17 00:00:00 2001 From: Magnus Larsson Date: Sat, 17 Aug 2024 15:08:58 +0200 Subject: [PATCH] Call open asynchronously (#42) * Call open asynchronously --- .../swedish_calendar/api_data.py | 24 ++++++++++++----- .../swedish_calendar/coordinator.py | 6 ++--- .../swedish_calendar/theme_data.py | 21 +++++++++++---- tests/test_api_data.py | 26 +++++++++++-------- 4 files changed, 51 insertions(+), 26 deletions(-) diff --git a/custom_components/swedish_calendar/api_data.py b/custom_components/swedish_calendar/api_data.py index 0f3e0df..f2fc099 100644 --- a/custom_components/swedish_calendar/api_data.py +++ b/custom_components/swedish_calendar/api_data.py @@ -1,6 +1,7 @@ import asyncio from collections import deque from datetime import date, datetime, timedelta, timezone +from functools import partial import hashlib import json import logging @@ -10,6 +11,8 @@ import aiohttp import async_timeout +from homeassistant.core import HomeAssistant + from .types import ApiData, CacheConfig from .utils import DateUtils @@ -17,10 +20,10 @@ class ApiDataProvider: - def __init__(self, session: aiohttp.ClientSession, cache_config: CacheConfig): + def __init__(self, hass: HomeAssistant, session: aiohttp.ClientSession, cache_config: CacheConfig): self._base_url: str = 'https://sholiday.faboul.se/dagar/v2.1/' self._session = session - self._cache = ApiDataCache(cache_config) + self._cache = ApiDataCache(hass, cache_config) async def fetch_data(self, start: date, end: date) -> list[ApiData]: urls = deque(self._get_urls(start, end)) @@ -74,10 +77,10 @@ def _get_url_patterns_for_date_range(start: date, end: date) -> list[str]: async def _get_json_from_url(self, url, timeout) -> dict[str, Any]: if self._cache.has_data_for(url): _LOGGER.debug("Using cached version of url: %s", url) - data = self._cache.get(url) + data = await self._cache.get(url) else: data = await self._get_data_online(url, timeout) - self._cache.update(url, data) + await self._cache.update(url, data) return data @@ -101,7 +104,8 @@ def _to_api_data(json_response: dict[str, Any], start: date, end: date) -> list[ class ApiDataCache: - def __init__(self, cache_config: CacheConfig): + def __init__(self, hass: HomeAssistant, cache_config: CacheConfig): + self._hass = hass self.config = cache_config def has_data_for(self, url: str) -> bool: @@ -121,7 +125,10 @@ def _cache_age(self, url) -> timedelta: now_in_utc = datetime.now().astimezone(tz=timezone.utc) return now_in_utc - cache_in_utc - def get(self, url) -> dict[str, Any] | None: + async def get(self, url) -> dict[str, Any] | None: + return await self._hass.async_add_executor_job(partial(self._get, url=url)) + + def _get(self, url) -> dict[str, Any] | None: path = self._url_to_path(url) data = None with open(path) as cached_file: @@ -133,7 +140,10 @@ def get(self, url) -> dict[str, Any] | None: return data - def update(self, url, data: dict[str, Any]) -> None: + async def update(self, url, data: dict[str, Any]) -> None: + return await self._hass.async_add_executor_job(partial(self._update, url=url, data=data)) + + def _update(self, url, data: dict[str, Any]) -> None: if self.config.enabled: path = self._url_to_path(url) _LOGGER.debug("Caching %s, saving to %s", url, path) diff --git a/custom_components/swedish_calendar/coordinator.py b/custom_components/swedish_calendar/coordinator.py index 3610e1b..d775266 100644 --- a/custom_components/swedish_calendar/coordinator.py +++ b/custom_components/swedish_calendar/coordinator.py @@ -37,9 +37,9 @@ def __init__(self, self._first_update = True # Keep track of first update so that we keep boot times down session = async_get_clientsession(hass) - self._api_data_provider = ApiDataProvider(session=session, cache_config=cache_config) - self._theme_data_updater = ThemeDataUpdater(config=special_themes_config, session=session) - self._theme_provider = ThemeDataProvider(theme_path=special_themes_config.path) + self._api_data_provider = ApiDataProvider(hass=hass, session=session, cache_config=cache_config) + self._theme_data_updater = ThemeDataUpdater(hass=hass, config=special_themes_config, session=session) + self._theme_provider = ThemeDataProvider(hass=hass, theme_path=special_themes_config.path) super().__init__( hass, diff --git a/custom_components/swedish_calendar/theme_data.py b/custom_components/swedish_calendar/theme_data.py index c4ad148..c19bd50 100644 --- a/custom_components/swedish_calendar/theme_data.py +++ b/custom_components/swedish_calendar/theme_data.py @@ -1,5 +1,6 @@ import asyncio from datetime import date +from functools import partial import json import logging from typing import Any @@ -7,6 +8,8 @@ import aiohttp import async_timeout +from homeassistant.core import HomeAssistant + from .types import SpecialThemesConfig, ThemeData from .utils import DateUtils @@ -14,10 +17,14 @@ class ThemeDataProvider: - def __init__(self, theme_path): + def __init__(self, hass, theme_path): + self._hass = hass self._theme_path = theme_path async def fetch_data(self, start: date, end: date) -> list[ThemeData]: + return await self._hass.async_add_executor_job(partial(self._fetch_data, start=start, end=end)) + + def _fetch_data(self, start: date, end: date) -> list[ThemeData]: theme_dates = [] try: with open(self._theme_path) as data_file: @@ -44,7 +51,8 @@ def _map_to_theme_dates(json_data: dict[str, Any], start: date, end: date) -> li class ThemeDataUpdater: - def __init__(self, config: SpecialThemesConfig, session: aiohttp.ClientSession): + def __init__(self, hass: HomeAssistant, config: SpecialThemesConfig, session: aiohttp.ClientSession): + self._hass = hass self._config = config self._session = session self._url = 'https://raw.githubusercontent.com/Miicroo/ha-swedish_calendar/master/custom_components' \ @@ -56,9 +64,12 @@ def can_update(self): async def update(self): new_data = await self._download() if new_data: - with open(self._config.path, 'w') as themes_file: - themes_file.write(new_data) - _LOGGER.info('Themes updated with latest json') + await self._hass.async_add_executor_job(partial(self._write_update, new_data=new_data)) + + def _write_update(self, new_data): + with open(self._config.path, 'w') as themes_file: + themes_file.write(new_data) + _LOGGER.info('Themes updated with latest json') async def _download(self) -> str | None: _LOGGER.debug("Downloading latest themes") diff --git a/tests/test_api_data.py b/tests/test_api_data.py index b7dbeff..d27afa6 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -8,10 +8,10 @@ from custom_components.swedish_calendar.types import CacheConfig -def test_api_data_cache_has_data_for_returns_false_if_cache_is_disabled(mocker): +def test_api_data_cache_has_data_for_returns_false_if_cache_is_disabled(mocker, hass): """Disabled cache -> has_data_for returns False.""" config = _cache_is_disabled() # This triggers has_data_for -> False - api_cache = ApiDataCache(cache_config=config) + api_cache = ApiDataCache(hass=hass, cache_config=config) _cache_is_not_old(mocker) _url_is_cached(mocker) @@ -19,10 +19,12 @@ def test_api_data_cache_has_data_for_returns_false_if_cache_is_disabled(mocker): assert api_cache.has_data_for("https://whatever") is False -def test_api_data_cache_has_data_for_returns_false_if_cache_file_does_not_exist(mocker): +def test_api_data_cache_has_data_for_returns_false_if_cache_file_does_not_exist( + mocker, hass +): """Cache (file) does not exist -> has_data_for returns False.""" config = _cache_is_enabled() - api_cache = ApiDataCache(cache_config=config) + api_cache = ApiDataCache(hass=hass, cache_config=config) _cache_is_not_old(mocker) _url_is_not_cached(mocker) # This triggers has_data_for -> False @@ -30,10 +32,10 @@ def test_api_data_cache_has_data_for_returns_false_if_cache_file_does_not_exist( assert api_cache.has_data_for("https://whatever") is False -def test_api_data_cache_has_data_for_returns_false_if_cache_file_is_old(mocker): +def test_api_data_cache_has_data_for_returns_false_if_cache_file_is_old(mocker, hass): """Cache is too old -> has_data_for returns False.""" config = _cache_is_enabled() - api_cache = ApiDataCache(cache_config=config) + api_cache = ApiDataCache(hass=hass, cache_config=config) _cache_is_old(mocker) # This triggers has_data_for -> False _url_is_cached(mocker) @@ -41,10 +43,12 @@ def test_api_data_cache_has_data_for_returns_false_if_cache_file_is_old(mocker): assert api_cache.has_data_for("https://whatever") is False -def test_api_data_cache_has_data_for_returns_true_if_all_conditions_are_true(mocker): +def test_api_data_cache_has_data_for_returns_true_if_all_conditions_are_true( + mocker, hass +): """Cache enabled + cached file exists + cache not too old -> has_data_for returns True.""" config = _cache_is_enabled() - api_cache = ApiDataCache(cache_config=config) + api_cache = ApiDataCache(hass=hass, cache_config=config) _cache_is_not_old(mocker) _url_is_cached(mocker) @@ -52,16 +56,16 @@ def test_api_data_cache_has_data_for_returns_true_if_all_conditions_are_true(moc assert api_cache.has_data_for("https://whatever") is True -def test_api_data_cache_get_removes_file_on_json_decode_error(mocker): +async def test_api_data_cache_get_removes_file_on_json_decode_error(mocker, hass): """Get cached url removes cache entry if json is malformed.""" remove_mock = mocker.patch("os.remove") config = _cache_is_enabled() - api_cache = ApiDataCache(config) + api_cache = ApiDataCache(hass, config) expected_file_path = "ecc5e2d8d57c91749b379bc26d1f677a.json" mock_open = mock.mock_open(read_data="") with mock.patch("builtins.open", mock_open): - result = api_cache.get("https://whatever") + result = await api_cache.get("https://whatever") assert result is None remove_mock.assert_has_calls(calls=[call(expected_file_path)], any_order=True)