Skip to content

Commit

Permalink
Fix how node_label is passed by sncosmo models.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 25, 2024
1 parent 7121cf6 commit 1988aa3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/tdastro/example_runs/simulate_snia.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1):
dec=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings Dec = 0.0
hostmass=NumpyRandomFunc("uniform", low=7, high=12),
redshift=NumpyRandomFunc("uniform", low=0.1, high=0.4),
node_label="host",
)

distmod_func = DistModFromRedshift(host.redshift, H0=73.0, Omega_m=0.3)
Expand All @@ -117,6 +118,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1):
alpha=0.14,
beta=3.1,
m_abs=m_abs_func,
node_label="x0_func",
)

sncosmo_modelname = "salt3"
Expand All @@ -130,6 +132,7 @@ def run_snia_end2end(oversampled_observations, passbands_dir, nsample=1):
ra=NumpyRandomFunc("normal", loc=host.ra, scale=0.01),
dec=NumpyRandomFunc("normal", loc=host.dec, scale=0.01),
redshift=host.redshift,
node_label="source",
)

passbands = PassbandGroup(
Expand Down
9 changes: 7 additions & 2 deletions src/tdastro/sources/sncomso_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ class SncosmoWrapperModel(PhysicalModel):
----------
source_name : `str`
The name used to set the source.
t0 : `float`
The start time of the sncosmo model.
Default: 0.0
node_label : `str`, optional
An identifier (or name) for the current node.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, source_name, t0=0.0, **kwargs):
super().__init__(**kwargs)
def __init__(self, source_name, t0=0.0, node_label=None, **kwargs):
super().__init__(node_label=node_label, **kwargs)
self.source_name = source_name
self.source = get_source(source_name)

Expand Down

0 comments on commit 1988aa3

Please sign in to comment.