diff --git a/gliderpy/__init__.py b/gliderpy/__init__.py index 8d2cbf0..9238e5b 100644 --- a/gliderpy/__init__.py +++ b/gliderpy/__init__.py @@ -6,9 +6,10 @@ __version__ = "unknown" from .fetchers import GliderDataFetcher -from .plotting import plot_transect +from .plotting import plot_track, plot_transect __all__ = [ "GliderDataFetcher", + "plot_track", "plot_transect", ] diff --git a/gliderpy/plotting.py b/gliderpy/plotting.py index d1e5d33..c47516e 100644 --- a/gliderpy/plotting.py +++ b/gliderpy/plotting.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING try: + import cartopy.crs as ccrs import matplotlib.dates as mdates import matplotlib.pyplot as plt except ModuleNotFoundError: @@ -22,6 +23,26 @@ from pandas_flavor import register_dataframe_method +@register_dataframe_method +def plot_track(df: pd.DataFrame) -> tuple(plt.Figure, plt.Axes): + """Plot a track of glider path coloured by temperature. + + :return: figures, axes + """ + x = df["longitude"] + y = df["latitude"] + dx, dy = 2, 4 + + fig, ax = plt.subplots( + figsize=(9, 9), + subplot_kw={"projection": ccrs.PlateCarree()}, + ) + ax.scatter(x, y, c=None, s=25, alpha=0.25, edgecolor="none") + ax.coastlines("10m") + ax.set_extent([x.min() - dx, x.max() + dx, y.min() - dy, y.max() + dy]) + return fig, ax + + @register_dataframe_method def plot_transect( df: pd.DataFrame, diff --git a/tests/baseline/test_plot_track.png b/tests/baseline/test_plot_track.png new file mode 100644 index 0000000..2e3b6f7 Binary files /dev/null and b/tests/baseline/test_plot_track.png differ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index cbf15b1..d1bca53 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -5,13 +5,28 @@ import pytest from gliderpy.fetchers import GliderDataFetcher -from gliderpy.plotting import plot_transect +from gliderpy.plotting import plot_track, plot_transect root = Path(__file__).parent +@pytest.mark.mpl_image_compare(baseline_dir=root.joinpath("baseline/")) +def test_plot_track(): + """Image comparison test for plot_track.""" + glider_grab = GliderDataFetcher() + + glider_grab.fetcher.dataset_id = "whoi_406-20160902T1700" + df = glider_grab.to_pandas() + # Generate the plot + fig, ax = plot_track(df) + + # Return the figure for pytest-mpl to compare + return fig + + @pytest.mark.mpl_image_compare(baseline_dir=root.joinpath("baseline/")) def test_plot_transect(): + """Image comparison test for plot_transect.""" glider_grab = GliderDataFetcher() glider_grab.fetcher.dataset_id = "whoi_406-20160902T1700"