diff --git a/localtileserver/client.py b/localtileserver/client.py index f9725f32..bdaa4fe4 100644 --- a/localtileserver/client.py +++ b/localtileserver/client.py @@ -1,8 +1,10 @@ from collections.abc import Iterable +import json import logging import pathlib from typing import List, Optional, Union +from matplotlib.colors import Colormap, ListedColormap import rasterio import requests from rio_tiler.io import Reader @@ -416,7 +418,7 @@ def create_url(self, path: str, client: bool = False): def get_tile_url( self, indexes: Optional[List[int]] = None, - colormap: Optional[str] = None, + colormap: Optional[Union[str, Colormap, List[str]]] = None, vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, nodata: Optional[Union[int, float]] = None, @@ -446,8 +448,18 @@ def get_tile_url( if indexes is not None: params["indexes"] = indexes if colormap is not None: - # make sure palette is valid - palette_valid_or_raise(colormap) + if isinstance(colormap, ListedColormap): + colormap = json.dumps([c for c in colormap.colors]) + elif isinstance(colormap, Colormap): + colormap = json.dumps( + {k: tuple(v.tolist()) for k, v in enumerate(colormap(range(256), 1, 1))} + ) + elif isinstance(colormap, list): + colormap = json.dumps(colormap) + else: + # make sure palette is valid + palette_valid_or_raise(colormap) + params["colormap"] = colormap if vmin is not None: if isinstance(vmin, Iterable) and not isinstance(indexes, Iterable): diff --git a/localtileserver/tiler/handler.py b/localtileserver/tiler/handler.py index d1949772..1e0af982 100644 --- a/localtileserver/tiler/handler.py +++ b/localtileserver/tiler/handler.py @@ -1,7 +1,9 @@ """Methods for working with images.""" +import json import pathlib from typing import Dict, List, Optional, Tuple, Union +from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap import numpy as np import rasterio from rasterio.enums import ColorInterp @@ -150,7 +152,23 @@ def _render_image( colormap: Optional[str] = None, img_format: str = "PNG", ): - colormap = cmap.get(colormap) if colormap else None + if colormap in cmap.list(): + colormap = cmap.get(colormap) + elif isinstance(colormap, ListedColormap): + c = LinearSegmentedColormap.from_list("", colormap.colors, N=256) + colormap = {k: tuple(v) for k, v in enumerate(c(range(256), 1, 1))} + elif isinstance(colormap, Colormap): + colormap = {k: tuple(v) for k, v in enumerate(colormap(range(256), 1, 1))} + elif colormap: + c = json.loads(colormap) + if isinstance(c, list): + c = LinearSegmentedColormap.from_list("", c, N=256) + colormap = {k: tuple(v) for k, v in enumerate(c(range(256), 1, 1))} + else: + colormap = {} + for key, value in c.items(): + colormap[int(key)] = tuple(value) + if ( not colormap and len(indexes) == 1 diff --git a/localtileserver/widgets.py b/localtileserver/widgets.py index 06863b5a..b33a1b3e 100644 --- a/localtileserver/widgets.py +++ b/localtileserver/widgets.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union import warnings +from matplotlib.colors import Colormap import rasterio from localtileserver.client import TileClient, get_or_create_tile_client @@ -23,7 +24,7 @@ def get_leaflet_tile_layer( port: Union[int, str] = "default", debug: bool = False, indexes: Optional[List[int]] = None, - colormap: Optional[str] = None, + colormap: Optional[Union[str, Colormap, List[str]]] = None, vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, nodata: Optional[Union[int, float]] = None, diff --git a/tests/baseline/test_custom_colormap[colormap0-None].png b/tests/baseline/test_custom_colormap[colormap0-None].png new file mode 100644 index 00000000..6dfbf811 Binary files /dev/null and b/tests/baseline/test_custom_colormap[colormap0-None].png differ diff --git a/tests/baseline/test_custom_colormap[colormap1-2].png b/tests/baseline/test_custom_colormap[colormap1-2].png new file mode 100644 index 00000000..d099c8f8 Binary files /dev/null and b/tests/baseline/test_custom_colormap[colormap1-2].png differ diff --git a/tests/test_rendering.py b/tests/test_rendering.py index bf3cb403..685bf6b9 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -1,3 +1,4 @@ +from matplotlib.colors import ListedColormap import pytest from .utilities import get_content @@ -56,6 +57,24 @@ def test_tile_colormap(bahamas, compare, colormap, indexes): compare(direct_content) +@pytest.mark.parametrize( + "colormap,indexes", + [ + (ListedColormap(["red", "blue"]), None), + (ListedColormap(["blue", "green"]), 2), + ], +) +def test_custom_colormap(bahamas, compare, colormap, indexes): + # Get a tile over the REST API + tile_url = bahamas.get_tile_url(colormap=colormap, indexes=indexes).format(z=8, x=72, y=110) + rest_content = get_content(tile_url) + # Get tile directly + direct_content = bahamas.tile(z=8, x=72, y=110, colormap=colormap, indexes=indexes) + # Make sure they are the same + assert rest_content == direct_content + compare(direct_content) + + @pytest.mark.parametrize("vmin", [100, [100, 200, 250]]) def test_tile_vmin(bahamas, compare, vmin): url = bahamas.get_tile_url(