Skip to content

Commit

Permalink
Merge pull request #171 from lincc-frameworks/switch_away_from_os_join
Browse files Browse the repository at this point in the history
Replace os.join with Path
  • Loading branch information
jeremykubica authored Oct 16, 2024
2 parents e9cd611 + 7c64f73 commit 80b009b
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 55 deletions.
7 changes: 4 additions & 3 deletions docs/notebooks/test_snia.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import sncosmo\n",
"import tdastro\n",
"\n",
"from pathlib import Path\n",
"\n",
"from tdastro.example_runs.simulate_snia import run_snia_end2end"
]
},
Expand All @@ -46,7 +47,7 @@
"source": [
"from tdastro.astro_utils.opsim import OpSim, oversample_opsim\n",
"\n",
"opsim_name = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n",
"opsim_name = tdastro._TDASTRO_TEST_DATA_DIR / \"opsim_shorten.db\"\n",
"base_opsim = OpSim.from_db(opsim_name)\n",
"oversampled_observations = oversample_opsim(\n",
" base_opsim,\n",
Expand Down Expand Up @@ -91,7 +92,7 @@
"metadata": {},
"outputs": [],
"source": [
"passbands_dir = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"passbands\")\n",
"passbands_dir = Path(tdastro._TDASTRO_TEST_DATA_DIR, \"passbands\")\n",
"res, passbands = run_snia_end2end(\n",
" oversampled_observations,\n",
" passbands_dir=passbands_dir,\n",
Expand Down
40 changes: 24 additions & 16 deletions src/tdastro/astro_utils/passbands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import os
import socket
import urllib.parse
import urllib.request
from typing import Literal, Optional
from pathlib import Path
from typing import Literal, Optional, Union
from urllib.error import HTTPError, URLError

import numpy as np
Expand All @@ -29,7 +29,7 @@ def __init__(
self,
preset: str = None,
passband_parameters: Optional[list] = None,
table_dir: Optional[str] = None,
table_dir: Optional[Union[str, Path]] = None,
given_passbands: list = None,
**kwargs,
):
Expand All @@ -45,7 +45,7 @@ def __init__(
- survey : str
- filter_name : str
Dictionaries may also contain the following optional parameters:
- table_path : str
- table_path : str or Path
- table_url : str
- delta_wave : float
- trim_quantile : float
Expand Down Expand Up @@ -82,10 +82,15 @@ def __init__(
parameters[key] = value

# Set the table path if it is not already set and a table_dir is provided
if "table_path" not in parameters and table_dir is not None:
parameters["table_path"] = os.path.join(
table_dir, parameters["survey"], f"{parameters['filter_name']}.dat"
)
if "table_path" not in parameters:
if table_dir is not None:
parameters["table_path"] = Path(
table_dir,
parameters["survey"],
f"{parameters['filter_name']}.dat",
)
elif isinstance(parameters["table_path"], str):
parameters["table_path"] = Path(parameters["table_path"])

passband = Passband(**parameters)
self.passbands[passband.full_name] = passband
Expand Down Expand Up @@ -128,7 +133,7 @@ def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None:
if table_dir is None:
self.passbands[f"LSST_{filter_name}"] = Passband("LSST", filter_name, **kwargs)
else:
table_path = os.path.join(table_dir, "LSST", f"{filter_name}.dat")
table_path = Path(table_dir, "LSST", f"{filter_name}.dat")
self.passbands[f"LSST_{filter_name}"] = Passband(
"LSST",
filter_name,
Expand Down Expand Up @@ -276,7 +281,7 @@ def __init__(
filter_name: str,
delta_wave: Optional[float] = 5.0,
trim_quantile: Optional[float] = 1e-3,
table_path: Optional[str] = None,
table_path: Optional[Union[str, Path]] = None,
table_url: Optional[str] = None,
table_values: Optional[np.array] = None,
units: Optional[Literal["nm", "A"]] = "A",
Expand Down Expand Up @@ -318,7 +323,7 @@ def __init__(
self.filter_name = filter_name
self.full_name = f"{survey}_{filter_name}"

self.table_path = table_path
self.table_path = Path(table_path) if table_path is not None else None
self.table_url = table_url
self.units = units
self._in_band_wave_indices = None
Expand Down Expand Up @@ -364,11 +369,14 @@ def _load_transmission_table(self, force_download: bool = False) -> None:
"""
# Check if the table file exists locally, and download it if it does not
if self.table_path is None:
self.table_path = os.path.join(
os.path.dirname(__file__), f"passbands/{self.survey}/{self.filter_name}.dat"
self.table_path = Path(
Path(__file__).parent,
"passbands",
self.survey,
f"{self.filter_name}.dat",
)
os.makedirs(os.path.dirname(self.table_path), exist_ok=True)
if force_download or not os.path.exists(self.table_path):
self.table_path.parent.mkdir(parents=True, exist_ok=True)
if force_download or not self.table_path.exists():
self._download_transmission_table()

# Load the table file
Expand Down Expand Up @@ -409,7 +417,7 @@ def _download_transmission_table(self) -> bool:
socket.setdefaulttimeout(10)
logging.info(f"Retrieving {self.table_url}")
urllib.request.urlretrieve(self.table_url, self.table_path)
if os.path.getsize(self.table_path) == 0:
if self.table_path.stat().st_size == 0:
logging.error(f"Transmission table downloaded for {self.full_name} is empty.")
return False
else:
Expand Down
2 changes: 1 addition & 1 deletion src/tdastro/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def from_file(cls, filename):
Parameters
----------
filename : str
filename : str or Path
The name of the file.
"""
data_table = ascii.read(filename, format="ecsv")
Expand Down
17 changes: 8 additions & 9 deletions tests/tdastro/astro_utils/test_opsim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import tempfile
from pathlib import Path

Expand Down Expand Up @@ -134,28 +133,28 @@ def test_write_read_opsim():
ops_data = OpSim(pd.DataFrame(values))

with tempfile.TemporaryDirectory() as dir_name:
filename = os.path.join(dir_name, "test_write_read_opsim.db")
file_path = Path(dir_name, "test_write_read_opsim.db")

# The opsim does not exist until we write it.
assert not Path(filename).is_file()
assert not file_path.is_file()
with pytest.raises(FileNotFoundError):
_ = OpSim.from_db(filename)
_ = OpSim.from_db(file_path)

# We can write the opsim db.
ops_data.write_opsim_table(filename)
assert Path(filename).is_file()
ops_data.write_opsim_table(file_path)
assert file_path.is_file()

# We can reread the opsim db.
ops_data2 = OpSim.from_db(filename)
ops_data2 = OpSim.from_db(file_path)
assert len(ops_data2) == 5
assert np.allclose(values["observationStartMJD"], ops_data2["observationStartMJD"].to_numpy())
assert np.allclose(values["fieldRA"], ops_data2["fieldRA"].to_numpy())
assert np.allclose(values["fieldDec"], ops_data2["fieldDec"].to_numpy())

# We cannot overwrite unless we set overwrite=True
with pytest.raises(ValueError):
ops_data.write_opsim_table(filename, overwrite=False)
ops_data.write_opsim_table(filename, overwrite=True)
ops_data.write_opsim_table(file_path, overwrite=False)
ops_data.write_opsim_table(file_path, overwrite=True)


def test_opsim_range_search():
Expand Down
7 changes: 4 additions & 3 deletions tests/tdastro/astro_utils/test_passbands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path
from unittest.mock import patch

import numpy as np
Expand All @@ -17,8 +17,9 @@ def create_lsst_passband(path, filter_name, **kwargs):
def create_toy_passband(path, transmission_table, filter_name="a", **kwargs):
"""Helper function to create a toy Passband object for testing."""
survey = "TOY"
table_path = f"{path}/{survey}/{filter_name}.dat"
os.makedirs(os.path.dirname(table_path), exist_ok=True)
dir_path = Path(path, survey)
dir_path.mkdir(parents=True, exist_ok=True)
table_path = dir_path / f"{filter_name}.dat"

# Create a transmission table file
with open(table_path, "w") as f:
Expand Down
14 changes: 7 additions & 7 deletions tests/tdastro/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os.path
from pathlib import Path

import pytest
from tdastro import _TDASTRO_TEST_DATA_DIR
Expand All @@ -7,31 +7,31 @@
@pytest.fixture
def test_data_dir():
"""Return the base test data directory."""
return _TDASTRO_TEST_DATA_DIR
return Path(_TDASTRO_TEST_DATA_DIR)


@pytest.fixture
def grid_data_good_file(test_data_dir):
"""Return the file path for the good grid input file."""
return os.path.join(test_data_dir, "grid_input_good.ecsv")
return test_data_dir / "grid_input_good.ecsv"


@pytest.fixture
def grid_data_bad_file(test_data_dir):
"""Return the file path for the bad grid input file."""
return os.path.join(test_data_dir, "grid_input_bad.txt")
return test_data_dir / "grid_input_bad.txt"


@pytest.fixture
def opsim_small(test_data_dir):
"""Return the file path for the bad grid input file."""
return os.path.join(test_data_dir, "opsim_small.db")
return test_data_dir / "opsim_small.db"


@pytest.fixture
def opsim_shorten(test_data_dir):
"""Return the file path for the bad grid input file."""
return os.path.join(test_data_dir, "opsim_shorten.db")
return test_data_dir / "opsim_shorten.db"


@pytest.fixture
Expand All @@ -54,4 +54,4 @@ def oversampled_observations(opsim_shorten):
@pytest.fixture
def passbands_dir(test_data_dir):
"""Return the file path for passbands directory."""
return os.path.join(test_data_dir, "passbands")
return test_data_dir / "passbands"
27 changes: 11 additions & 16 deletions tests/tdastro/test_graph_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os
import tempfile
from pathlib import Path

import numpy as np
import pytest
from astropy.table import Table
Expand Down Expand Up @@ -379,27 +375,26 @@ def test_graph_state_update_multi():
state.update(state4)


def test_graph_to_from_file():
def test_graph_to_from_file(tmp_path):
"""Test that we can create an AstroPy Table from a GraphState."""
state = GraphState(num_samples=3)
state.set("a", "v1", [1.0, 2.0, 3.0])
state.set("a", "v2", [3.0, 4.0, 5.0])
state.set("b", "v1", [6.0, 7.0, 8.0])

with tempfile.TemporaryDirectory() as dir_name:
file_path = os.path.join(dir_name, "state.ecsv")
assert not Path(file_path).is_file()
file_path = tmp_path / "state.ecsv"
assert not file_path.is_file()

state.save_to_file(file_path)
assert Path(file_path).is_file()
state.save_to_file(file_path)
assert file_path.is_file()

state2 = GraphState.from_file(file_path)
assert state == state2
state2 = GraphState.from_file(file_path)
assert state == state2

# Cannot overwrite with it set to False, but works when set to True.
with pytest.raises(OSError):
state.save_to_file(file_path)
state.save_to_file(file_path, overwrite=True)
# Cannot overwrite with it set to False, but works when set to True.
with pytest.raises(OSError):
state.save_to_file(file_path)
state.save_to_file(file_path, overwrite=True)


def test_transpose_dict_of_list():
Expand Down

0 comments on commit 80b009b

Please sign in to comment.