Skip to content

Commit

Permalink
Call open asynchronously (#42)
Browse files Browse the repository at this point in the history
* Call open asynchronously
  • Loading branch information
Miicroo authored Aug 17, 2024
1 parent 5fa5797 commit eaf6e0a
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 26 deletions.
24 changes: 17 additions & 7 deletions custom_components/swedish_calendar/api_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from collections import deque
from datetime import date, datetime, timedelta, timezone
from functools import partial
import hashlib
import json
import logging
Expand All @@ -10,17 +11,19 @@
import aiohttp
import async_timeout

from homeassistant.core import HomeAssistant

from .types import ApiData, CacheConfig
from .utils import DateUtils

_LOGGER = logging.getLogger(__name__)


class ApiDataProvider:
def __init__(self, session: aiohttp.ClientSession, cache_config: CacheConfig):
def __init__(self, hass: HomeAssistant, session: aiohttp.ClientSession, cache_config: CacheConfig):
self._base_url: str = 'https://sholiday.faboul.se/dagar/v2.1/'
self._session = session
self._cache = ApiDataCache(cache_config)
self._cache = ApiDataCache(hass, cache_config)

async def fetch_data(self, start: date, end: date) -> list[ApiData]:
urls = deque(self._get_urls(start, end))
Expand Down Expand Up @@ -74,10 +77,10 @@ def _get_url_patterns_for_date_range(start: date, end: date) -> list[str]:
async def _get_json_from_url(self, url, timeout) -> dict[str, Any]:
if self._cache.has_data_for(url):
_LOGGER.debug("Using cached version of url: %s", url)
data = self._cache.get(url)
data = await self._cache.get(url)
else:
data = await self._get_data_online(url, timeout)
self._cache.update(url, data)
await self._cache.update(url, data)

return data

Expand All @@ -101,7 +104,8 @@ def _to_api_data(json_response: dict[str, Any], start: date, end: date) -> list[


class ApiDataCache:
def __init__(self, cache_config: CacheConfig):
def __init__(self, hass: HomeAssistant, cache_config: CacheConfig):
self._hass = hass
self.config = cache_config

def has_data_for(self, url: str) -> bool:
Expand All @@ -121,7 +125,10 @@ def _cache_age(self, url) -> timedelta:
now_in_utc = datetime.now().astimezone(tz=timezone.utc)
return now_in_utc - cache_in_utc

def get(self, url) -> dict[str, Any] | None:
async def get(self, url) -> dict[str, Any] | None:
return await self._hass.async_add_executor_job(partial(self._get, url=url))

def _get(self, url) -> dict[str, Any] | None:
path = self._url_to_path(url)
data = None
with open(path) as cached_file:
Expand All @@ -133,7 +140,10 @@ def get(self, url) -> dict[str, Any] | None:

return data

def update(self, url, data: dict[str, Any]) -> None:
async def update(self, url, data: dict[str, Any]) -> None:
return await self._hass.async_add_executor_job(partial(self._update, url=url, data=data))

def _update(self, url, data: dict[str, Any]) -> None:
if self.config.enabled:
path = self._url_to_path(url)
_LOGGER.debug("Caching %s, saving to %s", url, path)
Expand Down
6 changes: 3 additions & 3 deletions custom_components/swedish_calendar/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def __init__(self,
self._first_update = True # Keep track of first update so that we keep boot times down

session = async_get_clientsession(hass)
self._api_data_provider = ApiDataProvider(session=session, cache_config=cache_config)
self._theme_data_updater = ThemeDataUpdater(config=special_themes_config, session=session)
self._theme_provider = ThemeDataProvider(theme_path=special_themes_config.path)
self._api_data_provider = ApiDataProvider(hass=hass, session=session, cache_config=cache_config)
self._theme_data_updater = ThemeDataUpdater(hass=hass, config=special_themes_config, session=session)
self._theme_provider = ThemeDataProvider(hass=hass, theme_path=special_themes_config.path)

super().__init__(
hass,
Expand Down
21 changes: 16 additions & 5 deletions custom_components/swedish_calendar/theme_data.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
import asyncio
from datetime import date
from functools import partial
import json
import logging
from typing import Any

import aiohttp
import async_timeout

from homeassistant.core import HomeAssistant

from .types import SpecialThemesConfig, ThemeData
from .utils import DateUtils

_LOGGER = logging.getLogger(__name__)


class ThemeDataProvider:
def __init__(self, theme_path):
def __init__(self, hass, theme_path):
self._hass = hass
self._theme_path = theme_path

async def fetch_data(self, start: date, end: date) -> list[ThemeData]:
return await self._hass.async_add_executor_job(partial(self._fetch_data, start=start, end=end))

def _fetch_data(self, start: date, end: date) -> list[ThemeData]:
theme_dates = []
try:
with open(self._theme_path) as data_file:
Expand All @@ -44,7 +51,8 @@ def _map_to_theme_dates(json_data: dict[str, Any], start: date, end: date) -> li


class ThemeDataUpdater:
def __init__(self, config: SpecialThemesConfig, session: aiohttp.ClientSession):
def __init__(self, hass: HomeAssistant, config: SpecialThemesConfig, session: aiohttp.ClientSession):
self._hass = hass
self._config = config
self._session = session
self._url = 'https://raw.githubusercontent.com/Miicroo/ha-swedish_calendar/master/custom_components' \
Expand All @@ -56,9 +64,12 @@ def can_update(self):
async def update(self):
new_data = await self._download()
if new_data:
with open(self._config.path, 'w') as themes_file:
themes_file.write(new_data)
_LOGGER.info('Themes updated with latest json')
await self._hass.async_add_executor_job(partial(self._write_update, new_data=new_data))

def _write_update(self, new_data):
with open(self._config.path, 'w') as themes_file:
themes_file.write(new_data)
_LOGGER.info('Themes updated with latest json')

async def _download(self) -> str | None:
_LOGGER.debug("Downloading latest themes")
Expand Down
26 changes: 15 additions & 11 deletions tests/test_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,64 @@
from custom_components.swedish_calendar.types import CacheConfig


def test_api_data_cache_has_data_for_returns_false_if_cache_is_disabled(mocker):
def test_api_data_cache_has_data_for_returns_false_if_cache_is_disabled(mocker, hass):
"""Disabled cache -> has_data_for returns False."""
config = _cache_is_disabled() # This triggers has_data_for -> False
api_cache = ApiDataCache(cache_config=config)
api_cache = ApiDataCache(hass=hass, cache_config=config)

_cache_is_not_old(mocker)
_url_is_cached(mocker)

assert api_cache.has_data_for("https://whatever") is False


def test_api_data_cache_has_data_for_returns_false_if_cache_file_does_not_exist(mocker):
def test_api_data_cache_has_data_for_returns_false_if_cache_file_does_not_exist(
mocker, hass
):
"""Cache (file) does not exist -> has_data_for returns False."""
config = _cache_is_enabled()
api_cache = ApiDataCache(cache_config=config)
api_cache = ApiDataCache(hass=hass, cache_config=config)

_cache_is_not_old(mocker)
_url_is_not_cached(mocker) # This triggers has_data_for -> False

assert api_cache.has_data_for("https://whatever") is False


def test_api_data_cache_has_data_for_returns_false_if_cache_file_is_old(mocker):
def test_api_data_cache_has_data_for_returns_false_if_cache_file_is_old(mocker, hass):
"""Cache is too old -> has_data_for returns False."""
config = _cache_is_enabled()
api_cache = ApiDataCache(cache_config=config)
api_cache = ApiDataCache(hass=hass, cache_config=config)

_cache_is_old(mocker) # This triggers has_data_for -> False
_url_is_cached(mocker)

assert api_cache.has_data_for("https://whatever") is False


def test_api_data_cache_has_data_for_returns_true_if_all_conditions_are_true(mocker):
def test_api_data_cache_has_data_for_returns_true_if_all_conditions_are_true(
mocker, hass
):
"""Cache enabled + cached file exists + cache not too old -> has_data_for returns True."""
config = _cache_is_enabled()
api_cache = ApiDataCache(cache_config=config)
api_cache = ApiDataCache(hass=hass, cache_config=config)

_cache_is_not_old(mocker)
_url_is_cached(mocker)

assert api_cache.has_data_for("https://whatever") is True


def test_api_data_cache_get_removes_file_on_json_decode_error(mocker):
async def test_api_data_cache_get_removes_file_on_json_decode_error(mocker, hass):
"""Get cached url removes cache entry if json is malformed."""
remove_mock = mocker.patch("os.remove")
config = _cache_is_enabled()
api_cache = ApiDataCache(config)
api_cache = ApiDataCache(hass, config)
expected_file_path = "ecc5e2d8d57c91749b379bc26d1f677a.json"

mock_open = mock.mock_open(read_data="<html>")
with mock.patch("builtins.open", mock_open):
result = api_cache.get("https://whatever")
result = await api_cache.get("https://whatever")

assert result is None
remove_mock.assert_has_calls(calls=[call(expected_file_path)], any_order=True)
Expand Down

0 comments on commit eaf6e0a

Please sign in to comment.