From 01ec7febe6656696eeee20f338662e50f594ff7a Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 15 Oct 2024 16:05:10 -0400 Subject: [PATCH 1/2] Replace os.join with Path --- docs/notebooks/test_snia.ipynb | 7 ++-- src/tdastro/astro_utils/passbands.py | 40 ++++++++++++--------- src/tdastro/graph_state.py | 2 +- tests/tdastro/astro_utils/test_opsim.py | 17 +++++---- tests/tdastro/astro_utils/test_passbands.py | 7 ++-- tests/tdastro/conftest.py | 14 ++++---- tests/tdastro/test_graph_state.py | 7 ++-- 7 files changed, 51 insertions(+), 43 deletions(-) diff --git a/docs/notebooks/test_snia.ipynb b/docs/notebooks/test_snia.ipynb index 5796de5d..5c5319c4 100644 --- a/docs/notebooks/test_snia.ipynb +++ b/docs/notebooks/test_snia.ipynb @@ -20,10 +20,11 @@ "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import os\n", "import sncosmo\n", "import tdastro\n", "\n", + "from pathlib import Path\n", + "\n", "from tdastro.example_runs.simulate_snia import run_snia_end2end" ] }, @@ -46,7 +47,7 @@ "source": [ "from tdastro.astro_utils.opsim import OpSim, oversample_opsim\n", "\n", - "opsim_name = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n", + "opsim_name = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n", "base_opsim = OpSim.from_db(opsim_name)\n", "oversampled_observations = oversample_opsim(\n", " base_opsim,\n", @@ -91,7 +92,7 @@ "metadata": {}, "outputs": [], "source": [ - "passbands_dir = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"passbands\")\n", + "passbands_dir = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"passbands\")\n", "res, passbands = run_snia_end2end(\n", " oversampled_observations,\n", " passbands_dir=passbands_dir,\n", diff --git a/src/tdastro/astro_utils/passbands.py b/src/tdastro/astro_utils/passbands.py index 9d8df598..e0b60d82 100644 --- a/src/tdastro/astro_utils/passbands.py +++ b/src/tdastro/astro_utils/passbands.py @@ -1,9 +1,9 @@ import logging -import os import socket import urllib.parse import urllib.request -from typing import Literal, Optional +from pathlib import Path +from typing import Literal, Optional, Union from urllib.error import HTTPError, URLError import numpy as np @@ -29,7 +29,7 @@ def __init__( self, preset: str = None, passband_parameters: Optional[list] = None, - table_dir: Optional[str] = None, + table_dir: Optional[Union[str, Path]] = None, given_passbands: list = None, **kwargs, ): @@ -45,7 +45,7 @@ def __init__( - survey : str - filter_name : str Dictionaries may also contain the following optional parameters: - - table_path : str + - table_path : str or Path - table_url : str - delta_wave : float - trim_quantile : float @@ -82,10 +82,15 @@ def __init__( parameters[key] = value # Set the table path if it is not already set and a table_dir is provided - if "table_path" not in parameters and table_dir is not None: - parameters["table_path"] = os.path.join( - table_dir, parameters["survey"], f"{parameters['filter_name']}.dat" - ) + if "table_path" not in parameters: + if table_dir is not None: + parameters["table_path"] = Path( + table_dir, + parameters["survey"], + f"{parameters['filter_name']}.dat", + ) + elif isinstance(parameters["table_path"], str): + parameters["table_path"] = Path(parameters["table_path"]) passband = Passband(**parameters) self.passbands[passband.full_name] = passband @@ -128,7 +133,7 @@ def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None: if table_dir is None: self.passbands[f"LSST_{filter_name}"] = Passband("LSST", filter_name, **kwargs) else: - table_path = os.path.join(table_dir, "LSST", f"{filter_name}.dat") + table_path = Path(table_dir, "LSST", f"{filter_name}.dat") self.passbands[f"LSST_{filter_name}"] = Passband( "LSST", filter_name, @@ -276,7 +281,7 @@ def __init__( filter_name: str, delta_wave: Optional[float] = 5.0, trim_quantile: Optional[float] = 1e-3, - table_path: Optional[str] = None, + table_path: Optional[Union[str, Path]] = None, table_url: Optional[str] = None, table_values: Optional[np.array] = None, units: Optional[Literal["nm", "A"]] = "A", @@ -318,7 +323,7 @@ def __init__( self.filter_name = filter_name self.full_name = f"{survey}_{filter_name}" - self.table_path = table_path + self.table_path = Path(table_path) if table_path is not None else None self.table_url = table_url self.units = units self._in_band_wave_indices = None @@ -364,11 +369,14 @@ def _load_transmission_table(self, force_download: bool = False) -> None: """ # Check if the table file exists locally, and download it if it does not if self.table_path is None: - self.table_path = os.path.join( - os.path.dirname(__file__), f"passbands/{self.survey}/{self.filter_name}.dat" + self.table_path = Path( + Path(__file__).parent, + "passbands", + self.survey, + f"{self.filter_name}.dat", ) - os.makedirs(os.path.dirname(self.table_path), exist_ok=True) - if force_download or not os.path.exists(self.table_path): + self.table_path.parent.mkdir(parents=True, exist_ok=True) + if force_download or not self.table_path.exists(): self._download_transmission_table() # Load the table file @@ -409,7 +417,7 @@ def _download_transmission_table(self) -> bool: socket.setdefaulttimeout(10) logging.info(f"Retrieving {self.table_url}") urllib.request.urlretrieve(self.table_url, self.table_path) - if os.path.getsize(self.table_path) == 0: + if self.table_path.stat().st_size == 0: logging.error(f"Transmission table downloaded for {self.full_name} is empty.") return False else: diff --git a/src/tdastro/graph_state.py b/src/tdastro/graph_state.py index ed99c03c..07a33cac 100644 --- a/src/tdastro/graph_state.py +++ b/src/tdastro/graph_state.py @@ -171,7 +171,7 @@ def from_file(cls, filename): Parameters ---------- - filename : str + filename : str or Path The name of the file. """ data_table = ascii.read(filename, format="ecsv") diff --git a/tests/tdastro/astro_utils/test_opsim.py b/tests/tdastro/astro_utils/test_opsim.py index 9c67dc8d..a2a4aff2 100644 --- a/tests/tdastro/astro_utils/test_opsim.py +++ b/tests/tdastro/astro_utils/test_opsim.py @@ -1,4 +1,3 @@ -import os import tempfile from pathlib import Path @@ -134,19 +133,19 @@ def test_write_read_opsim(): ops_data = OpSim(pd.DataFrame(values)) with tempfile.TemporaryDirectory() as dir_name: - filename = os.path.join(dir_name, "test_write_read_opsim.db") + file_path = Path(dir_name, "test_write_read_opsim.db") # The opsim does not exist until we write it. - assert not Path(filename).is_file() + assert not file_path.is_file() with pytest.raises(FileNotFoundError): - _ = OpSim.from_db(filename) + _ = OpSim.from_db(file_path) # We can write the opsim db. - ops_data.write_opsim_table(filename) - assert Path(filename).is_file() + ops_data.write_opsim_table(file_path) + assert file_path.is_file() # We can reread the opsim db. - ops_data2 = OpSim.from_db(filename) + ops_data2 = OpSim.from_db(file_path) assert len(ops_data2) == 5 assert np.allclose(values["observationStartMJD"], ops_data2["observationStartMJD"].to_numpy()) assert np.allclose(values["fieldRA"], ops_data2["fieldRA"].to_numpy()) @@ -154,8 +153,8 @@ def test_write_read_opsim(): # We cannot overwrite unless we set overwrite=True with pytest.raises(ValueError): - ops_data.write_opsim_table(filename, overwrite=False) - ops_data.write_opsim_table(filename, overwrite=True) + ops_data.write_opsim_table(file_path, overwrite=False) + ops_data.write_opsim_table(file_path, overwrite=True) def test_opsim_range_search(): diff --git a/tests/tdastro/astro_utils/test_passbands.py b/tests/tdastro/astro_utils/test_passbands.py index f587299b..5277ff6f 100644 --- a/tests/tdastro/astro_utils/test_passbands.py +++ b/tests/tdastro/astro_utils/test_passbands.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from unittest.mock import patch import numpy as np @@ -17,8 +17,9 @@ def create_lsst_passband(path, filter_name, **kwargs): def create_toy_passband(path, transmission_table, filter_name="a", **kwargs): """Helper function to create a toy Passband object for testing.""" survey = "TOY" - table_path = f"{path}/{survey}/{filter_name}.dat" - os.makedirs(os.path.dirname(table_path), exist_ok=True) + dir_path = Path(path, survey) + dir_path.mkdir(parents=True, exist_ok=True) + table_path = dir_path / f"{filter_name}.dat" # Create a transmission table file with open(table_path, "w") as f: diff --git a/tests/tdastro/conftest.py b/tests/tdastro/conftest.py index 259d28c1..44ebcc74 100644 --- a/tests/tdastro/conftest.py +++ b/tests/tdastro/conftest.py @@ -1,4 +1,4 @@ -import os.path +from pathlib import Path import pytest from tdastro import _TDASTRO_TEST_DATA_DIR @@ -7,31 +7,31 @@ @pytest.fixture def test_data_dir(): """Return the base test data directory.""" - return _TDASTRO_TEST_DATA_DIR + return Path(_TDASTRO_TEST_DATA_DIR) @pytest.fixture def grid_data_good_file(test_data_dir): """Return the file path for the good grid input file.""" - return os.path.join(test_data_dir, "grid_input_good.ecsv") + return test_data_dir / "grid_input_good.ecsv" @pytest.fixture def grid_data_bad_file(test_data_dir): """Return the file path for the bad grid input file.""" - return os.path.join(test_data_dir, "grid_input_bad.txt") + return test_data_dir / "grid_input_bad.txt" @pytest.fixture def opsim_small(test_data_dir): """Return the file path for the bad grid input file.""" - return os.path.join(test_data_dir, "opsim_small.db") + return test_data_dir / "opsim_small.db" @pytest.fixture def opsim_shorten(test_data_dir): """Return the file path for the bad grid input file.""" - return os.path.join(test_data_dir, "opsim_shorten.db") + return test_data_dir / "opsim_shorten.db" @pytest.fixture @@ -54,4 +54,4 @@ def oversampled_observations(opsim_shorten): @pytest.fixture def passbands_dir(test_data_dir): """Return the file path for passbands directory.""" - return os.path.join(test_data_dir, "passbands") + return test_data_dir / "passbands" diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 7df62f3b..468b4770 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -1,4 +1,3 @@ -import os import tempfile from pathlib import Path @@ -387,11 +386,11 @@ def test_graph_to_from_file(): state.set("b", "v1", [6.0, 7.0, 8.0]) with tempfile.TemporaryDirectory() as dir_name: - file_path = os.path.join(dir_name, "state.ecsv") - assert not Path(file_path).is_file() + file_path = Path(dir_name, "state.ecsv") + assert not file_path.is_file() state.save_to_file(file_path) - assert Path(file_path).is_file() + assert file_path.is_file() state2 = GraphState.from_file(file_path) assert state == state2 From 7c64f730865048eddb1b6d8b8d19e4b0df08768e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 16 Oct 2024 07:59:55 -0400 Subject: [PATCH 2/2] Address PR comments --- docs/notebooks/test_snia.ipynb | 2 +- tests/tdastro/test_graph_state.py | 26 +++++++++++--------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/notebooks/test_snia.ipynb b/docs/notebooks/test_snia.ipynb index 5c5319c4..bb9d8ca9 100644 --- a/docs/notebooks/test_snia.ipynb +++ b/docs/notebooks/test_snia.ipynb @@ -47,7 +47,7 @@ "source": [ "from tdastro.astro_utils.opsim import OpSim, oversample_opsim\n", "\n", - "opsim_name = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n", + "opsim_name = tdastro._TDASTRO_TEST_DATA_DIR / \"opsim_shorten.db\"\n", "base_opsim = OpSim.from_db(opsim_name)\n", "oversampled_observations = oversample_opsim(\n", " base_opsim,\n", diff --git a/tests/tdastro/test_graph_state.py b/tests/tdastro/test_graph_state.py index 468b4770..89badcf6 100644 --- a/tests/tdastro/test_graph_state.py +++ b/tests/tdastro/test_graph_state.py @@ -1,6 +1,3 @@ -import tempfile -from pathlib import Path - import numpy as np import pytest from astropy.table import Table @@ -378,27 +375,26 @@ def test_graph_state_update_multi(): state.update(state4) -def test_graph_to_from_file(): +def test_graph_to_from_file(tmp_path): """Test that we can create an AstroPy Table from a GraphState.""" state = GraphState(num_samples=3) state.set("a", "v1", [1.0, 2.0, 3.0]) state.set("a", "v2", [3.0, 4.0, 5.0]) state.set("b", "v1", [6.0, 7.0, 8.0]) - with tempfile.TemporaryDirectory() as dir_name: - file_path = Path(dir_name, "state.ecsv") - assert not file_path.is_file() + file_path = tmp_path / "state.ecsv" + assert not file_path.is_file() - state.save_to_file(file_path) - assert file_path.is_file() + state.save_to_file(file_path) + assert file_path.is_file() - state2 = GraphState.from_file(file_path) - assert state == state2 + state2 = GraphState.from_file(file_path) + assert state == state2 - # Cannot overwrite with it set to False, but works when set to True. - with pytest.raises(OSError): - state.save_to_file(file_path) - state.save_to_file(file_path, overwrite=True) + # Cannot overwrite with it set to False, but works when set to True. + with pytest.raises(OSError): + state.save_to_file(file_path) + state.save_to_file(file_path, overwrite=True) def test_transpose_dict_of_list():