diff --git a/aioelectricitymaps/const.py b/aioelectricitymaps/const.py index 66e51a1..659b55e 100644 --- a/aioelectricitymaps/const.py +++ b/aioelectricitymaps/const.py @@ -1,9 +1,12 @@ +"""Constants for aioelectricitymaps.""" API_BASE_URL = "https://api-access.electricitymaps.com/free-tier/" LEGACY_API_BASE_URL = "https://api.co2signal.com/v1/" class ApiEndpoints: + """Class holding API endpoints.""" + LEGACY_CARBON_INTENSITY = LEGACY_API_BASE_URL + "latest" CARBON_INTENSITY = API_BASE_URL + "home-assistant" ZONES = "https://api.electricitymap.org/v3/zones" diff --git a/aioelectricitymaps/decorators.py b/aioelectricitymaps/decorators.py index 75e76b9..f08b12c 100644 --- a/aioelectricitymaps/decorators.py +++ b/aioelectricitymaps/decorators.py @@ -13,7 +13,7 @@ def retry_legacy( func: Callable[_P, Coroutine[Any, Any, _R]], ) -> Callable[_P, Coroutine[Any, Any, _R]]: - """Decorator to retry a function with the legacy API if SwitchedToLegacyAPI is raised.""" + """Retry a function with the legacy API if SwitchedToLegacyAPI is raised.""" async def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: try: diff --git a/aioelectricitymaps/electricitymaps.py b/aioelectricitymaps/electricitymaps.py index 4228fb8..fa4ac98 100644 --- a/aioelectricitymaps/electricitymaps.py +++ b/aioelectricitymaps/electricitymaps.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import json import logging -from typing import Any +from typing import Any, Self from aiohttp import ClientSession @@ -24,6 +24,8 @@ @dataclass class ElectricityMaps: + """ElectricityMaps API client.""" + token: str session: ClientSession | None = None @@ -49,38 +51,40 @@ async def _get(self, url: str, params: dict[str, Any] | None = None) -> Any: ) as response: parsed = await response.json() except json.JSONDecodeError as exception: + msg = f"JSON decoding failed: {exception}" raise ElectricityMapsDecodeError( - f"JSON decoding failed: {exception}", + msg, ) from exception except Exception as exc: + msg = f"Unknown error occurred while fetching data: {exc}" raise ElectricityMapsError( - f"Unknown error occurred while fetching data: {exc}", + msg, ) from exc - else: - _LOGGER.debug( - "Got response with status %s and body: %s", - response.status, - await response.text(), - ) - # check for invalid token - if ( - "message" in parsed - and response.status == 404 - and ( - "No data product found" in parsed["message"] - or "Invalid authentication" in parsed["message"] + _LOGGER.debug( + "Got response with status %s and body: %s", + response.status, + await response.text(), + ) + + # check for invalid token + if ( + "message" in parsed + and response.status == 404 + and ( + "No data product found" in parsed["message"] + or "Invalid authentication" in parsed["message"] + ) + ): + # enable legacy mode and let the function recalled by the decorator + if not self._is_legacy_token: + _LOGGER.debug( + "Detected invalid token on new API, retrying on legacy API.", ) - ): - # enable legacy mode and let the function recalled by the decorator - if not self._is_legacy_token: - _LOGGER.debug( - "Detected invalid token on new API, retrying on legacy API.", - ) - self._is_legacy_token = True - raise SwitchedToLegacyAPI + self._is_legacy_token = True + raise SwitchedToLegacyAPI - raise InvalidToken + raise InvalidToken return parsed @@ -128,7 +132,7 @@ async def close(self) -> None: if self.session and self._close_session: await self.session.close() - async def __aenter__(self) -> ElectricityMaps: + async def __aenter__(self) -> Self: """Async enter.""" return self diff --git a/aioelectricitymaps/exceptions.py b/aioelectricitymaps/exceptions.py index 2ed706a..07e083e 100644 --- a/aioelectricitymaps/exceptions.py +++ b/aioelectricitymaps/exceptions.py @@ -6,7 +6,10 @@ class ElectricityMapsError(Exception): class SwitchedToLegacyAPI(ElectricityMapsError): - """""" + """Error raised when API switched to legacy. + + Caught by retry_legacy decorator. + """ class InvalidToken(ElectricityMapsError): diff --git a/aioelectricitymaps/marshmallow.py b/aioelectricitymaps/marshmallow.py index e7617b1..3df8023 100644 --- a/aioelectricitymaps/marshmallow.py +++ b/aioelectricitymaps/marshmallow.py @@ -1,4 +1,6 @@ """Module contains classes for de-/serialisation with marshmallow.""" +from __future__ import annotations + from dataclasses import dataclass, field from dataclasses_json import DataClassJsonMixin, config @@ -9,7 +11,7 @@ @dataclass(slots=True, frozen=True) class ZoneList(dict[str, Zone], DataClassJsonMixin): - """List of zones""" + """List of zones.""" zones: dict[str, Zone] = field( metadata=config( diff --git a/aioelectricitymaps/models.py b/aioelectricitymaps/models.py index 1b12799..ad6b1e5 100644 --- a/aioelectricitymaps/models.py +++ b/aioelectricitymaps/models.py @@ -1,4 +1,6 @@ """Models to the electricitymaps.com API.""" +from __future__ import annotations + from dataclasses import dataclass, field from dataclasses_json import DataClassJsonMixin, LetterCase, config diff --git a/pyproject.toml b/pyproject.toml index 46891f5..42beccc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ disable = [ "too-many-instance-attributes", "too-many-arguments", "too-many-public-methods", + "too-few-public-methods", "wrong-import-order", ] @@ -128,6 +129,8 @@ ignore = [ "PLR0913", # Too many arguments "TCH001", "TCH003", + "BLE001", # disable temporarily + "N818", # disable temporarily ] select = ["ALL"] diff --git a/tests/__init__.py b/tests/__init__.py index a195af2..5e9cf4e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,9 +1,8 @@ """Helpers for the tests.""" -import os +from pathlib import Path -def load_fixture(filename): +def load_fixture(filename: str) -> str: """Load a fixture.""" - path = os.path.join(os.path.dirname(__file__), "fixtures", filename) - with open(path, encoding="utf-8") as fptr: - return fptr.read() + path = Path(__file__).parent / "fixtures" / filename + return path.read_text() diff --git a/tests/conftest.py b/tests/conftest.py index 86a6ca4..b286117 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,13 @@ +"""Fixtures for aioelectricitymaps tests.""" from aresponses import ResponsesMockServer import pytest from . import load_fixture -@pytest.fixture -def mock_response(aresponses: ResponsesMockServer) -> None: +@pytest.fixture(name="mock_response") +def _mock_response(aresponses: ResponsesMockServer) -> None: + """Mock an API response.""" aresponses.add( "api-access.electricitymaps.com", "/free-tier/home-assistant", @@ -18,8 +20,9 @@ def mock_response(aresponses: ResponsesMockServer) -> None: ) -@pytest.fixture -def mock_broken_response(aresponses: ResponsesMockServer) -> None: +@pytest.fixture(name="mock_broken_response") +def _mock_broken_response(aresponses: ResponsesMockServer) -> None: + """Mock a bad API response.""" aresponses.add( "api-access.electricitymaps.com", "/free-tier/home-assistant", diff --git a/tests/ruff.toml b/tests/ruff.toml new file mode 100644 index 0000000..1257561 --- /dev/null +++ b/tests/ruff.toml @@ -0,0 +1,15 @@ +# This extend our general Ruff rules specifically for tests +extend = "../pyproject.toml" + +extend-select = [ + "PT", # Use @pytest.fixture without parentheses +] + +extend-ignore = [ + "S101", # Use of assert detected. As these are tests... + "S105", # Detection of passwords... + "S106", # Detection of passwords... + "SLF001", # Tests will access private/protected members... + "TCH002", # pytest doesn't like this one... + "PLR0913", # we're overwriting function that has many arguments +] diff --git a/tests/test_electricitymaps.py b/tests/test_electricitymaps.py index 1c1ca8a..050e288 100644 --- a/tests/test_electricitymaps.py +++ b/tests/test_electricitymaps.py @@ -4,39 +4,41 @@ import aiohttp from aresponses import ResponsesMockServer import pytest +from syrupy.assertion import SnapshotAssertion from aioelectricitymaps import ElectricityMaps from aioelectricitymaps.exceptions import ( ElectricityMapsDecodeError, ElectricityMapsError, ) -from tests import load_fixture +from . import load_fixture -@pytest.mark.asyncio -async def test_asyncio_protocol(mock_response) -> None: + +@pytest.mark.usefixtures("mock_response") +async def test_asyncio_protocol() -> None: """Test the asyncio protocol implementation.""" async with ElectricityMaps(token="abc123") as em: assert await em.latest_carbon_intensity_by_country_code("DE") -@pytest.mark.asyncio -async def test_json_request_without_session(mock_response, snapshot) -> None: +@pytest.mark.usefixtures("mock_response") +async def test_json_request_without_session(snapshot: SnapshotAssertion) -> None: """Test JSON response is handled correctly without given session.""" em = ElectricityMaps(token="abc123") assert await em.latest_carbon_intensity_by_country_code("DE") == snapshot -@pytest.mark.asyncio -async def test_json_request_with_session(mock_response, snapshot) -> None: +@pytest.mark.usefixtures("mock_response") +async def test_json_request_with_session(snapshot: SnapshotAssertion) -> None: """Test JSON response is handled correctly with given session.""" async with aiohttp.ClientSession() as session: em = ElectricityMaps(token="abc123", session=session) assert await em.latest_carbon_intensity_by_country_code("DE") == snapshot -@pytest.mark.asyncio -async def test_carbon_intensity_by_coordinates(mock_response, snapshot) -> None: +@pytest.mark.usefixtures("mock_response") +async def test_carbon_intensity_by_coordinates(snapshot: SnapshotAssertion) -> None: """Test carbon_intentsity_by_coordinates with given session.""" async with aiohttp.ClientSession() as session: em = ElectricityMaps(token="abc123", session=session) @@ -49,8 +51,8 @@ async def test_carbon_intensity_by_coordinates(mock_response, snapshot) -> None: ) -@pytest.mark.asyncio -async def test_broken_json_request(mock_broken_response) -> None: +@pytest.mark.usefixtures("mock_broken_response") +async def test_broken_json_request() -> None: """Test JSON response is handled correctly with given session.""" async with aiohttp.ClientSession() as session: em = ElectricityMaps(token="abc123", session=session) @@ -59,7 +61,6 @@ async def test_broken_json_request(mock_broken_response) -> None: await em.latest_carbon_intensity_by_country_code("DE") -@pytest.mark.asyncio async def test_catching_unknown_error() -> None: """Test JSON response is handled correctly with given session.""" async with aiohttp.ClientSession() as session: @@ -70,8 +71,11 @@ async def test_catching_unknown_error() -> None: await em.latest_carbon_intensity_by_country_code("DE") -@pytest.mark.asyncio -async def test_zones_request(aresponses: ResponsesMockServer, snapshot) -> None: +async def test_zones_request( + aresponses: ResponsesMockServer, + snapshot: SnapshotAssertion, +) -> None: + """Test zones request.""" aresponses.add( "api.electricitymap.org", "/v3/zones",