Skip to content

Commit

Permalink
Only wait if an object is awaitable.
Browse files Browse the repository at this point in the history
Between juju 2.x and 3.x, the resulting return objects from the
run method has changed to be awaitable or not. This change checks
to see if the resulting object can be awaited or not. Without this,
zaza will simply hang waiting on the results of an object which is
not actually awaitable.

Signed-off-by: Billy Olsen <[email protected]>
  • Loading branch information
Billy Olsen authored and wolsen committed Mar 1, 2024
1 parent c44d359 commit 1a68ae2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
45 changes: 45 additions & 0 deletions unit_tests/test_zaza_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,14 @@ def test_run_on_unit(self):
self.cmd = cmd = 'somecommand someargument'
self.patch_object(model, 'Model')
self.patch_object(model, 'get_unit_from_name')
self.patch("inspect.isawaitable", name="isawaitable",
return_value=True)
self.get_unit_from_name.return_value = self.unit1
self.Model.return_value = self.Model_mock
self.assertEqual(model.run_on_unit('app/2', cmd),
expected)
self.unit1.run.assert_called_once_with(cmd, timeout=None)
self.action.wait.assert_called_once()

def test_run_on_unit_juju2_x(self):
del self.action.results
Expand All @@ -803,11 +806,14 @@ def test_run_on_unit_juju2_x(self):
self.cmd = cmd = 'somecommand someargument'
self.patch_object(model, 'Model')
self.patch_object(model, 'get_unit_from_name')
self.patch("inspect.isawaitable", name="isawaitable",
return_value=False)
self.get_unit_from_name.return_value = self.unit1
self.Model.return_value = self.Model_mock
self.assertEqual(model.run_on_unit('app/2', cmd),
expected)
self.unit1.run.assert_called_once_with(cmd, timeout=None)
self.action.wait.assert_not_called()

def test_run_on_unit_lc_keys(self):
self.patch_object(model, 'get_juju_model', return_value='mname')
Expand Down Expand Up @@ -1117,6 +1123,9 @@ def test_run_action_on_units(self):
self.patch_object(model, 'Model')
self.Model.return_value = self.Model_mock
self.patch_object(model, 'async_get_unit_from_name')
self.patch("inspect.isawaitable", return_value=False,
name="isawaitable")
self.patch_object(model, 'async_block_until')
units = {
'app/1': self.unit1,
'app/2': self.unit2}
Expand All @@ -1138,11 +1147,15 @@ async def _async_get_unit_from_name(x, *args):
'backup',
backup_dir='/dev/null')

self.async_block_until.assert_not_called()

def test_run_action_on_units_timeout(self):
self.patch_object(model, 'get_juju_model', return_value='mname')
self.patch_object(model, 'Model')
self.Model.return_value = self.Model_mock
self.patch_object(model, 'get_unit_from_name')
self.patch("inspect.isawaitable", return_value=True,
name="isawaitable")
self.get_unit_from_name.return_value = self.unit1
self.run_action.data = {'status': 'running'}
with self.assertRaises(AsyncTimeoutError):
Expand Down Expand Up @@ -1170,6 +1183,38 @@ async def _fake_get_action_output(_):
raise_on_failure=True,
action_params={'backup_dir': '/dev/null'})

def test_run_action_on_units_async(self):
"""Tests that non-awaitable action results aren't awaited on."""
self.patch_object(model, 'get_juju_model', return_value='mname')
self.patch_object(model, 'Model')
self.Model.return_value = self.Model_mock
self.patch_object(model, 'async_get_unit_from_name')
self.patch_object(model, 'async_block_until')
self.patch("inspect.isawaitable", return_value=True,
name="isawaitable")
units = {
'app/1': self.unit1,
'app/2': self.unit2}

async def _async_get_unit_from_name(x, *args):
nonlocal units
return units[x]

self.async_get_unit_from_name.side_effect = _async_get_unit_from_name
self.run_action.data = {'status': 'completed'}
model.run_action_on_units(
['app/1', 'app/2'],
'backup',
action_params={'backup_dir': '/dev/null'})
self.unit1.run_action.assert_called_once_with(
'backup',
backup_dir='/dev/null')
self.unit2.run_action.assert_called_once_with(
'backup',
backup_dir='/dev/null')

assert self.async_block_until.called

def _application_states_setup(self, setup, units_idle=True):
self.system_ready = True
self._block_until_calls = 0
Expand Down
18 changes: 13 additions & 5 deletions zaza/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ async def async_run_on_unit(unit_name, command, model_name=None, timeout=None):
model = await get_model(model_name)
unit = await async_get_unit_from_name(unit_name, model)
action = await unit.run(command, timeout=timeout)
await action.wait()
if inspect.isawaitable(action):
await action.wait()
action = _normalise_action_object(action)
results = action.data.get('results')
return _normalise_action_results(results)
Expand Down Expand Up @@ -598,7 +599,8 @@ async def async_run_on_leader(application_name, command, model_name=None,
is_leader = await unit.is_leader_from_status()
if is_leader:
action = await unit.run(command, timeout=timeout)
await action.wait()
if inspect.isawaitable(action):
await action.wait()

Check warning on line 603 in zaza/model.py

View check run for this annotation

Codecov / codecov/patch

zaza/model.py#L603

Added line #L603 was not covered by tests
action = _normalise_action_object(action)
results = action.data.get('results')
return _normalise_action_results(results)
Expand Down Expand Up @@ -1211,18 +1213,24 @@ async def async_run_action_on_units(units, action_name, action_params=None,

model = await get_model(model_name)
actions = []
async_actions = []
for unit_name in units:
unit = await async_get_unit_from_name(unit_name, model)
action_obj = await unit.run_action(action_name, **action_params)
actions.append(action_obj)
if inspect.isawaitable(action_obj):
async_actions.append(action_obj)
else:
actions.append(action_obj)

async def _check_actions():
for action_obj in actions:
for action_obj in async_actions:
if action_obj.data['status'] in ['running', 'pending']:
return False
return True

await async_block_until(_check_actions, timeout=timeout)
if async_actions:
await async_block_until(_check_actions, timeout=timeout)
actions.extend(async_actions)

for action_obj in actions:
if raise_on_failure and action_obj.data['status'] != 'completed':
Expand Down

0 comments on commit 1a68ae2

Please sign in to comment.