From e16566ae32aa50841fc2cac2c28aa9d1a3e2f295 Mon Sep 17 00:00:00 2001 From: Eivind Jahren Date: Thu, 26 Sep 2024 14:57:07 +0200 Subject: [PATCH] Test transaction in StatefulStorageTest --- .../unit_tests/storage/test_local_storage.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/ert/unit_tests/storage/test_local_storage.py b/tests/ert/unit_tests/storage/test_local_storage.py index 0ed8f22db78..a5bc5dd2957 100644 --- a/tests/ert/unit_tests/storage/test_local_storage.py +++ b/tests/ert/unit_tests/storage/test_local_storage.py @@ -494,12 +494,16 @@ def test_write_transaction(data): class RaisingWriteNamedTemporaryFile: + entered = False + def __init__(self, *args, **kwargs): self.wrapped = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa + RaisingWriteNamedTemporaryFile.entered = False def __enter__(self, *args, **kwargs): self.actual_handle = self.wrapped.__enter__(*args, **kwargs) mock_handle = MagicMock() + RaisingWriteNamedTemporaryFile.entered = True def ctrlc(_): raise RuntimeError() @@ -517,9 +521,11 @@ def test_write_transaction_failure(tmp_path): with patch( "ert.storage.local_storage.NamedTemporaryFile", RaisingWriteNamedTemporaryFile, - ), pytest.raises(RuntimeError): + ) as f, pytest.raises(RuntimeError): storage._write_transaction(path, b"deadbeaf") + assert f.entered + assert not path.exists() @@ -669,6 +675,34 @@ def save_field(self, model_ensemble: Ensemble, field_data): ).to_dataset(), ) + @rule( + model_ensemble=ensembles, + field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)), + ) + def write_error_in_save_field(self, model_ensemble: Ensemble, field_data): + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + parameters = model_ensemble.parameter_values.values() + fields = [p for p in parameters if isinstance(p, Field)] + iens = 1 + assume(not storage_ensemble.realizations_initialized([iens])) + for f in fields: + with patch( + "ert.storage.local_storage.NamedTemporaryFile", + RaisingWriteNamedTemporaryFile, + ) as temp_file, pytest.raises(RuntimeError): + storage_ensemble.save_parameters( + f.name, + iens, + xr.DataArray( + field_data, + name="values", + dims=["x", "y", "z"], # type: ignore + ).to_dataset(), + ) + + assert temp_file.entered + assert not storage_ensemble.realizations_initialized([iens]) + @rule( model_ensemble=ensembles, ) @@ -831,6 +865,32 @@ def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: st ) model_ensemble.failure_messages[realization] = message + @rule(model_ensemble=ensembles, data=st.data(), message=st.text()) + def write_error_in_set_failure( + self, + model_ensemble: Ensemble, + data: st.DataObject, + message: str, + ): + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + realization = data.draw( + st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1) + ) + assume(not storage_ensemble.has_failure(realization)) + + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + + with patch( + "ert.storage.local_storage.NamedTemporaryFile", + RaisingWriteNamedTemporaryFile, + ) as f, pytest.raises(RuntimeError): + storage_ensemble.set_failure( + realization, RealizationStorageState.PARENT_FAILURE, message + ) + assert f.entered + + assert not storage_ensemble.has_failure(realization) + @rule(model_ensemble=ensembles, data=st.data()) def get_failure(self, model_ensemble: Ensemble, data: st.DataObject): storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)