Skip to content

Commit

Permalink
Add the module to the object's name string
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jul 17, 2024
1 parent 27fc159 commit 39092b4
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 23 deletions.
7 changes: 0 additions & 7 deletions src/tdastro/astro_utils/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,10 @@ class RedshiftDistFunc(FunctionNode):
"""

def __init__(self, redshift, cosmology, **kwargs):
# Augment the node identifier string to include the cosmology name.
if "node_identifier" in kwargs:
node_identifier = f"{kwargs['node_identifier']}({cosmology.name})"
else:
node_identifier = f"{cosmology.name}"

# Call the super class's constructor with the needed information.
super().__init__(
func=redshift_to_distance,
redshift=redshift,
cosmology=cosmology,
node_identifier=node_identifier,
**kwargs,
)
5 changes: 3 additions & 2 deletions src/tdastro/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __init__(self, node_identifier=None, graph_base_seed=None, **kwargs):

def __str__(self):
"""Return the string representation of the node."""
name = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
if self.node_identifier:
return f"{self.node_identifier}={self.__class__.__name__}"
return f"{self.node_identifier}={name}"
else:
return self.__class__.__name__
return name

def set_graph_base_seed(self, graph_base_seed):
"""Set a new graph base seed.
Expand Down
2 changes: 1 addition & 1 deletion tests/tdastro/sources/test_sncosmo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_sncomso_models_hsiao() -> None:
model = SncosmoWrapperModel("hsiao")
model.set(amplitude=1.0e-10)
assert model.amplitude == 1.0e-10
assert str(model) == "SncosmoWrapperModel"
assert str(model) == "tdastro.sources.sncomso_models.SncosmoWrapperModel"

assert np.array_equal(model.param_names, ["amplitude"])
assert np.array_equal(model.parameters, [1.0e-10])
Expand Down
4 changes: 2 additions & 2 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_spline_model_flat() -> None:
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)
assert str(model) == "SplineModel"
assert str(model) == "tdastro.sources.spline_model.SplineModel"

test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0])
test_waves = np.array([0.0, 100.0, 200.0, 1000.0])
Expand All @@ -19,7 +19,7 @@ def test_spline_model_flat() -> None:
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0, node_identifier="test")
assert str(model2) == "test=SplineModel"
assert str(model2) == "test=tdastro.sources.spline_model.SplineModel"

values2 = model2.evaluate(test_times, test_waves)
assert values2.shape == (5, 4)
Expand Down
4 changes: 2 additions & 2 deletions tests/tdastro/sources/test_static_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_static_source() -> None:
assert model.ra is None
assert model.dec is None
assert model.distance is None
assert str(model) == "my_static_source=StaticSource"
assert str(model) == "my_static_source=tdastro.sources.static_source.StaticSource"

times = np.array([1, 2, 3, 4, 5, 10])
wavelengths = np.array([100.0, 200.0, 300.0])
Expand All @@ -49,7 +49,7 @@ def test_static_source_host() -> None:
assert model.ra == 1.0
assert model.dec == 2.0
assert model.distance == 3.0
assert str(model) == "StaticSource"
assert str(model) == "tdastro.sources.static_source.StaticSource"


def test_static_source_resample() -> None:
Expand Down
18 changes: 9 additions & 9 deletions tests/tdastro/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def test_parameterized_node_attributes():

settings = model1.get_all_parameter_values(True)
assert len(settings) == 3
assert settings["1=PairModel.value1"] == 0.5
assert settings["1=PairModel.value2"] == 1.5
assert settings["1=PairModel.value_sum"] == 2.0
assert settings["1=test_base_models.PairModel.value1"] == 0.5
assert settings["1=test_base_models.PairModel.value2"] == 1.5
assert settings["1=test_base_models.PairModel.value_sum"] == 2.0

# Use value1=model.value and value2=3.0
model2 = PairModel(value1=model1, value2=3.0, node_identifier="2")
Expand All @@ -187,12 +187,12 @@ def test_parameterized_node_attributes():

settings = model2.get_all_parameter_values(True)
assert len(settings) == 6
assert settings["1=PairModel.value1"] == 0.5
assert settings["1=PairModel.value2"] == 1.5
assert settings["1=PairModel.value_sum"] == 2.0
assert settings["2=PairModel.value1"] == 0.5
assert settings["2=PairModel.value2"] == 3.0
assert settings["2=PairModel.value_sum"] == 3.5
assert settings["1=test_base_models.PairModel.value1"] == 0.5
assert settings["1=test_base_models.PairModel.value2"] == 1.5
assert settings["1=test_base_models.PairModel.value_sum"] == 2.0
assert settings["2=test_base_models.PairModel.value1"] == 0.5
assert settings["2=test_base_models.PairModel.value2"] == 3.0
assert settings["2=test_base_models.PairModel.value_sum"] == 3.5


def test_parameterized_node_get_dependencies():
Expand Down

0 comments on commit 39092b4

Please sign in to comment.