diff --git a/docs/notebooks/test_snia.ipynb b/docs/notebooks/test_snia.ipynb index 5c5319c..bb9d8ca 100644 --- a/docs/notebooks/test_snia.ipynb +++ b/docs/notebooks/test_snia.ipynb @@ -47,7 +47,7 @@ "source": [ "from tdastro.astro_utils.opsim import OpSim, oversample_opsim\n", "\n", - "opsim_name = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n", + "opsim_name = tdastro._TDASTRO_TEST_DATA_DIR / \"opsim_shorten.db\"\n", "base_opsim = OpSim.from_db(opsim_name)\n", "oversampled_observations = oversample_opsim(\n", " base_opsim,\n", diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 468b477..89badcf 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -1,6 +1,3 @@ -import tempfile -from pathlib import Path - import numpy as np import pytest from astropy.table import Table @@ -378,27 +375,26 @@ def test_graph_state_update_multi(): state.update(state4) -def test_graph_to_from_file(): +def test_graph_to_from_file(tmp_path): """Test that we can create an AstroPy Table from a GraphState.""" state = GraphState(num_samples=3) state.set("a", "v1", [1.0, 2.0, 3.0]) state.set("a", "v2", [3.0, 4.0, 5.0]) state.set("b", "v1", [6.0, 7.0, 8.0]) - with tempfile.TemporaryDirectory() as dir_name: - file_path = Path(dir_name, "state.ecsv") - assert not file_path.is_file() + file_path = tmp_path / "state.ecsv" + assert not file_path.is_file() - state.save_to_file(file_path) - assert file_path.is_file() + state.save_to_file(file_path) + assert file_path.is_file() - state2 = GraphState.from_file(file_path) - assert state == state2 + state2 = GraphState.from_file(file_path) + assert state == state2 - # Cannot overwrite with it set to False, but works when set to True. - with pytest.raises(OSError): - state.save_to_file(file_path) - state.save_to_file(file_path, overwrite=True) + # Cannot overwrite with it set to False, but works when set to True. + with pytest.raises(OSError): + state.save_to_file(file_path) + state.save_to_file(file_path, overwrite=True) def test_transpose_dict_of_list():