From 1988aa3d7059134b0fa10e9907b669f665e09659 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:23:59 -0400 Subject: [PATCH] Fix how node_label is passed by sncosmo models. --- src/tdastro/example_runs/simulate_snia.py | 3 +++ src/tdastro/sources/sncomso_models.py | 9 +++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/tdastro/example_runs/simulate_snia.py b/src/tdastro/example_runs/simulate_snia.py index 46c8043..6df68b8 100644 --- a/src/tdastro/example_runs/simulate_snia.py +++ b/src/tdastro/example_runs/simulate_snia.py @@ -103,6 +103,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1): dec=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings Dec = 0.0 hostmass=NumpyRandomFunc("uniform", low=7, high=12), redshift=NumpyRandomFunc("uniform", low=0.1, high=0.4), + node_label="host", ) distmod_func = DistModFromRedshift(host.redshift, H0=73.0, Omega_m=0.3) @@ -117,6 +118,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1): alpha=0.14, beta=3.1, m_abs=m_abs_func, + node_label="x0_func", ) sncosmo_modelname = "salt3" @@ -130,6 +132,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1): ra=NumpyRandomFunc("normal", loc=host.ra, scale=0.01), dec=NumpyRandomFunc("normal", loc=host.dec, scale=0.01), redshift=host.redshift, + node_label="source", ) passbands = PassbandGroup( diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index f02a974..596a815 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -27,12 +27,17 @@ class SncosmoWrapperModel(PhysicalModel): ---------- source_name : `str` The name used to set the source. + t0 : `float` + The start time of the sncosmo model. + Default: 0.0 + node_label : `str`, optional + An identifier (or name) for the current node. **kwargs : `dict`, optional Any additional keyword arguments. """ - def __init__(self, source_name, t0=0.0, **kwargs): - super().__init__(**kwargs) + def __init__(self, source_name, t0=0.0, node_label=None, **kwargs): + super().__init__(node_label=node_label, **kwargs) self.source_name = source_name self.source = get_source(source_name)