diff --git a/tests/test_plots.py b/tests/test_plots.py index d0015c5..004b66d 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -13,7 +13,7 @@ def test_plots(start_prof=0, end_prof=100): ds = ds.drop_vars(['DENSITY']) fig, ax = plots.plot_basic_vars(ds, start_prof=start_prof, end_prof=end_prof) assert ax[0].get_ylabel() == 'Depth (m)' - assert ax[0].get_xlabel() == 'Temperature \n(°C)' + assert ax[0].get_xlabel() == f'{utilities.plotting_labels("TEMP")} \n({utilities.plotting_units(ds,"TEMP")})' def test_up_down_bias(v_res=1): @@ -24,20 +24,20 @@ def test_up_down_bias(v_res=1): lims = np.abs(df.dc) assert ax.get_xlim() == (-np.nanpercentile(lims, 99.5), np.nanpercentile(lims, 99.5)) assert ax.get_ylim() == (df.depth.max() + 1, -df.depth.max() / 30) - assert ax.get_xlabel() == 'Practical salinity (PSU)' + assert ax.get_xlabel() == f'{utilities.plotting_labels("PSAL")} ({utilities.plotting_units(ds,"PSAL")})' # check without passing axis new_fig, new_ax = plots.plot_updown_bias(ds, var='PSAL', v_res=1) assert new_ax.get_xlim() == (-np.nanpercentile(lims, 99.5), np.nanpercentile(lims, 99.5)) assert new_ax.get_ylim() == (df.depth.max() + 1, -df.depth.max() / 30) - assert new_ax.get_xlabel() == 'Practical salinity (PSU)' + assert new_ax.get_xlabel() == f'{utilities.plotting_labels("PSAL")} ({utilities.plotting_units(ds,"PSAL")})' def test_chl(var1='CHLA', var2='BBP700'): ds = fetchers.load_sample_dataset() ax = plots.process_optics_assess(ds, var=var1) - assert ax.get_ylabel() == 'Chlorophyll (mg m⁻³)' + assert ax.get_ylabel() == f'{utilities.plotting_labels(var1)} ({utilities.plotting_units(ds,var1)})' ax = plots.process_optics_assess(ds, var=var2) - assert ax.get_ylabel() == 'Red backscatter, b${bp}$(700) (m⁻¹)' + assert ax.get_ylabel() == f'{utilities.plotting_labels(var2)} ({utilities.plotting_units(ds,var2)})' with pytest.raises(KeyError) as e: plots.process_optics_assess(ds, var='nonexistent_variable') @@ -53,7 +53,7 @@ def test_quench_sequence(ylim=45): fig, ax = plots.plot_daynight_avg(ds, var='TEMP') assert ax.get_ylabel() == 'Depth (m)' - assert ax.get_xlabel() == 'Temperature (°C)' + assert ax.get_xlabel() == f'{utilities.plotting_labels("TEMP")} ({utilities.plotting_units(ds,"TEMP")})' def test_temporal_drift(var='DOXY'): @@ -128,4 +128,4 @@ def test_plot_sampling_period_all(): def test_plot_max_depth(): ds = fetchers.load_sample_dataset() - plots.plot_max_depth_per_profile(ds) \ No newline at end of file + plots.plot_max_depth_per_profile(ds)