diff --git a/unit_tests/test_zaza_model.py b/unit_tests/test_zaza_model.py index 66fd9b1b..f46d5158 100644 --- a/unit_tests/test_zaza_model.py +++ b/unit_tests/test_zaza_model.py @@ -785,11 +785,14 @@ def test_run_on_unit(self): self.cmd = cmd = 'somecommand someargument' self.patch_object(model, 'Model') self.patch_object(model, 'get_unit_from_name') + self.patch("inspect.isawaitable", name="isawaitable", + return_value=True) self.get_unit_from_name.return_value = self.unit1 self.Model.return_value = self.Model_mock self.assertEqual(model.run_on_unit('app/2', cmd), expected) self.unit1.run.assert_called_once_with(cmd, timeout=None) + self.action.wait.assert_called_once() def test_run_on_unit_juju2_x(self): del self.action.results @@ -803,11 +806,14 @@ def test_run_on_unit_juju2_x(self): self.cmd = cmd = 'somecommand someargument' self.patch_object(model, 'Model') self.patch_object(model, 'get_unit_from_name') + self.patch("inspect.isawaitable", name="isawaitable", + return_value=False) self.get_unit_from_name.return_value = self.unit1 self.Model.return_value = self.Model_mock self.assertEqual(model.run_on_unit('app/2', cmd), expected) self.unit1.run.assert_called_once_with(cmd, timeout=None) + self.action.wait.assert_not_called() def test_run_on_unit_lc_keys(self): self.patch_object(model, 'get_juju_model', return_value='mname') @@ -900,6 +906,8 @@ def test_run_on_leader(self): self.cmd = cmd = 'somecommand someargument' self.patch_object(model, 'Model') self.Model.return_value = self.Model_mock + self.patch('inspect.isawaitable', return_value=True, + name='isawaitable') self.assertEqual(model.run_on_leader('app', cmd), expected) self.unit2.run.assert_called_once_with(cmd, timeout=None) @@ -1700,6 +1708,8 @@ def test_block_until_file_has_contents(self): _fileobj = mock.MagicMock() _fileobj.__enter__().read.return_value = "somestring" self._open.return_value = _fileobj + self.patch('inspect.isawaitable', return_value=True, + name='isawaitable') model.block_until_file_has_contents( 'app', '/tmp/src/myfile.txt', @@ -1725,6 +1735,8 @@ def test_block_until_file_has_contents_juju2_x(self): _fileobj = mock.MagicMock() _fileobj.__enter__().read.return_value = "somestring" self._open.return_value = _fileobj + self.patch('inspect.isawaitable', return_value=False, + name='isawaitable') model.block_until_file_has_contents( 'app', '/tmp/src/myfile.txt', @@ -1805,6 +1817,8 @@ def test_block_until_file_missing(self): self.Model.return_value = self.Model_mock self.patch_object(model, 'get_juju_model', return_value='mname') self.action.results = {'stdout': "1"} + self.patch('inspect.isawaitable', return_value=True, + name='isawaitable') model.block_until_file_missing( 'app', '/tmp/src/myfile.txt', @@ -2776,6 +2790,51 @@ async def _g(): await model.async_block_until(_f, _g, timeout=0.1) + async def test_update_unknown_action_status_invalid_params(self): + """Test update unknown action status invalid params.""" + self.assertRaises(ValueError, model.update_unknown_action_status, + None, mock.MagicMock()) + self.assertRaises(ValueError, model.update_unknown_action_status, + mock.MagicMock(), None) + + async def test_update_unknown_action_status_not_unknown(self): + """Test update unknown action status with status not unknown.""" + mock_model = mock.AsyncMock() + action_obj = mock.AsyncMock() + action_obj.data = {"status": "running"} + + await model.async_update_unknown_action_status(model, action_obj) + mock_model.get_action_status.assert_not_called() + + async def test_update_unknown_action_status_no_completion_timestamp(self): + """Test update unknown action status without a completion timestamp.""" + model_mock = mock.AsyncMock() + action_obj = mock.MagicMock() + action_obj.data = {"status": "unknown", "completed": ""} + + await model.async_update_unknown_action_status(model_mock, action_obj) + model_mock.get_action_status.assert_not_called() + + async def test_update_unknown_action_status(self): + """Test update unknown action status updates status.""" + mock_model = mock.AsyncMock() + + class ActionObj: + id = "1234" + data = { + "status": "unknown", + "completed": "2024-03-01T12:45:14" + } + + async def get_action_status(_id): + return {"1234": "completed"} + + mock_model.get_action_status.side_effect = get_action_status + action_obj = ActionObj() + + await model.async_update_unknown_action_status(mock_model, action_obj) + self.assertEqual(action_obj.data["status"], "completed") + async def test_run_on_machine(self): with mock.patch.object( model.generic_utils, diff --git a/zaza/model.py b/zaza/model.py index 04166394..7e7526b6 100644 --- a/zaza/model.py +++ b/zaza/model.py @@ -570,7 +570,8 @@ async def async_run_on_unit(unit_name, command, model_name=None, timeout=None): model = await get_model(model_name) unit = await async_get_unit_from_name(unit_name, model) action = await unit.run(command, timeout=timeout) - await action.wait() + if inspect.isawaitable(action): + await action.wait() action = _normalise_action_object(action) results = action.data.get('results') return _normalise_action_results(results) @@ -598,7 +599,8 @@ async def async_run_on_leader(application_name, command, model_name=None, is_leader = await unit.is_leader_from_status() if is_leader: action = await unit.run(command, timeout=timeout) - await action.wait() + if inspect.isawaitable(action): + await action.wait() action = _normalise_action_object(action) results = action.data.get('results') return _normalise_action_results(results) @@ -1104,6 +1106,44 @@ def _normalise_action_object(action_obj): return action_obj +async def async_update_unknown_action_status(model, action_obj): + """Update the action status if the status is unknown. + + Updates the action status for an action object when its data has + a completion timestamp, indicating it did complete, but the status + is set to unknown. When the action object is in this state, this + function will query for the latest status. This function will only + query for the status update a single time. + + :param model: the model the action_obj belongs to + :type model: juju.model.ModelEntity + :param action_obj: an action that should be updated if the status + is unknown. + :type action_obj: juju.model.ActionEntity + :raises ValueError: If either the model or action_obj is invalid + :return: None + """ + if not model or not action_obj: + raise ValueError("Expected model and action_obj to be valid. " + f"Got model: {model}, action_obj: {action_obj}") + + # If the status is not unknown, don't update the status + if action_obj.data.get('status', '') != 'unknown': + return + + # If the completed timestamp is not set, don't update the status + if not action_obj.data.get('completed', ''): + return + + logging.debug("Action results were complete but status is unknown. " + "Refreshing status.") + updated_status = await model.get_action_status(action_obj.id) + action_obj.data['status'] = updated_status.get(action_obj.id) + + +update_unknown_action_status = sync_wrapper(async_update_unknown_action_status) + + async def async_run_action(unit_name, action_name, model_name=None, action_params=None, raise_on_failure=False): """Run action on given unit. @@ -1130,6 +1170,8 @@ async def async_run_action(unit_name, action_name, model_name=None, action_obj = await unit.run_action(action_name, **action_params) await action_obj.wait() action_obj = _normalise_action_object(action_obj) + await async_update_unknown_action_status(model, action_obj) + if raise_on_failure and action_obj.data['status'] != 'completed': try: output = await model.get_action_output(action_obj.id) @@ -1171,6 +1213,7 @@ async def async_run_action_on_leader(application_name, action_name, **action_params) await action_obj.wait() action_obj = _normalise_action_object(action_obj) + await async_update_unknown_action_status(model, action_obj) if raise_on_failure and action_obj.data['status'] != 'completed': try: output = await model.get_action_output(action_obj.id) @@ -1225,6 +1268,7 @@ async def _check_actions(): await async_block_until(_check_actions, timeout=timeout) for action_obj in actions: + await async_update_unknown_action_status(model, action_obj) if raise_on_failure and action_obj.data['status'] != 'completed': try: output = await model.get_action_output(action_obj.id) @@ -2116,7 +2160,8 @@ async def _check_file(): for unit in units: try: output = await unit.run('cat {}'.format(remote_file)) - await output.wait() + if inspect.isawaitable(output): + await output.wait() results = {} try: results = output.results @@ -2257,7 +2302,8 @@ async def _check_for_file(model): for unit in units: try: output = await unit.run('test -e "{}"; echo $?'.format(path)) - await output.wait() + if inspect.isawaitable(output): + await output.wait() output = _normalise_action_object(output) output_result = _normalise_action_results( output.data.get('results', {}))