diff --git a/docs/notebooks/test_snia.ipynb b/docs/notebooks/test_snia.ipynb index e884b9f9..b618aa5f 100644 --- a/docs/notebooks/test_snia.ipynb +++ b/docs/notebooks/test_snia.ipynb @@ -3,69 +3,119 @@ { "cell_type": "code", "execution_count": null, - "id": "1bd86f47-7061-4aff-a7e0-be226219b139", - "metadata": {}, + "id": "f2dc23f31b3601e1", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:25.986845Z", + "start_time": "2024-09-26T18:10:25.982521Z" + } + }, "outputs": [], "source": [ - "%load_ext autoreload\n", - "%autoreload 2" + "# %load_ext autoreload\n", + "# %autoreload 2" ] }, { "cell_type": "code", "execution_count": null, - "id": "4bf6fe10-59dc-439e-9371-c0cc15755f62", - "metadata": {}, + "id": "4e459a387df01a7c", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:25.995745Z", + "start_time": "2024-09-26T18:10:25.992519Z" + } + }, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import sncosmo\n", "import sys\n", - "from pathlib import Path\n", "import tdastro\n", "\n", - "sys.path.append(\n", - " str((Path(tdastro.__file__).parent / \"..\" / \"..\" / \"tests\" / \"tdastro\" / \"sources\").resolve())\n", - ")" + "# Append the path to the test directory so we can import run_snia_end2end\n", + "test_path = tdastro._TDASTRO_TEST_DIR\n", + "sys.path.append(str(test_path.resolve()))\n", + "\n", + "from sources.test_snia import run_snia_end2end" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "817f5d23-925d-4ef7-bc22-ad110020e57d", + "cell_type": "markdown", + "id": "a26e33f8", "metadata": {}, - "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import sncosmo" + "### Create the data we will use for this test\n", + "\n", + "Load a sample opsim file (opsim_shorten.db) from the test's data directory and use the `oversample_opsim()` function to sample every 0.01 days from MJD=61406.0 to MJD=61771.0." ] }, { "cell_type": "code", "execution_count": null, - "id": "c87c4488-9368-4a0a-8400-1596bb15492a", + "id": "89adedae", "metadata": {}, "outputs": [], "source": [ - "from test_snia import test_snia_end2end" + "from tdastro.astro_utils.opsim import oversample_opsim, OpSim\n", + "\n", + "opsim_name = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"opsim_shorten.db\")\n", + "base_opsim = OpSim.from_db(opsim_name)\n", + "oversampled_observations = oversample_opsim(\n", + " base_opsim,\n", + " pointing=(0.0, 0.0),\n", + " search_radius=180.0,\n", + " delta_t=0.01,\n", + " time_range=(61406.0, 61771.0),\n", + " bands=None,\n", + " strategy=\"darkest_sky\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1c96e021", + "metadata": {}, + "source": [ + "### Run the test\n", + "\n", + "Run the end to end test using the `run_snia_end2end()` to generate 20 samples." ] }, { "cell_type": "code", "execution_count": null, - "id": "3285d471-c454-4c7a-b7cb-b8a79ab7ec98", - "metadata": {}, + "id": "58a2bc03f3a3a3a0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.578886Z", + "start_time": "2024-09-26T18:10:26.023246Z" + } + }, "outputs": [], "source": [ - "res = test_snia_end2end(\n", - " None, opsim_db_file=None, opsim=False, nsample=10, return_result=True, phase_rest=np.linspace(-15, 45, 20)\n", - ")" + "passbands_dir = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, \"passbands\")\n", + "res, passbands = run_snia_end2end(\n", + " oversampled_observations,\n", + " passbands_dir=passbands_dir,\n", + " nsample=20,\n", + ")\n", + "\n", + "print(f\"Produced {len(res)} samples.\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "c60e849f-8529-4714-92e1-345182a12a68", - "metadata": {}, + "id": "5c4b0574aec01df4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.595273Z", + "start_time": "2024-09-26T18:10:56.589976Z" + } + }, "outputs": [], "source": [ "hostmass = [x[\"parameter_values\"][\"hostmass\"] for x in res]\n", @@ -80,8 +130,13 @@ { "cell_type": "code", "execution_count": null, - "id": "57607e1e-e18a-4408-a7b9-ec400c72c9a8", - "metadata": {}, + "id": "aaefb2f29d444cf2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.712236Z", + "start_time": "2024-09-26T18:10:56.601978Z" + } + }, "outputs": [], "source": [ "plt.hist(hostmass)" @@ -90,8 +145,13 @@ { "cell_type": "code", "execution_count": null, - "id": "561c6c8c-84d9-484d-9759-56b020a018d6", - "metadata": {}, + "id": "b86d82fe1a518d5a", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.771308Z", + "start_time": "2024-09-26T18:10:56.719033Z" + } + }, "outputs": [], "source": [ "plt.hist(x1)\n", @@ -101,8 +161,13 @@ { "cell_type": "code", "execution_count": null, - "id": "8addd896-8f21-4813-8053-b75debf114b2", - "metadata": {}, + "id": "edc24806f5752b25", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.832246Z", + "start_time": "2024-09-26T18:10:56.777555Z" + } + }, "outputs": [], "source": [ "plt.hist(c)" @@ -111,8 +176,13 @@ { "cell_type": "code", "execution_count": null, - "id": "bd1bbdd5-f236-4e6d-84f3-161ab6cca4fd", - "metadata": {}, + "id": "35ab73410bd98ac1", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.892230Z", + "start_time": "2024-09-26T18:10:56.838414Z" + } + }, "outputs": [], "source": [ "plt.scatter(hostmass, x1)" @@ -121,8 +191,13 @@ { "cell_type": "code", "execution_count": null, - "id": "151bd30a-26f4-40eb-ba99-611f82d61edf", - "metadata": {}, + "id": "2b56df734bccfe1b", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:56.949151Z", + "start_time": "2024-09-26T18:10:56.899465Z" + } + }, "outputs": [], "source": [ "plt.hist(x0)" @@ -131,8 +206,13 @@ { "cell_type": "code", "execution_count": null, - "id": "3b9e3896-c530-419c-8bb0-a9c1984eae0d", - "metadata": {}, + "id": "1140a75e38023d80", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.012781Z", + "start_time": "2024-09-26T18:10:56.956603Z" + } + }, "outputs": [], "source": [ "plt.hist(z)" @@ -141,8 +221,13 @@ { "cell_type": "code", "execution_count": null, - "id": "b31893a6-be21-4c69-8e6c-4cfec316d309", - "metadata": {}, + "id": "75b49a8b4f8fa1b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.111919Z", + "start_time": "2024-09-26T18:10:57.020201Z" + } + }, "outputs": [], "source": [ "# cosmo = FlatLambdaCDM(H0=73, Om0=0.3)\n", @@ -170,8 +255,13 @@ { "cell_type": "code", "execution_count": null, - "id": "d7443d9e-7bbd-4087-b31e-8a4030c32b16", - "metadata": {}, + "id": "98996d76d56cb7b9", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.206103Z", + "start_time": "2024-09-26T18:10:57.118280Z" + } + }, "outputs": [], "source": [ "plt.scatter(z, mb)\n", @@ -183,8 +273,13 @@ { "cell_type": "code", "execution_count": null, - "id": "494d4400-c7ef-453d-bfa3-4abea973997e", - "metadata": {}, + "id": "320e5ddcec3fcdd", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.252874Z", + "start_time": "2024-09-26T18:10:57.212412Z" + } + }, "outputs": [], "source": [ "plt.scatter(hostmass, mu - distmod)\n", @@ -194,8 +289,13 @@ { "cell_type": "code", "execution_count": null, - "id": "55a3c72f-a231-4108-b518-779e37502cd7", - "metadata": {}, + "id": "4dce3f1ba8792c87", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.388666Z", + "start_time": "2024-09-26T18:10:57.259232Z" + } + }, "outputs": [], "source": [ "for i in range(0, 3):\n", @@ -204,7 +304,7 @@ " except Exception:\n", " continue\n", " saltpars = {\"x0\": x0[i], \"x1\": x1[i], \"c\": c[i], \"z\": z[i], \"t0\": t0[i]}\n", - " model = sncosmo.Model(\"salt2-h17\")\n", + " model = sncosmo.Model(\"salt3\")\n", " model.update(saltpars)\n", " print(saltpars)\n", " print(model.parameters)\n", @@ -217,41 +317,46 @@ { "cell_type": "code", "execution_count": null, - "id": "acdcbb01-ead0-4a20-b479-3324ccd7e562", - "metadata": {}, + "id": "6e1d0698de0bfaaa", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.917869Z", + "start_time": "2024-09-26T18:10:57.394956Z" + } + }, "outputs": [], "source": [ "for i in range(0, 3):\n", - " phase_obs = res[i][\"phase_rest\"] * (1 + z[i])\n", " times = res[i][\"times\"]\n", - " colors = \"gr\"\n", - " for color, f in zip(\"gr\", colors):\n", - " plt.plot(\n", - " times, res[i][\"bandfluxes\"][\"LSST_\" + f], \"-\", marker=\"o\", label=f, color=color, alpha=0.6, lw=2\n", - " )\n", + " colors = [\"red\", \"brown\"]\n", + " for f, color in zip(\"ri\", colors):\n", + " band_name = f\"LSST_{f}\"\n", + " plt.plot(times, res[i][\"bandfluxes\"][band_name], \"-\", label=f, color=color, alpha=0.6, lw=2)\n", " saltpars = {\"x0\": x0[i], \"x1\": x1[i], \"c\": c[i], \"z\": z[i], \"t0\": t0[i]}\n", - " model = sncosmo.Model(\"salt2-h17\")\n", + " model = sncosmo.Model(\"salt3\")\n", " model.update(saltpars)\n", " print(saltpars)\n", - " flux = model.bandflux(\"lsst\" + f, times, zpsys=\"ab\", zp=8.9 + 2.5 * 9) # -48.6)\n", + " sncosmo_band = sncosmo.Bandpass(\n", + " *passbands.passbands[band_name].processed_transmission_table.T, name=band_name\n", + " )\n", + " flux = model.bandflux(sncosmo_band, times, zpsys=\"ab\", zp=8.9 + 2.5 * 9) # -48.6)\n", " plt.plot(times, flux, \"--\", label=f, color=color)\n", + " plt.xlabel(\"MJD\")\n", + " plt.ylabel(\"Flux, nJy\")\n", " plt.legend()\n", " plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "d95053a7-1e5c-4d36-8579-82b3dd8b9cbb", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, "id": "dc680d32-cc9c-428f-90c4-944649999d9f", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T18:10:57.925337Z", + "start_time": "2024-09-26T18:10:57.924252Z" + } + }, "outputs": [], "source": [] } @@ -272,7 +377,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.4" } }, "nbformat": 4, diff --git a/src/tdastro/__init__.py b/src/tdastro/__init__.py index e69de29b..4f5f9584 100644 --- a/src/tdastro/__init__.py +++ b/src/tdastro/__init__.py @@ -0,0 +1,6 @@ +from pathlib import Path + +# Define some global directory paths to use for testing, notebooks, etc. +_TDASTRO_BASE_DIR = Path(__file__).parent.parent.parent +_TDASTRO_TEST_DIR = _TDASTRO_BASE_DIR / "tests" / "tdastro" +_TDASTRO_TEST_DATA_DIR = _TDASTRO_TEST_DIR / "data" diff --git a/src/tdastro/astro_utils/passbands.py b/src/tdastro/astro_utils/passbands.py index 024e0647..4d3091f1 100644 --- a/src/tdastro/astro_utils/passbands.py +++ b/src/tdastro/astro_utils/passbands.py @@ -78,6 +78,9 @@ def __str__(self) -> str: f"{', '.join(self.passbands.keys())}" ) + def __len__(self) -> int: + return len(self.passbands) + def _load_preset(self, preset: str, **kwargs) -> None: """Load a pre-defined set of passbands. diff --git a/tests/tdastro/astro_utils/test_opsim.py b/tests/tdastro/astro_utils/test_opsim.py index 372d44ae..746a1024 100644 --- a/tests/tdastro/astro_utils/test_opsim.py +++ b/tests/tdastro/astro_utils/test_opsim.py @@ -285,3 +285,16 @@ def test_oversample_opsim(opsim_shorten): oversampled["skyBrightness"].unique().size >= oversampled["filter"].unique().size ), "there should be at least as many skyBrightness values as bands" assert oversampled["skyBrightness"].isna().sum() == 0, "skyBrightness has NaN values" + + +def test_fixture_oversampled_observations(oversampled_observations): + """Test the fixture oversampled_observations.""" + assert len(oversampled_observations) == 36_500 + assert set(oversampled_observations["filter"]) == {"g", "r"} + assert oversampled_observations["skyBrightness"].isna().sum() == 0 + assert oversampled_observations["skyBrightness"].unique().size >= 2 + assert np.all(oversampled_observations["observationStartMJD"] >= 61406.0) + assert np.all(oversampled_observations["observationStartMJD"] <= 61771.0) + np.testing.assert_allclose(oversampled_observations["fieldRA"], 0.0) + np.testing.assert_allclose(oversampled_observations["fieldDec"], 0.0) + np.testing.assert_allclose(np.diff(oversampled_observations["observationStartMJD"]), 0.01) diff --git a/tests/tdastro/conftest.py b/tests/tdastro/conftest.py index 1116e9ad..259d28c1 100644 --- a/tests/tdastro/conftest.py +++ b/tests/tdastro/conftest.py @@ -1,15 +1,13 @@ import os.path import pytest - -DATA_DIR_NAME = "data" -TEST_DIR = os.path.dirname(__file__) +from tdastro import _TDASTRO_TEST_DATA_DIR @pytest.fixture def test_data_dir(): """Return the base test data directory.""" - return os.path.join(TEST_DIR, DATA_DIR_NAME) + return _TDASTRO_TEST_DATA_DIR @pytest.fixture @@ -36,6 +34,23 @@ def opsim_shorten(test_data_dir): return os.path.join(test_data_dir, "opsim_shorten.db") +@pytest.fixture +def oversampled_observations(opsim_shorten): + """Return an OpSim object with 0.01 day cadence spanning year 2027.""" + from tdastro.astro_utils.opsim import OpSim, oversample_opsim + + base_opsim = OpSim.from_db(opsim_shorten) + return oversample_opsim( + base_opsim, + pointing=(0.0, 0.0), + search_radius=180.0, + delta_t=0.01, + time_range=(61406.0, 61771.0), + bands=None, + strategy="darkest_sky", + ) + + @pytest.fixture def passbands_dir(test_data_dir): """Return the file path for passbands directory.""" diff --git a/tests/tdastro/sources/test_snia.py b/tests/tdastro/sources/test_snia.py index 16530736..a686146c 100644 --- a/tests/tdastro/sources/test_snia.py +++ b/tests/tdastro/sources/test_snia.py @@ -1,7 +1,6 @@ import numpy as np import sncosmo from astropy import units as u -from tdastro.astro_utils.opsim import OpSim from tdastro.astro_utils.passbands import PassbandGroup from tdastro.astro_utils.snia_utils import DistModFromRedshift, HostmassX1Func, X0FromDistMod from tdastro.astro_utils.unit_utils import flam_to_fnu @@ -12,12 +11,8 @@ def draw_single_random_sn( source, - wavelengths_rest=None, - phase_rest=None, - passbands=None, - opsim=False, - opsim_data=None, - randseed=None, + opsim, + passbands, ): """ Draw a single random SN realiztion @@ -28,17 +23,15 @@ def draw_single_random_sn( z = source.get_param(state, "redshift") wave_obs = passbands.waves wavelengths_rest = wave_obs / (1.0 + z) - phase_obs = phase_rest * (1.0 + z) - res = {"wavelengths_rest": wavelengths_rest, "phase_rest": phase_rest} + res = {"wavelengths_rest": wavelengths_rest} t0 = source.get_param(state, "t0") - times = t0 + phase_obs if opsim: ra = source.get_param(state, "ra") dec = source.get_param(state, "dec") - obs = opsim_data.get_observations(ra, dec, radius=1.75, cols=["time", "filter"]) + obs = opsim.get_observations(ra, dec, radius=1.75, cols=["time", "filter"]) times = obs["time"] phase_obs = times - t0 @@ -70,31 +63,33 @@ def draw_single_random_sn( return res -def test_snia_end2end( - opsim_small, - passbands_dir, - opsim_db_file=None, - opsim=True, - nsample=1, - return_result=False, - phase_rest=None, - wavelengths_rest=None, -): - """Test that we can sample and create SN Ia simulation using the salt3 model.""" - - opsim_data = OpSim.from_db(opsim_db_file if opsim_db_file else opsim_small) if opsim else None - - ra_min = opsim_data["fieldRA"].min() if opsim_data else 0.0 - ra_max = opsim_data["fieldRA"].max() if opsim_data else 360.0 - dec_min = opsim_data["fieldDec"].min() if opsim_data else -90.0 - dec_max = opsim_data["fieldDec"].max() if opsim_data else 33.5 - t_min = opsim_data["observationStartMJD"].min() if opsim_data else 60796.0 - t_max = opsim_data["observationStartMJD"].max() if opsim_data else 64448.0 +def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1): + """Test that we can sample and create SN Ia simulation using the salt3 model. + + Parameters + ---------- + oversampled_observations : OpSim + The opsim data to use. + passbands_dir : str + The name of the directory holding the passband information. + nsample : int + The number of samples to test. + Default: 1 + + Returns + ------- + res_list : dict + A dictionary of lists of sampling and result information. + passbands : PassbandGroup + The passbands used. + """ + t_min = oversampled_observations["observationStartMJD"].min() + t_max = oversampled_observations["observationStartMJD"].max() - # Create a host galaxy anywhere on the sky. + # Create a host galaxy. host = SNIaHost( - ra=NumpyRandomFunc("uniform", low=ra_min, high=ra_max), - dec=NumpyRandomFunc("uniform", low=dec_min, high=dec_max), + ra=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings RA = 0.0 + dec=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings Dec = 0.0 hostmass=NumpyRandomFunc("uniform", low=7, high=12), redshift=NumpyRandomFunc("uniform", low=0.01, high=0.02), ) @@ -113,7 +108,7 @@ def test_snia_end2end( m_abs=m_abs_func, ) - sncosmo_modelname = "salt2-h17" + sncosmo_modelname = "salt3" source = SncosmoWrapperModel( sncosmo_modelname, @@ -129,42 +124,38 @@ def test_snia_end2end( passbands = PassbandGroup( passband_parameters=[ { - "survey": "LSST", - "filter_name": "u", - "table_path": f"{passbands_dir}/LSST/u.dat", - "units": "nm", - }, - { - "survey": "LSST", "filter_name": "r", "table_path": f"{passbands_dir}/LSST/r.dat", - "units": "nm", + }, + { + "filter_name": "i", + "table_path": f"{passbands_dir}/LSST/u.dat", }, ], + survey="LSST", + units="nm", + trim_quantile=0.001, delta_wave=1, ) res_list = [] - - if phase_rest is None: - phase_rest = np.array([-5.0, 0.0, 10.0]) - if wavelengths_rest is None: - wavelengths_rest = np.linspace(3000, 8000, 200) - + any_valid_results = False for _n in range(0, nsample): res = draw_single_random_sn( source, - wavelengths_rest=wavelengths_rest, - phase_rest=phase_rest, + opsim=oversampled_observations, passbands=passbands, - opsim=opsim, - opsim_data=opsim_data, ) + if res is None: + continue + any_valid_results = True + state = res["state"] + p = {} for parname in ["t0", "x0", "x1", "c", "redshift", "ra", "dec"]: - p[parname] = source.get_param(state, parname) + p[parname] = float(source.get_param(state, parname)) for parname in ["hostmass"]: p[parname] = host.get_param(state, parname) for parname in ["distmod"]: @@ -174,22 +165,30 @@ def test_snia_end2end( saltpars = {"x0": p["x0"], "x1": p["x1"], "c": p["c"], "z": p["redshift"], "t0": p["t0"]} model = sncosmo.Model(sncosmo_modelname) model.update(saltpars) - z = p["redshift"] wave = passbands.waves - if opsim is None: - time = phase_rest * (1 + z) + p["t0"] - assert np.allclose(res["times"], time) - else: - time = res["times"] + time = res["times"] flux_sncosmo = model.flux(time, wave) - assert np.allclose(res["flux_flam"] * 1e10, flux_sncosmo * 1e10) + np.testing.assert_allclose(res["flux_flam"], flux_sncosmo, atol=1e-30, rtol=1e-5) - for f in passbands.passbands: - bandflux_sncosmo = model.bandflux(f.replace("_", ""), time, zpsys="ab", zp=8.9 + 2.5 * 9) - assert np.allclose(res["bandfluxes"][f], bandflux_sncosmo, rtol=0.1) + for f, passband in passbands.passbands.items(): + # Skip test for negative fluxes + if np.any(flux_sncosmo < 0): + continue + sncosmo_band = sncosmo.Bandpass(*passband.processed_transmission_table.T, name=f) + bandflux_sncosmo = model.bandflux(sncosmo_band, time, zpsys="ab", zp=8.9 + 2.5 * 9) + np.testing.assert_allclose(res["bandfluxes"][f], bandflux_sncosmo, rtol=1e-1, err_msg=f"band {f}") res_list.append(res) - if return_result: - return res_list + assert any_valid_results, f"No valid results found over all {nsample} samples." + + return res_list, passbands + + +def test_snia_end2end(oversampled_observations, passbands_dir): + """Test that the end to end run works.""" + num_samples = 1 + res_list, passbands = run_snia_end2end(oversampled_observations, passbands_dir, nsample=num_samples) + assert len(res_list) == num_samples + assert len(passbands) == 2