Skip to content

Commit

Permalink
fix: missing entry name for location lookup config flow (#6)
Browse files Browse the repository at this point in the history
* fix: missing entry name for location lookup config flow

* add missing requirements_test.txt

* formatting
  • Loading branch information
firstof9 authored Dec 10, 2023
1 parent ed7068b commit bda9286
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 46 deletions.
21 changes: 16 additions & 5 deletions custom_components/gasbuddy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from gasbuddy import GasBuddy
from gasbuddy.exceptions import APIError, LibraryError

from .const import CONF_INTERVAL, CONF_STATION_ID, COORDINATOR, DOMAIN, ISSUE_URL, PLATFORMS, VERSION
from .const import (
CONF_INTERVAL,
CONF_STATION_ID,
COORDINATOR,
DOMAIN,
ISSUE_URL,
PLATFORMS,
VERSION,
)

_LOGGER = logging.getLogger(__name__)

Expand All @@ -25,6 +33,7 @@ async def async_setup( # pylint: disable-next=unused-argument
"""Disallow configuration via YAML."""
return True


async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Set up is called when Home Assistant is loading our component."""
hass.data.setdefault(DOMAIN, {})
Expand All @@ -44,8 +53,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
if not coordinator.last_update_success:
raise ConfigEntryNotReady

hass.data[DOMAIN][config_entry.entry_id] = { COORDINATOR: coordinator }

hass.data[DOMAIN][config_entry.entry_id] = {COORDINATOR: coordinator}

for platform in PLATFORMS:
hass.async_create_task(
Expand All @@ -54,6 +62,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b

return True


async def update_listener(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Update listener."""
_LOGGER.debug("Attempting to reload entities from the %s integration", DOMAIN)
Expand All @@ -71,6 +80,7 @@ async def update_listener(hass: HomeAssistant, config_entry: ConfigEntry) -> Non

await hass.config_entries.async_reload(config_entry.entry_id)


async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Handle removal of an entry."""
_LOGGER.debug("Attempting to unload entities from the %s integration", DOMAIN)
Expand All @@ -88,7 +98,8 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
_LOGGER.debug("Successfully removed entities from the %s integration", DOMAIN)
hass.data[DOMAIN].pop(config_entry.entry_id)

return unload_ok
return unload_ok


class GasBuddyUpdateCoordinator(DataUpdateCoordinator):
"""Class to manage fetching data from the API."""
Expand Down Expand Up @@ -117,5 +128,5 @@ async def _async_update_data(self) -> dict:
self._data = {}
except Exception as exception:
raise UpdateFailed() from exception

return self._data
91 changes: 60 additions & 31 deletions custom_components/gasbuddy/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@

import gasbuddy

from .const import CONF_NAME, CONF_INTERVAL, CONF_POSTAL, CONF_STATION_ID, DEFAULT_NAME, DOMAIN
from .const import (
CONF_NAME,
CONF_INTERVAL,
CONF_POSTAL,
CONF_STATION_ID,
DEFAULT_NAME,
DOMAIN,
)

_LOGGER = logging.getLogger(__name__)
MENU_OPTIONS = ["manual", "search"]
MENU_SEARCH = ["home", "postal"]


async def validate_station(station: int) -> bool:
"""Validate statation ID."""
check = await gasbuddy.GasBuddy(station_id=station).price_lookup()
Expand All @@ -27,6 +35,7 @@ async def validate_station(station: int) -> bool:
return False
return True


async def _get_station_list(hass, user_input) -> list | None:
"""Return list of utilities by lat/lon."""
lat = None
Expand All @@ -35,7 +44,6 @@ async def _get_station_list(hass, user_input) -> list | None:

if user_input is not None and CONF_POSTAL in user_input.keys():
postal = user_input[CONF_POSTAL]


if not bool(postal):
lat = hass.config.latitude
Expand All @@ -56,6 +64,7 @@ async def _get_station_list(hass, user_input) -> list | None:
_LOGGER.debug("stations_list: %s", stations_list)
return stations_list


def _get_schema_manual(hass: Any, user_input: list, default_dict: list) -> Any:
"""Gets a schema using the default_dict as a backup."""
if user_input is None:
Expand All @@ -72,7 +81,10 @@ def _get_default(key: str, fallback_default: Any = None) -> Any | None:
}
)

def _get_schema_home(hass: Any, user_input: list, default_dict: list, station_list: list) -> Any:

def _get_schema_home(
hass: Any, user_input: list, default_dict: list, station_list: list
) -> Any:
"""Gets a schema using the default_dict as a backup."""
if user_input is None:
user_input = {}
Expand All @@ -83,10 +95,14 @@ def _get_default(key: str, fallback_default: Any = None) -> Any | None:

return vol.Schema(
{
vol.Required(CONF_STATION_ID, default=_get_default(CONF_STATION_ID)): vol.In(station_list),
vol.Required(
CONF_STATION_ID, default=_get_default(CONF_STATION_ID)
): vol.In(station_list),
vol.Required(CONF_NAME, default=_get_default(CONF_NAME, DEFAULT_NAME)): str,
}
)


def _get_schema_postal(hass: Any, user_input: list, default_dict: list) -> Any:
"""Gets a schema using the default_dict as a backup."""
if user_input is None:
Expand All @@ -99,10 +115,14 @@ def _get_default(key: str, fallback_default: Any = None) -> Any | None:
return vol.Schema(
{
vol.Required(CONF_POSTAL, default=_get_default(CONF_POSTAL)): str,
vol.Required(CONF_NAME, default=_get_default(CONF_NAME, DEFAULT_NAME)): str,
}
)

def _get_schema_postal_list(hass: Any, user_input: list, default_dict: list, station_list: list) -> Any:

def _get_schema_postal_list(
hass: Any, user_input: list, default_dict: list, station_list: list
) -> Any:
"""Gets a schema using the default_dict as a backup."""
if user_input is None:
user_input = {}
Expand All @@ -113,10 +133,13 @@ def _get_default(key: str, fallback_default: Any = None) -> Any | None:

return vol.Schema(
{
vol.Required(CONF_STATION_ID, default=_get_default(CONF_STATION_ID)): vol.In(station_list),
vol.Required(
CONF_STATION_ID, default=_get_default(CONF_STATION_ID)
): vol.In(station_list),
}
)


def _get_schema_options(hass: Any, user_input: list, default_dict: list) -> Any:
"""Gets a schema using the default_dict as a backup."""
if user_input is None:
Expand All @@ -132,6 +155,7 @@ def _get_default(key: str, fallback_default: Any = None) -> Any | None:
}
)


@config_entries.HANDLERS.register(DOMAIN)
class GasBuddyFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Config flow for GasBuddy."""
Expand Down Expand Up @@ -162,9 +186,11 @@ async def async_step_manual(self, user_input={}):
self._errors[CONF_STATION_ID] = "station_id"
else:
self._data.update(user_input)
return self.async_create_entry(title=self._data[CONF_NAME], data=self._data)
return await self._show_config_manual(user_input)

return self.async_create_entry(
title=self._data[CONF_NAME], data=self._data
)
return await self._show_config_manual(user_input)

async def _show_config_manual(self, user_input):
"""Show the configuration form to edit location data."""

Expand All @@ -177,15 +203,15 @@ async def _show_config_manual(self, user_input):
step_id="manual",
data_schema=_get_schema_manual(self.hass, user_input, defaults),
errors=self._errors,
)
)

# Search option
async def async_step_search(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the flow initialized by the user."""
return self.async_show_menu(step_id="search", menu_options=MENU_SEARCH)
return self.async_show_menu(step_id="search", menu_options=MENU_SEARCH)

# Use lat/lon from HA
async def async_step_home(self, user_input={}):
"""Handle a flow initialized by the user."""
Expand All @@ -196,8 +222,8 @@ async def async_step_home(self, user_input={}):
user_input[CONF_INTERVAL] = 3600
self._data.update(user_input)
return self.async_create_entry(title=self._data[CONF_NAME], data=self._data)
return await self._show_config_home(user_input)
return await self._show_config_home(user_input)

async def _show_config_home(self, user_input):
"""Show the configuration form to edit location data."""
defaults = {}
Expand All @@ -208,8 +234,8 @@ async def _show_config_home(self, user_input):
step_id="home",
data_schema=_get_schema_home(self.hass, user_input, defaults, station_list),
errors=self._errors,
)
)

# User input postal code
async def async_step_postal(self, user_input={}):
"""Handle a flow initialized by the user."""
Expand All @@ -218,8 +244,8 @@ async def async_step_postal(self, user_input={}):
if user_input is not None:
self._data.update(user_input)
return await self.async_step_postal_list()
return await self._show_config_postal(user_input)
return await self._show_config_postal(user_input)

async def _show_config_postal(self, user_input):
"""Show the configuration form to edit location data."""
defaults = {}
Expand All @@ -228,8 +254,8 @@ async def _show_config_postal(self, user_input):
step_id="postal",
data_schema=_get_schema_postal(self.hass, user_input, defaults),
errors=self._errors,
)
)

async def async_step_postal_list(self, user_input={}):
"""Handle a flow initialized by the user."""
self._errors = {}
Expand All @@ -239,8 +265,8 @@ async def async_step_postal_list(self, user_input={}):
user_input[CONF_INTERVAL] = 3600
self._data.update(user_input)
return self.async_create_entry(title=self._data[CONF_NAME], data=self._data)
return await self._show_config_postal_list(user_input)
return await self._show_config_postal_list(user_input)

async def _show_config_postal_list(self, user_input):
"""Show the configuration form to edit location data."""
defaults = {}
Expand All @@ -249,17 +275,20 @@ async def _show_config_postal_list(self, user_input):

return self.async_show_form(
step_id="postal_list",
data_schema=_get_schema_postal_list(self.hass, user_input, defaults, station_list),
data_schema=_get_schema_postal_list(
self.hass, user_input, defaults, station_list
),
errors=self._errors,
)
)

@staticmethod
@callback
def async_get_options_flow(config_entry):
return GasBuddyOptionsFlow(config_entry)

return GasBuddyOptionsFlow(config_entry)


class GasBuddyOptionsFlow(config_entries.OptionsFlow):
"""Options flow for GasBuddy."""
"""Options flow for GasBuddy."""

def __init__(self, config_entry):
"""Initialize."""
Expand All @@ -273,11 +302,11 @@ async def async_step_init(self, user_input=None):
self._data.update(user_input)
return self.async_create_entry(title="", data=self._data)
return await self._show_options_form(user_input)

async def _show_options_form(self, user_input):
"""Show the configuration form to edit options."""
return self.async_show_form(
step_id="init",
data_schema=_get_schema_options(self.hass, user_input, self._data),
errors=self._errors,
)
)
2 changes: 1 addition & 1 deletion custom_components/gasbuddy/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
DOMAIN = "gasbuddy"
VERSION = "1.0"
ISSUE_URL = "https://github.com/firstof9/ha-gasbuddy/issues"
PLATFORMS = ['sensor']
PLATFORMS = ["sensor"]

# sensor constants
UNIT_OF_MEASURE = {
Expand Down
17 changes: 13 additions & 4 deletions custom_components/gasbuddy/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
from homeassistant.config_entries import ConfigEntry
from homeassistant.helpers.update_coordinator import CoordinatorEntity

from .const import CONF_NAME, CONF_STATION_ID, COORDINATOR, DOMAIN, SENSOR_TYPES, UNIT_OF_MEASURE
from .const import (
CONF_NAME,
CONF_STATION_ID,
COORDINATOR,
DOMAIN,
SENSOR_TYPES,
UNIT_OF_MEASURE,
)

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(hass, entry, async_add_entities):
"""Set up the GasBuddy sensors."""
coordinator = hass.data[DOMAIN][entry.entry_id][COORDINATOR]
Expand All @@ -26,6 +34,7 @@ async def async_setup_entry(hass, entry, async_add_entities):

async_add_entities(sensors, False)


class GasBuddySensor(CoordinatorEntity, SensorEntity):
"""Implementation of a GasBuddy sensor."""

Expand Down Expand Up @@ -80,15 +89,15 @@ def native_unit_of_measurement(self) -> Any:
uom = self.coordinator.data["unit_of_measure"]
currency = self.coordinator.data["currency"]
if uom is not None and currency is not None:
return f'{currency}/{UNIT_OF_MEASURE[uom]}'
return f"{currency}/{UNIT_OF_MEASURE[uom]}"
return None

@property
def extra_state_attributes(self) -> Optional[dict]:
"""Return sesnsor attributes."""
credit = self.coordinator.data[self._type]["credit"]
attrs = {}
attrs[ATTR_ATTRIBUTION] = f'{credit} via GasBuddy'
attrs[ATTR_ATTRIBUTION] = f"{credit} via GasBuddy"
attrs["last_updated"] = self.coordinator.data[self._type]["last_updated"]
attrs[CONF_STATION_ID] = self.coordinator.data[CONF_STATION_ID]
return attrs
Expand All @@ -109,4 +118,4 @@ def available(self) -> bool:
@property
def should_poll(self) -> bool:
"""No need to poll. Coordinator notifies entity of updates."""
return False
return False
10 changes: 10 additions & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-r requirements.txt
pytest-homeassistant-custom-component
black==23.11.0
flake8==6.1.0
mypy==1.7.1
pydocstyle==6.3.0
isort==5.12.0
pylint==3.0.2
tox==4.11.4
pytest
Loading

0 comments on commit bda9286

Please sign in to comment.