diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d746e08..12159d0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ New features and enhancements * ``fg.taylordiagram`` can now accept datasets with many dimensions (not only `taylor_params`), provided that they all share the same `ref_std` (e.g. normalized taylor diagrams) (:pull:`214`). * A new optional way to organize points in a ``fg.taylordiagram`` with `colors_key`, `markers_key` : DataArrays with a common dimension value or a common attribute are grouped with the same color/marker (:pull:`214`). * Heatmap (``fg.matplotlib.heatmap``) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:issue:`208`, :pull:`219`). +* New function ``fg.matplotlib.triheatmap`` (:pull:`199`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/docs/notebooks/figanos_docs.ipynb b/docs/notebooks/figanos_docs.ipynb index 0a5140a..8b7be49 100644 --- a/docs/notebooks/figanos_docs.ipynb +++ b/docs/notebooks/figanos_docs.ipynb @@ -989,6 +989,67 @@ "cell_type": "markdown", "id": "61", "metadata": {}, + "source": [ + "## Triangle heatmaps\n", + "\n", + "The `triheatmap` function is based on the matplotlib function [tripcolor](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tripcolor.html). It can create a heatmap with 2 or 4 triangles in each square of the heatmap.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a fake data\n", + "da = xr.DataArray(data=np.random.rand(2,3,4),\n", + " coords=dict(realization=['A', 'B'],\n", + " method=['a','b', 'c'],\n", + " experiment=['ssp126','ssp245','ssp370','ssp585'],\n", + " ))\n", + "da.name='pr' # to guess the cmap\n", + "# will be automatically detected for the cbar label\n", + "da.attrs['long_name']= 'precipitation' \n", + "da.attrs['units']= 'mm'\n", + "\n", + "# Plot a heatmap\n", + "fg.triheatmap(da,\n", + " z='experiment', # which dimension should be represented by triangles\n", + " divergent=True, # for the cmap\n", + " cbar='unique', # only show one cbar\n", + " plot_kw={'vmin':-1, 'vmax':1} # we are only showing the 1st cbar, so make sure the cbar of each triangle is the same\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a fake data\n", + "da = xr.DataArray(data=np.random.rand(4,3,2),\n", + " coords=dict(realization=['A', 'B', 'C','D'],\n", + " method=['a','b', 'c'],\n", + " season=['DJF','JJA'],\n", + " ))\n", + "da.attrs['description']= \"La plus belle saison de ma vie\"\n", + "\n", + "# Plot a heatmap\n", + "fg.triheatmap(da,\n", + " z='season',\n", + " cbar='each', # show a cbar per triangle\n", + " use_attrs={'title':'description'},\n", + " cbar_kw=[{'label':'winter'},{'label':'summer'}], # Use a list to change the cbar associated with each triangle type (upper or lower)\n", + " plot_kw=[{'cmap':'winter'},{'cmap':'summer'}]) # Use a list to change each triangle type (upper or lower)" + ] + }, + { + "cell_type": "markdown", + "id": "64", + "metadata": {}, "source": [ "## Taylor Diagrams\n", "\n", @@ -1006,7 +1067,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62", + "id": "65", "metadata": {}, "outputs": [], "source": [ @@ -1025,7 +1086,7 @@ }, { "cell_type": "markdown", - "id": "63", + "id": "66", "metadata": {}, "source": [ "### Normalized taylor diagram\n", @@ -1036,7 +1097,7 @@ { "cell_type": "code", "execution_count": null, - "id": "64", + "id": "67", "metadata": {}, "outputs": [], "source": [ @@ -1070,7 +1131,7 @@ }, { "cell_type": "markdown", - "id": "65", + "id": "68", "metadata": {}, "source": [ "## Partition plots\n", @@ -1090,7 +1151,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66", + "id": "69", "metadata": {}, "outputs": [], "source": [ @@ -1138,7 +1199,7 @@ }, { "cell_type": "markdown", - "id": "67", + "id": "70", "metadata": {}, "source": [ "Compute uncertainties with xclim and use `fractional_uncertainty` to have the right format to plot." @@ -1147,7 +1208,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68", + "id": "71", "metadata": {}, "outputs": [], "source": [ @@ -1182,7 +1243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69", + "id": "72", "metadata": {}, "outputs": [], "source": [ @@ -1212,7 +1273,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/src/figanos/matplotlib/__init__.py b/src/figanos/matplotlib/__init__.py index 974fcda..855deeb 100644 --- a/src/figanos/matplotlib/__init__.py +++ b/src/figanos/matplotlib/__init__.py @@ -10,6 +10,7 @@ stripes, taylordiagram, timeseries, + triheatmap, violin, ) from .utils import categorical_colors, plot_logo, set_mpl_style diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 64ef69a..01e9592 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -27,6 +27,7 @@ from matplotlib.cm import ScalarMappable from matplotlib.lines import Line2D from matplotlib.projections import PolarAxes +from matplotlib.tri import Triangulation from mpl_toolkits.axisartist.floating_axes import FloatingSubplot, GridHelperCurveLinear from figanos.matplotlib.utils import ( # masknan_sizes_key, @@ -2768,3 +2769,245 @@ def partition( ax.legend(**legend_kw) return ax + + +def triheatmap( + data: xr.DataArray | xr.Dataset, + z: str, + ax: matplotlib.axes.Axes | None = None, + use_attrs: dict[str, Any] | None = None, + fig_kw: dict[str, Any] | None = None, + plot_kw: dict[str, Any] | None | list = None, + cmap: str | matplotlib.colors.Colormap | None = None, + divergent: bool | int | float = False, + cbar: bool | str = "unique", + cbar_kw: dict[str, Any] | None | list = None, +) -> matplotlib.axes.Axes: + """Create a triangle heatmap from a DataArray. + + Note that most of the code comes from: + https://stackoverflow.com/questions/66048529/how-to-create-a-heatmap-where-each-cell-is-divided-into-4-triangles + + Parameters + ---------- + data : DataArray or Dataset + Input data do plot. + z: str + Dimension to plot on the triangles. Its length should be 2 or 4. + ax : matplotlib axis, optional + Matplotlib axis on which to plot, with the same projection as the one specified. + use_attrs : dict, optional + Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). + Default value is {'cbar_label': 'long_name',"cbar_units": "units"}. + Valid keys are: 'title', 'xlabel', 'ylabel', 'cbar_label', 'cbar_units'. + fig_kw : dict, optional + Arguments to pass to `plt.figure()`. + plot_kw : dict, optional + Arguments to pass to the 'plt.tripcolor()' function. + It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west). + cmap : matplotlib.colors.Colormap or str, optional + Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). + If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray + or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022 + (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). + divergent : bool or int or float + If int or float, becomes center of cmap. Default center is 0. + cbar : {False, True, 'unique', 'each'} + If False, don't show the colorbar. + If True or 'unique', show a unique colorbar for all triangle types. (The cbar of the first triangle is used). + If 'each', show a colorbar for each triangle type. + cbar_kw : dict or list + Arguments to pass to 'fig.colorbar()'. + It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west). + + Returns + ------- + matplotlib.axes.Axes + """ + # create empty dicts if None + use_attrs = empty_dict(use_attrs) + fig_kw = empty_dict(fig_kw) + plot_kw = empty_dict(plot_kw) + cbar_kw = empty_dict(cbar_kw) + + # select data to plot + if isinstance(data, xr.DataArray): + da = data + elif isinstance(data, xr.Dataset): + if len(data.data_vars) > 1: + warnings.warn( + "data is xr.Dataset; only the first variable will be used in plot" + ) + da = list(data.values())[0] + else: + raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") + + # setup fig, axis + if ax is None: + fig, ax = plt.subplots(**fig_kw) + + # colormap + if isinstance(cmap, str): + if cmap not in plt.colormaps(): + try: + cmap = create_cmap(filename=cmap) + except FileNotFoundError: + pass + logging.log("Colormap not found. Using default.") + + elif cmap is None: + cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json" + cmap = create_cmap( + get_var_group(path_to_json=cdata, da=da), + divergent=divergent, + ) + + # prep data + d = [da.sel(**{z: v}).values for v in da[z].values] + + other_dims = [di for di in da.dims if di != z] + if len(other_dims) > 2: + warnings.warn( + "More than 3 dimensions in data. The first two after dim will be used as the dimensions of the heatmap." + ) + if len(other_dims) < 2: + raise ValueError( + "Data must have 3 dimensions. If you only have 2 dimensions, use fg.heatmap." + ) + + if plot_kw == {} and cbar in ["unique", True]: + warnings.warn( + 'With cbar="unique" only the colorbar of the first triangle' + " will be shown. No `plot_kw` was passed. vmin and vmax will be set the max" + " and min of data." + ) + plot_kw = {"vmax": da.max().values, "vmin": da.min().values} + + if isinstance(plot_kw, dict): + plot_kw.setdefault("cmap", cmap) + plot_kw.setdefault("ec", "white") + plot_kw = [plot_kw for _ in range(len(d))] + + labels_x = da[other_dims[0]].values + labels_y = da[other_dims[1]].values + m, n = d[0].shape[0], d[0].shape[1] + + # plot + if len(d) == 2: + + x = np.arange(m + 1) + y = np.arange(n + 1) + xss, ys = np.meshgrid(x, y) + zs = (xss * ys) % 10 + triangles1 = [ + (i + j * (m + 1), i + 1 + j * (m + 1), i + (j + 1) * (m + 1)) + for j in range(n) + for i in range(m) + ] + triangles2 = [ + (i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), i + (j + 1) * (m + 1)) + for j in range(n) + for i in range(m) + ] + triang1 = Triangulation(xss.ravel(), ys.ravel(), triangles1) + triang2 = Triangulation(xss.ravel(), ys.ravel(), triangles2) + triangul = [triang1, triang2] + + imgs = [ + ax.tripcolor(t, np.ravel(val), **plotkw) + for t, val, plotkw in zip(triangul, d, plot_kw) + ] + + ax.set_xticks(np.array(range(m)) + 0.5, labels=labels_x, rotation=45) + ax.set_yticks(np.array(range(n)) + 0.5, labels=labels_y, rotation=90) + + elif len(d) == 4: + + xv, yv = np.meshgrid( + np.arange(-0.5, m), np.arange(-0.5, n) + ) # vertices of the little squares + xc, yc = np.meshgrid( + np.arange(0, m), np.arange(0, n) + ) # centers of the little squares + x = np.concatenate([xv.ravel(), xc.ravel()]) + y = np.concatenate([yv.ravel(), yc.ravel()]) + cstart = (m + 1) * (n + 1) # indices of the centers + + triangles_n = [ + (i + j * (m + 1), i + 1 + j * (m + 1), cstart + i + j * m) + for j in range(n) + for i in range(m) + ] + triangles_e = [ + (i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), cstart + i + j * m) + for j in range(n) + for i in range(m) + ] + triangles_s = [ + (i + 1 + (j + 1) * (m + 1), i + (j + 1) * (m + 1), cstart + i + j * m) + for j in range(n) + for i in range(m) + ] + triangles_w = [ + (i + (j + 1) * (m + 1), i + j * (m + 1), cstart + i + j * m) + for j in range(n) + for i in range(m) + ] + triangul = [ + Triangulation(x, y, triangles) + for triangles in [triangles_n, triangles_e, triangles_s, triangles_w] + ] + + imgs = [ + ax.tripcolor(t, np.ravel(val), **plotkw) + for t, val, plotkw in zip(triangul, d, plot_kw) + ] + ax.set_xticks(np.array(range(m)), labels=labels_x, rotation=45) + ax.set_yticks(np.array(range(n)), labels=labels_y, rotation=90) + + else: + raise ValueError( + f"The length of the dimensiondim ({z},{len(d)}) should be either 2 or 4. It represents the number of triangles." + ) + + ax.set_title(get_attributes(use_attrs.get("title", None), data)) + ax.set_xlabel(other_dims[0]) + ax.set_ylabel(other_dims[1]) + if "xlabel" in use_attrs: + ax.set_xlabel(get_attributes(use_attrs["xlabel"], data)) + if "ylabel" in use_attrs: + ax.set_ylabel(get_attributes(use_attrs["ylabel"], data)) + ax.set_aspect("equal", "box") + ax.invert_yaxis() + ax.tick_params(left=False, bottom=False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + + # create cbar label + # set default use_attrs values + use_attrs.setdefault("cbar_label", "long_name") + use_attrs.setdefault("cbar_units", "units") + if ( + "cbar_units" in use_attrs + and len(get_attributes(use_attrs["cbar_units"], data)) >= 1 + ): # avoids '()' as label + cbar_label = ( + get_attributes(use_attrs["cbar_label"], data) + + " (" + + get_attributes(use_attrs["cbar_units"], data) + + ")" + ) + else: + cbar_label = get_attributes(use_attrs["cbar_label"], data) + + if isinstance(cbar_kw, dict): + cbar_kw.setdefault("label", cbar_label) + cbar_kw = [cbar_kw for _ in range(len(d))] + if cbar == "unique": + plt.colorbar(imgs[0], ax=ax, **cbar_kw[0]) + + elif (cbar == "each") or (cbar is True): + for i in reversed(range(len(d))): # swithc order of colorbars + plt.colorbar(imgs[i], ax=ax, **cbar_kw[i]) + + return ax