diff --git a/gliderpy/plotting.py b/gliderpy/plotting.py index c47516e..d0f03c8 100644 --- a/gliderpy/plotting.py +++ b/gliderpy/plotting.py @@ -42,11 +42,11 @@ def plot_track(df: pd.DataFrame) -> tuple(plt.Figure, plt.Axes): ax.set_extent([x.min() - dx, x.max() + dx, y.min() - dy, y.max() + dy]) return fig, ax - @register_dataframe_method def plot_transect( df: pd.DataFrame, var: str, + ax: plt.Axes = None, **kw: dict, ) -> tuple(plt.Figure, plt.Axes): """Make a scatter plot of depth vs time coloured by a user defined @@ -57,7 +57,18 @@ def plot_transect( """ cmap = kw.get("cmap", None) - fig, ax = plt.subplots(figsize=(17, 2)) + fignums = plt.get_fignums() + if ax is None and not fignums: + fig, ax = plt.subplots(figsize=(17, 2)) + elif ax: + fig = ax.get_figure() + else: + ax = plt.gca() + fig = plt.gcf() + + if not ax.yaxis_inverted(): + ax.invert_yaxis() + cs = ax.scatter( df.index, df["pressure"], @@ -68,11 +79,15 @@ def plot_transect( cmap=cmap, ) - ax.invert_yaxis() xfmt = mdates.DateFormatter("%H:%Mh\n%d-%b") ax.xaxis.set_major_formatter(xfmt) cbar = fig.colorbar(cs, orientation="vertical", extend="both") cbar.ax.set_ylabel(var) ax.set_ylabel("pressure") + + ax.set_ylim(ax.get_ylim()[0], 0) + return fig, ax + +