Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 16, 2024
1 parent 01ec7fe commit 7c64f73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/test_snia.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"source": [
"from tdastro.astro_utils.opsim import OpSim, oversample_opsim\n",
"\n",
"opsim_name = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n",
"opsim_name = tdastro._TDASTRO_TEST_DATA_DIR / \"opsim_shorten.db\"\n",
"base_opsim = OpSim.from_db(opsim_name)\n",
"oversampled_observations = oversample_opsim(\n",
" base_opsim,\n",
Expand Down
26 changes: 11 additions & 15 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import tempfile
from pathlib import Path

import numpy as np
import pytest
from astropy.table import Table
Expand Down Expand Up @@ -378,27 +375,26 @@ def test_graph_state_update_multi():
state.update(state4)


def test_graph_to_from_file():
def test_graph_to_from_file(tmp_path):
"""Test that we can create an AstroPy Table from a GraphState."""
state = GraphState(num_samples=3)
state.set("a", "v1", [1.0, 2.0, 3.0])
state.set("a", "v2", [3.0, 4.0, 5.0])
state.set("b", "v1", [6.0, 7.0, 8.0])

with tempfile.TemporaryDirectory() as dir_name:
file_path = Path(dir_name, "state.ecsv")
assert not file_path.is_file()
file_path = tmp_path / "state.ecsv"
assert not file_path.is_file()

state.save_to_file(file_path)
assert file_path.is_file()
state.save_to_file(file_path)
assert file_path.is_file()

state2 = GraphState.from_file(file_path)
assert state == state2
state2 = GraphState.from_file(file_path)
assert state == state2

# Cannot overwrite with it set to False, but works when set to True.
with pytest.raises(OSError):
state.save_to_file(file_path)
state.save_to_file(file_path, overwrite=True)
# Cannot overwrite with it set to False, but works when set to True.
with pytest.raises(OSError):
state.save_to_file(file_path)
state.save_to_file(file_path, overwrite=True)


def test_transpose_dict_of_list():
Expand Down

0 comments on commit 7c64f73

Please sign in to comment.