diff --git a/jdaviz/configs/cubeviz/plugins/spectral_extraction/tests/test_spectral_extraction.py b/jdaviz/configs/cubeviz/plugins/spectral_extraction/tests/test_spectral_extraction.py index 295432051e..0cfc4044d6 100644 --- a/jdaviz/configs/cubeviz/plugins/spectral_extraction/tests/test_spectral_extraction.py +++ b/jdaviz/configs/cubeviz/plugins/spectral_extraction/tests/test_spectral_extraction.py @@ -570,6 +570,24 @@ def test_default_spectral_extraction(cubeviz_helper, spectrum1d_cube_fluxunit_jy ) +def test_spectral_extraction_unit_conv_one_spec( + cubeviz_helper, spectrum1d_cube_fluxunit_jy_per_steradian +): + cubeviz_helper.load_data(spectrum1d_cube_fluxunit_jy_per_steradian) + spectrum_viewer = cubeviz_helper.app.get_viewer( + cubeviz_helper._default_spectrum_viewer_reference_name) + uc = cubeviz_helper.plugins["Unit Conversion"] + assert uc.flux_unit == "Jy" + uc.flux_unit.selected = "MJy" + spec_extr_plugin = cubeviz_helper.plugins['Spectral Extraction'] + # Overwrite the one and only default extraction. + collapsed = spec_extr_plugin.extract() + # Actual values not in display unit but should not affect display unit. + assert collapsed.flux.unit == u.Jy + assert uc.flux_unit.selected == "MJy" + assert spectrum_viewer.state.y_display_unit == "MJy" + + @pytest.mark.usefixtures('_jail') @pytest.mark.remote_data @pytest.mark.parametrize( diff --git a/jdaviz/configs/specviz/tests/test_viewers.py b/jdaviz/configs/specviz/tests/test_viewers.py index 2c40e390ff..ba706c32e1 100644 --- a/jdaviz/configs/specviz/tests/test_viewers.py +++ b/jdaviz/configs/specviz/tests/test_viewers.py @@ -26,6 +26,21 @@ def test_spectrum_viewer_axis_labels(specviz_helper, input_unit, y_axis_label): assert (y_axis_label in label) +@pytest.mark.xfail(reason="FIXME: Some callback magic needs to happen somewhere.") +def test_spectrum_viewer_keep_unit_when_removed(specviz_helper, spectrum1d): + specviz_helper.load_data(spectrum1d, data_label="Test") + uc = specviz_helper.plugins["Unit Conversion"] + assert uc.flux_unit == "Jy" + uc.flux_unit = "MJy" + specviz_helper.app.remove_data_from_viewer("spectrum-viewer", "Test") + specviz_helper.app.add_data_to_viewer("spectrum-viewer", "Test") + # Actual values not in display unit but should not affect display unit. + spec = specviz_helper.get_spectra(data_label="Test", apply_slider_redshift=False) + assert spec.flux.unit == u.Jy + assert uc.flux_unit.selected == "MJy" + assert specviz_helper.app._get_display_unit('spectral_y') == "MJy" + + class TestResetLimitsTwoTests: """See https://github.com/spacetelescope/lcviz/pull/93""" diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index fd7ff237c3..7d1c27be61 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -3910,7 +3910,13 @@ def add_results_from_plugin(self, data_item, replace=None, label=None): add_to_viewer_vis = [True] preserved_attributes = [{}] + enforce_flux_unit = None if label in self.app.data_collection: + if self.app.config == "cubeviz": + sv = self.app.get_viewer( + self.app._jdaviz_helper._default_spectrum_viewer_reference_name) + if len(sv.state.layers) == 1: + enforce_flux_unit = self.app._get_display_unit('spectral_y') for viewer_ref in add_to_viewer_refs: self.app.remove_data_from_viewer(viewer_ref, label) self.app.data_collection.remove(self.app.data_collection[label]) @@ -3943,6 +3949,9 @@ def add_results_from_plugin(self, data_item, replace=None, label=None): label, visible=visible, clear_other_data=this_replace) + if enforce_flux_unit: + sv.state.y_display_unit = enforce_flux_unit + if preserved != {}: layer_state = [layer.state for layer in this_viewer.layers if layer.layer.label == label][0]