Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Added function to create nice labels and units #161

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 47 additions & 54 deletions glidertest/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@



def plot_updown_bias(df: pd.DataFrame, ax: plt.Axes = None, xlabel='Temperature [C]', **kw: dict, ) -> tuple({plt.Figure, plt.Axes}):
def plot_updown_bias(ds: xr.Dataset, var='TEMP', v_res=1, ax: plt.Axes = None, **kw: dict, ) -> tuple({plt.Figure, plt.Axes}):
"""
This function can be used to plot the up and downcast differences computed with the updown_bias function

Parameters
----------
df: pandas dataframe containing dc (Dive - Climb average), cd (Climb - Dive average) and depth
ds: xarray on OG1 format containing at least time, depth, latitude, longitude and the selected variable.
Data should not be gridded.
var: Selected variable
v_res: Vertical resolution for the gridding
ax: axis to plot the data
xlabel: label for the x-axis

Returns
-------
A line plot comparing the day and night average over depth for the selected day
A line plot comparing the climb and dive average over depth

Original author
----------------
Expand All @@ -41,27 +43,31 @@ def plot_updown_bias(df: pd.DataFrame, ax: plt.Axes = None, xlabel='Temperature
if ax is None:
fig, ax = plt.subplots()
third_width = fig.get_size_inches()[0] / 3.11
fig.set_size_inches(third_width, third_width *1.1)
fig.set_size_inches(third_width, third_width * 1.1)
force_plot = True
else:
fig = plt.gcf()
force_plot = False

df = tools.quant_updown_bias(ds, var=var, v_res=v_res)
if not all(hasattr(df, attr) for attr in ['dc', 'depth']):
ax.text(0.5, 0.55, xlabel, va='center', ha='center', transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
ax.text(0.5, 0.45, 'data unavailable', va='center', ha='center', transform=ax.transAxes, bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
ax.text(0.5, 0.55, ds[var].standard_name, va='center', ha='center', transform=ax.transAxes,
bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
ax.text(0.5, 0.45, 'data unavailable', va='center', ha='center', transform=ax.transAxes,
bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))
else:
ax.plot(df.dc, df.depth, label='Dive-Climb')
ax.plot(df.cd, df.depth, label='Climb-Dive')
ax.plot(df.dc, df.depth, label='Dive-Climb', **kw)
ax.plot(df.cd, df.depth, label='Climb-Dive', **kw)
ax.legend(loc=3)
lims = np.abs(df.dc)
ax.set_xlim(-np.nanpercentile(lims, 99.5), np.nanpercentile(lims, 99.5))
ax.set_ylim(df.depth.max() + 1, -df.depth.max() / 30)
ax.set_xlabel(xlabel)
ax.set_xlabel(f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})')
ax.set_ylabel(f'Depth (m)')
ax.grid()
if force_plot:
plt.show()
return fig, ax
return fig, ax

def plot_basic_vars(ds: xr.Dataset, v_res=1, start_prof=0, end_prof=-1):
"""
Expand Down Expand Up @@ -113,16 +119,16 @@ def plot_basic_vars(ds: xr.Dataset, v_res=1, start_prof=0, end_prof=-1):
ax1.plot(np.nanmean(salG, axis=0), depthG[0, :], c='red')
ax2.plot(np.nanmean(denG, axis=0), depthG[0, :], c='black')

ax[0].set(xlabel=f'Temperature [C]', ylabel='Depth (m)')
ax[0].set(ylabel='Depth (m)', xlabel=f'{utilities.plotting_labels("TEMP")} \n({utilities.plotting_units(ds,"TEMP")})')
ax[0].tick_params(axis='x', colors='blue')
ax[0].xaxis.label.set_color('blue')
ax1.spines['bottom'].set_color('blue')
ax1.set(xlabel=f'Salinity [PSU]')
ax1.set(xlabel=f'{utilities.plotting_labels("PSAL")} ({utilities.plotting_units(ds,"PSAL")})')
ax1.xaxis.label.set_color('red')
ax1.spines['top'].set_color('red')
ax1.tick_params(axis='x', colors='red')
ax2.spines['bottom'].set_color('black')
ax2.set(xlabel=f'Density [kg m-3]')
ax2.set(xlabel=f'{utilities.plotting_labels("DENSITY")} ({utilities.plotting_units(ds,"DENSITY")})')
ax2.xaxis.label.set_color('black')
ax2.spines['top'].set_color('black')
ax2.tick_params(axis='x', colors='black')
Expand All @@ -136,7 +142,7 @@ def plot_basic_vars(ds: xr.Dataset, v_res=1, start_prof=0, end_prof=-1):
chlaG = chlaG[start_prof:end_prof, :]
ax2_1 = ax[1].twiny()
ax2_1.plot(np.nanmean(chlaG, axis=0), depthG[0, :], c='green')
ax2_1.set(xlabel=f'Chlorophyll-a [mg m-3]')
ax2_1.set(xlabel=f'{utilities.plotting_labels("CHLA")} ({utilities.plotting_units(ds,"CHLA")})')
ax2_1.xaxis.label.set_color('green')
ax2_1.spines['top'].set_color('green')
ax2_1.tick_params(axis='x', colors='green')
Expand All @@ -147,7 +153,7 @@ def plot_basic_vars(ds: xr.Dataset, v_res=1, start_prof=0, end_prof=-1):
oxyG, profG, depthG = utilities.construct_2dgrid(ds.PROFILE_NUMBER, ds.DEPTH, ds.DOXY, p, z)
oxyG = oxyG[start_prof:end_prof, :]
ax[1].plot(np.nanmean(oxyG, axis=0), depthG[0, :], c='orange')
ax[1].set(xlabel=f'Oxygen [mmol m-3]')
ax[1].set(xlabel=f'{utilities.plotting_labels("DOXY")} \n({utilities.plotting_units(ds,"DOXY")})')
ax[1].xaxis.label.set_color('orange')
ax[1].spines['top'].set_color('orange')
ax[1].tick_params(axis='x', colors='orange')
Expand Down Expand Up @@ -212,7 +218,7 @@ def process_optics_assess(ds, var='CHLA'):
ax.grid()
ax.set(ylim=(np.nanpercentile(bottom_opt_data, 0.5), np.nanpercentile(bottom_opt_data, 99.5)),
xlabel='Measurements',
ylabel=var)
ylabel=f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})')
plt.show()
percentage_change = (((slope * len(bottom_opt_data) + intercept) - intercept) / abs(intercept)) * 100

Expand All @@ -227,18 +233,16 @@ def process_optics_assess(ds, var='CHLA'):
return ax


def plot_daynight_avg(day: pd.DataFrame, night: pd.DataFrame, ax: plt.Axes = None, sel_day=None,
xlabel='Chlorophyll [mg m-3]', **kw: dict, ) -> tuple({plt.Figure, plt.Axes}):
def plot_daynight_avg(ds,var='PSAL', ax: plt.Axes = None, sel_day=None, **kw: dict, ) -> tuple({plt.Figure, plt.Axes}):
"""
This function can be used to plot the day and night averages computed with the day_night_avg function

Parameters
----------
day: pandas dataframe containing the day averages
night: pandas dataframe containing the night averages
ds: xarray dataset in OG1 format containing at least time, depth and the selected variable
var: name of the selected variable
ax: axis to plot the data
sel_day: selected day to plot. Defaults to the median day
xlabel: label for the x-axis

Returns
-------
Expand All @@ -249,6 +253,7 @@ def plot_daynight_avg(day: pd.DataFrame, night: pd.DataFrame, ax: plt.Axes = Non
Chiara Monforte

"""
day, night = tools.compute_daynight_avg(ds, sel_var=var)
if not sel_day:
dates = list(day.date.dropna().values) + list(night.date.dropna().values)
dates.sort()
Expand All @@ -260,14 +265,15 @@ def plot_daynight_avg(day: pd.DataFrame, night: pd.DataFrame, ax: plt.Axes = Non
else:
fig = plt.gcf()
force_plot = False

ax.plot(night.where(night.date == sel_day).dropna().dat, night.where(night.date == sel_day).dropna().depth,
label='Night time average')
ax.plot(day.where(day.date == sel_day).dropna().dat, day.where(day.date == sel_day).dropna().depth,
label='Daytime average')
ax.legend()
ax.invert_yaxis()
ax.grid()
ax.set(xlabel=xlabel, ylabel='Depth [m]')
ax.set(xlabel=f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})', ylabel='Depth (m)')
ax.set_title(sel_day)
if force_plot:
plt.show()
Expand Down Expand Up @@ -339,13 +345,13 @@ def plot_quench_assess(ds: xr.Dataset, sel_var: str, ax: plt.Axes = None, start_
ax.axvline(np.unique(n), c='blue')
for m in np.unique(sunrise):
ax.axvline(np.unique(m), c='orange')
ax.set_ylabel('Depth [m]')
ax.set_ylabel('Depth (m)')

# Set x-tick labels based on duration of the selection
# Could pop out as a utility plotting function?
utilities._time_axis_formatter(ax, ds_sel, format_x_axis=True)

plt.colorbar(c, label=f'{sel_var} [{ds[sel_var].units}]')
plt.colorbar(c, label=f'{utilities.plotting_labels(sel_var)} ({utilities.plotting_units(ds,sel_var)})')
plt.show()
return fig, ax

Expand Down Expand Up @@ -383,10 +389,10 @@ def check_temporal_drift(ds: xr.Dataset, var: str, ax: plt.Axes = None, **kw: di
# Set x-tick labels based on duration of the selection
utilities._time_axis_formatter(ax[0], ds, format_x_axis=True)

ax[0].set(ylim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel=var)
ax[0].set(ylim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel=f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})')

c = ax[1].scatter(ds[var], ds.DEPTH, c=mdates.date2num(ds.TIME), s=10)
ax[1].set(xlim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel='Depth (m)', xlabel=var)
ax[1].set(xlim=(np.nanpercentile(ds[var], 0.01), np.nanpercentile(ds[var], 99.99)), ylabel='Depth (m)', xlabel=f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})')
ax[1].invert_yaxis()

[a.grid() for a in ax]
Expand Down Expand Up @@ -496,8 +502,8 @@ def plot_glider_track(ds: xr.Dataset, ax: plt.Axes = None, **kw: dict) -> tuple(
ax.add_feature(cfeature.OCEAN)
ax.add_feature(cfeature.COASTLINE)

ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')
ax.set_xlabel(f'Longitude')
ax.set_ylabel(f'Latitude')
ax.set_title('Glider Track')
gl = ax.gridlines(draw_labels=True, color='black', alpha=0.5, linestyle='--')
gl.top_labels = False
Expand Down Expand Up @@ -694,7 +700,7 @@ def plot_sampling_period(ds: xr.Dataset, ax: plt.Axes = None, variable='TEMP'):
ax.set_xlabel('Time Spacing (s)')
if variable=='TEMP': ax.set_ylabel('Frequency')
ax.set_title('Histogram of Sampling Period' + '\n' +
'for ' + variable + ', \n' +
'for ' + utilities.plotting_labels(variable) + ', \n' +
'valid values: {:.1f}'.format(100*(np.sum(nonan)/ds.TIME.values.shape[0]))+'%')

annotation_text = (
Expand Down Expand Up @@ -1083,7 +1089,7 @@ def plot_hysteresis(ds, var='DOXY', v_res=1, perct_err=2, ax=None):
[a.grid() for a in ax]
[a.invert_yaxis() for a in ax]
ax[0].set_ylabel('Depth (m)')
ax[0].set_xlabel(f'{var} concentration $=mean$ \n({ds[var].units})')
ax[0].set_xlabel(f'{utilities.plotting_labels(var)} $=mean$ \n({utilities.plotting_units(ds, var)})')
ax[1].set_xlabel(f'Absolute difference = |$\Delta$| \n({ds[var].units})')
ax[2].set_xlabel('Percent error = |$\Delta$|/$mean$ \n(%)')
for ax1 in ax[:-1]:
Expand All @@ -1093,7 +1099,7 @@ def plot_hysteresis(ds, var='DOXY', v_res=1, perct_err=2, ax=None):
vmax=np.nanpercentile(np.diff(varG, axis=0), 99.5), cmap='seismic')
plt.colorbar(c, ax=ax[3], label=f'Difference dive-climb \n({ds[var].units})', fraction=0.05)
ax[3].set(ylabel='Depth (m)', xlabel='Profile number')
fig.suptitle(var, y=.98)
fig.suptitle(utilities.plotting_labels(var), y=.98)
if force_plot:
plt.show()
return fig, ax
Expand All @@ -1108,18 +1114,14 @@ def plot_outlier_duration(ds: xr.Dataset, rolling_mean: pd.Series, overtime, std
Parameters
----------
ds : An xarray object containing at least the variables 'TIME', 'DEPTH', and 'PROFILE_NUMBER'.
These are used to compute the profile durations and plot depth profiles.

These are used to compute the profile durations and plot depth profiles.
rolling_mean : A series representing the rolling mean of the profile durations,
which is used to highlight outliers based on standard deviation.

overtime : A list of profile numbers identified as having unusual durations.
These profiles are marked on the plot to highlight the outliers.

which is used to highlight outliers based on standard deviation.
overtime : A list of profile numbers identified as having unusual durations.
These profiles are marked on the plot to highlight the outliers.
std : float, optional, default 2
The number of standard deviations above and below the rolling mean that will be used to define the range
of "normal" durations. Profiles outside this range are considered outliers.

ax :The axes object on which to plot the results. If not provided, a new figure with two subplots is created.

Returns
Expand Down Expand Up @@ -1185,13 +1187,9 @@ def plot_global_range(ds, var='DOXY', min_val=-5, max_val=600, ax=None):
Parameters
----------
ds : The xarray dataset containing the variable (`var`) to be plotted.

var : The name of the variable to plot. Default is 'DOXY'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general guidance, try not to add these stylistic fixes to unrelated code in a PR, it makes the review process more confusing! You can do a seperate PR that just fixes typos and style etc. Or, if you're feeling extra cool, you can check out python linters like astral

We can leave them in for this PR though :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I did remove it haha

min_val : The minimum value of the global range to highlight on the plot. Default is -5.

max_val : The maximum value of the global range to highlight on the plot. Default is 600.

var : The name of the variable to plot.
min_val : The minimum value of the global range to highlight on the plot.
max_val : The maximum value of the global range to highlight on the plot.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the histogram. If `None`, a new figure and axes are created.
Default is `None`.
Expand All @@ -1200,7 +1198,6 @@ def plot_global_range(ds, var='DOXY', min_val=-5, max_val=600, ax=None):
-------
fig : matplotlib.figure.Figure
The figure object containing the plot.

ax : matplotlib.axes.Axes
The axes object containing the histogram plot.

Expand All @@ -1219,7 +1216,7 @@ def plot_global_range(ds, var='DOXY', min_val=-5, max_val=600, ax=None):
ax.hist(ds[var], bins=50)
ax.axvline(min_val, c='r')
ax.axvline(max_val, c='r')
ax.set(xlabel=f'{ds[var].long_name} ({ds[var].units})', ylabel='Frequency')
ax.set(xlabel=f'{utilities.plotting_labels(var)} ({utilities.plotting_units(ds,var)})', ylabel='Frequency')
ax.set_title('Global range check')
ax.grid()
if force_plot:
Expand All @@ -1237,17 +1234,13 @@ def plot_ioosqc(data, suspect_threshold=[25], fail_threshold=[50], title='', ax=
-----------
data : The result from the IOOS_QC test.
A sequence of numerical values representing the data points to be plotted.

suspect_threshold : A list containing one or two numerical values indicating the thresholds for suspect values. If one value is provided,
it applies to both lower and upper bounds for suspect data points. If two values are provided, they define the
lower and upper bounds for suspect values.

fail_threshold A list containing one or two numerical values indicating the thresholds for fail values. Similar to `suspect_threshold`,
it can have one or two values to define the bounds for fail data points.

title : str, optional, default = ''
The title to display at the top of the plot.

ax : matplotlib Axes object, optional, default = None
If provided, the plot will be drawn on this existing Axes object. If None, a new figure and axis will be created.

Expand Down Expand Up @@ -1299,7 +1292,7 @@ def plot_ioosqc(data, suspect_threshold=[25], fail_threshold=[50], title='', ax=

ax2.set_yticklabels(a_2, fontsize=12)

ax.set_xlabel('Data Index')
ax.set_xlabel('Data index')
ax.grid()
ax.set_title(title)
if force_plot:
Expand Down
Loading
Loading