Skip to content

Commit

Permalink
added test function for plot functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkhadka21 committed Dec 5, 2024
1 parent f910c26 commit fa29ff9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
1 change: 0 additions & 1 deletion slsim/Plots/plot_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def plot_lightcurves(data, images=True):

# Adjust layout to avoid overlaps
plt.tight_layout()
plt.show()
return fig


Expand Down
59 changes: 59 additions & 0 deletions tests/test_Plots/test_plot_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from slsim.Plots.plot_functions import (
create_image_montage_from_image_list,
plot_montage_of_random_injected_lens,
create_montage, plot_lightcurves
)
from slsim.image_simulation import sharp_image
from slsim.Sources.source import Source
Expand Down Expand Up @@ -122,6 +123,64 @@ def test_plot_montage_of_random_injected_lens(quasar_lens_pop_instance):
assert isinstance(fig, plt.Figure)
assert fig.get_size_inches()[0] == np.array([num_cols * 3, num_rows * 3])[0]

def test_create_montage_basics():
images = [
np.random.rand(5, 5),
np.random.rand(5, 5),
np.random.rand(5, 5),
np.random.rand(5, 5),
]
montage = create_montage(images)

# Check shape
assert montage.shape == (5, 15) # 1 row, 3 images wide

# Check normalization range
assert np.min(montage) >= 0
assert np.max(montage) <= 1

def test_create_montage_specified_grid():
images = [
np.random.rand(5, 5),
np.random.rand(5, 5),
np.random.rand(5, 5),
]
grid_size = (1, 3)
montage = create_montage(images, grid_size=grid_size)

# Check shape
assert montage.shape == (5, 15) # 1 row, 3 images wide

def test_plot_lightcurves():
data = {
"magnitudes": {
"mag_image_1": {"g": np.random.rand(5), "r": np.random.rand(5)},
"mag_image_2": {"g": np.random.rand(5), "r": np.random.rand(5)},
},
"errors_low": {
"mag_error_image_1_low": {"g": np.random.rand(5), "r": np.random.rand(5)},
"mag_error_image_2_low": {"g": np.random.rand(5), "r": np.random.rand(5)},
},
"errors_high": {
"mag_error_image_1_high": {"g": np.random.rand(5), "r": np.random.rand(5)},
"mag_error_image_2_high": {"g": np.random.rand(5), "r": np.random.rand(5)},
},
"obs_time": {"g": np.arange(5), "r": np.arange(5)},
"image_lists": {
"g": [np.random.rand(10, 10) for _ in range(3)],
"r": [np.random.rand(10, 10) for _ in range(3)],
},
}

fig = plot_lightcurves(data, images=True)
fig2 = plot_lightcurves(data, images=False)
ax1=fig.get_axes()
ax2=fig2.get_axes()
assert fig is not None
assert isinstance(fig, plt.Figure)
assert fig2 is not None
assert isinstance(fig2, plt.Figure)
assert len(ax1) == len(ax2)+2

if __name__ == "__main__":
pytest.main()

0 comments on commit fa29ff9

Please sign in to comment.