Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unexpected exception in Google Calendar OAuth exchange #73963

Merged
merged 4 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions homeassistant/components/google/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

from collections.abc import Awaitable, Callable
import datetime
import logging
from typing import Any, cast
Expand All @@ -19,9 +18,12 @@

from homeassistant.components.application_credentials import AuthImplementation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_time_interval,
)
from homeassistant.util import dt

from .const import (
Expand Down Expand Up @@ -76,6 +78,9 @@ def __init__(
self._oauth_flow = oauth_flow
self._device_flow_info: DeviceFlowInfo = device_flow_info
self._exchange_task_unsub: CALLBACK_TYPE | None = None
self._timeout_unsub: CALLBACK_TYPE | None = None
self._listener: CALLBACK_TYPE | None = None
self._creds: Credentials | None = None

@property
def verification_url(self) -> str:
Expand All @@ -87,15 +92,22 @@ def user_code(self) -> str:
"""Return the code that the user should enter at the verification url."""
return self._device_flow_info.user_code # type: ignore[no-any-return]

async def start_exchange_task(
self, finished_cb: Callable[[Credentials | None], Awaitable[None]]
@callback
def async_set_listener(
self,
update_callback: CALLBACK_TYPE,
) -> None:
"""Start the device auth exchange flow polling.
"""Invoke the update callback when the exchange finishes or on timeout."""
self._listener = update_callback

The callback is invoked with the valid credentials or with None on timeout.
"""
@property
def creds(self) -> Credentials | None:
"""Return result of exchange step or None on timeout."""
return self._creds

def async_start_exchange(self) -> None:
"""Start the device auth exchange flow polling."""
_LOGGER.debug("Starting exchange flow")
assert not self._exchange_task_unsub
max_timeout = dt.utcnow() + datetime.timedelta(seconds=EXCHANGE_TIMEOUT_SECONDS)
# For some reason, oauth.step1_get_device_and_user_codes() returns a datetime
# object without tzinfo. For the comparison below to work, it needs one.
Expand All @@ -104,31 +116,40 @@ async def start_exchange_task(
)
expiration_time = min(user_code_expiry, max_timeout)

def _exchange() -> Credentials:
return self._oauth_flow.step2_exchange(
device_flow_info=self._device_flow_info
)

async def _poll_attempt(now: datetime.datetime) -> None:
assert self._exchange_task_unsub
_LOGGER.debug("Attempting OAuth code exchange")
# Note: The callback is invoked with None when the device code has expired
creds: Credentials | None = None
if now < expiration_time:
try:
creds = await self._hass.async_add_executor_job(_exchange)
except FlowExchangeError:
_LOGGER.debug("Token not yet ready; trying again later")
return
self._exchange_task_unsub()
self._exchange_task_unsub = None
await finished_cb(creds)

self._exchange_task_unsub = async_track_time_interval(
self._hass,
_poll_attempt,
self._async_poll_attempt,
datetime.timedelta(seconds=self._device_flow_info.interval),
)
self._timeout_unsub = async_track_point_in_utc_time(
self._hass, self._async_timeout, expiration_time
)

async def _async_poll_attempt(self, now: datetime.datetime) -> None:
_LOGGER.debug("Attempting OAuth code exchange")
try:
self._creds = await self._hass.async_add_executor_job(self._exchange)
except FlowExchangeError:
_LOGGER.debug("Token not yet ready; trying again later")
return
self._finish()

def _exchange(self) -> Credentials:
return self._oauth_flow.step2_exchange(device_flow_info=self._device_flow_info)

@callback
def _async_timeout(self, now: datetime.datetime) -> None:
_LOGGER.debug("OAuth token exchange timeout")
self._finish()

@callback
def _finish(self) -> None:
if self._exchange_task_unsub:
self._exchange_task_unsub()
if self._timeout_unsub:
self._timeout_unsub()
if self._listener:
self._listener()


def get_feature_access(
Expand Down
8 changes: 4 additions & 4 deletions homeassistant/components/google/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from gcal_sync.api import GoogleCalendarService
from gcal_sync.exceptions import ApiException
from oauth2client.client import Credentials
import voluptuous as vol

from homeassistant import config_entries
Expand Down Expand Up @@ -96,17 +95,18 @@ async def async_step_auth(
return self.async_abort(reason="oauth_error")
self._device_flow = device_flow

async def _exchange_finished(creds: Credentials | None) -> None:
def _exchange_finished() -> None:
self.external_data = {
DEVICE_AUTH_CREDS: creds
DEVICE_AUTH_CREDS: device_flow.creds
} # is None on timeout/expiration
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(
flow_id=self.flow_id, user_input={}
)
)

await device_flow.start_exchange_task(_exchange_finished)
device_flow.async_set_listener(_exchange_finished)
device_flow.async_start_exchange()

return self.async_show_progress(
step_id="auth",
Expand Down
50 changes: 31 additions & 19 deletions tests/components/google/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiohttp.client_exceptions import ClientError
from freezegun.api import FrozenDateTimeFactory
from oauth2client.client import (
DeviceFlowInfo,
FlowExchangeError,
OAuth2Credentials,
OAuth2DeviceCodeError,
Expand Down Expand Up @@ -59,18 +60,26 @@ async def mock_code_flow(
) -> YieldFixture[Mock]:
"""Fixture for initiating OAuth flow."""
with patch(
"oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes",
"homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes",
) as mock_flow:
mock_flow.return_value.user_code_expiry = utcnow() + code_expiration_delta
mock_flow.return_value.interval = CODE_CHECK_INTERVAL
mock_flow.return_value = DeviceFlowInfo.FromResponse(
{
"device_code": "4/4-GMMhmHCXhWEzkobqIHGG_EnNYYsAkukHspeYUk9E8",
"user_code": "GQVQ-JKEC",
"verification_url": "https://www.google.com/device",
"expires_in": code_expiration_delta.total_seconds(),
"interval": CODE_CHECK_INTERVAL,
}
)
yield mock_flow


@pytest.fixture
async def mock_exchange(creds: OAuth2Credentials) -> YieldFixture[Mock]:
"""Fixture for mocking out the exchange for credentials."""
with patch(
"oauth2client.client.OAuth2WebServerFlow.step2_exchange", return_value=creds
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
return_value=creds,
) as mock:
yield mock

Expand Down Expand Up @@ -108,7 +117,6 @@ async def fire_alarm(hass, point_in_time):
await hass.async_block_till_done()


@pytest.mark.freeze_time("2022-06-03 15:19:59-00:00")
async def test_full_flow_yaml_creds(
hass: HomeAssistant,
mock_code_flow: Mock,
Expand All @@ -131,9 +139,8 @@ async def test_full_flow_yaml_creds(
"homeassistant.components.google.async_setup_entry", return_value=True
) as mock_setup:
# Run one tick to invoke the credential exchange check
freezer.tick(CODE_CHECK_ALARM_TIMEDELTA)
await fire_alarm(hass, datetime.datetime.utcnow())
await hass.async_block_till_done()
now = utcnow()
await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA)
result = await hass.config_entries.flow.async_configure(
flow_id=result["flow_id"]
)
Expand All @@ -143,11 +150,12 @@ async def test_full_flow_yaml_creds(
assert "data" in result
data = result["data"]
assert "token" in data
assert 0 < data["token"]["expires_in"] <= 60 * 60
assert (
data["token"]["expires_in"]
== 60 * 60 - CODE_CHECK_ALARM_TIMEDELTA.total_seconds()
datetime.datetime.now().timestamp()
<= data["token"]["expires_at"]
< (datetime.datetime.now() + datetime.timedelta(days=8)).timestamp()
)
assert data["token"]["expires_at"] == 1654273199.0
data["token"].pop("expires_at")
data["token"].pop("expires_in")
assert data == {
Expand Down Expand Up @@ -238,7 +246,7 @@ async def test_code_error(
assert await component_setup()

with patch(
"oauth2client.client.OAuth2WebServerFlow.step1_get_device_and_user_codes",
"homeassistant.components.google.api.OAuth2WebServerFlow.step1_get_device_and_user_codes",
side_effect=OAuth2DeviceCodeError("Test Failure"),
):
result = await hass.config_entries.flow.async_init(
Expand All @@ -248,13 +256,13 @@ async def test_code_error(
assert result.get("reason") == "oauth_error"


@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(minutes=-5)])
@pytest.mark.parametrize("code_expiration_delta", [datetime.timedelta(seconds=50)])
async def test_expired_after_exchange(
hass: HomeAssistant,
mock_code_flow: Mock,
component_setup: ComponentSetup,
) -> None:
"""Test successful creds setup."""
"""Test credential exchange expires."""
assert await component_setup()

result = await hass.config_entries.flow.async_init(
Expand All @@ -265,10 +273,14 @@ async def test_expired_after_exchange(
assert "description_placeholders" in result
assert "url" in result["description_placeholders"]

# Run one tick to invoke the credential exchange check
now = utcnow()
await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA)
await hass.async_block_till_done()
# Fail first attempt then advance clock past exchange timeout
with patch(
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
side_effect=FlowExchangeError(),
):
now = utcnow()
await fire_alarm(hass, now + datetime.timedelta(seconds=65))
await hass.async_block_till_done()

result = await hass.config_entries.flow.async_configure(flow_id=result["flow_id"])
assert result.get("type") == "abort"
Expand All @@ -295,7 +307,7 @@ async def test_exchange_error(
# Run one tick to invoke the credential exchange check
now = utcnow()
with patch(
"oauth2client.client.OAuth2WebServerFlow.step2_exchange",
"homeassistant.components.google.api.OAuth2WebServerFlow.step2_exchange",
side_effect=FlowExchangeError(),
):
now += CODE_CHECK_ALARM_TIMEDELTA
Expand Down