diff --git a/localtileserver/client.py b/localtileserver/client.py index f9725f32..4093d740 100644 --- a/localtileserver/client.py +++ b/localtileserver/client.py @@ -2,6 +2,7 @@ import logging import pathlib from typing import List, Optional, Union +from matplotlib.colors import Colormap import rasterio import requests @@ -416,7 +417,7 @@ def create_url(self, path: str, client: bool = False): def get_tile_url( self, indexes: Optional[List[int]] = None, - colormap: Optional[str] = None, + colormap: Optional[Union[str, Colormap]] = None, vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, nodata: Optional[Union[int, float]] = None, @@ -441,13 +442,23 @@ def get_tile_url( The value from the band to use to interpret as not valid data. """ + import json # First handle query parameters to check for errors params = {} if indexes is not None: params["indexes"] = indexes if colormap is not None: - # make sure palette is valid - palette_valid_or_raise(colormap) + + if isinstance(colormap, Colormap): + colormap = json.dumps( + {k:tuple(v.tolist()) for k,v in enumerate(colormap(range(256),1,1))} + ) + elif isinstance(colormap, list): + colormap = json.dumps(colormap) + else: + # make sure palette is valid + palette_valid_or_raise(colormap) + params["colormap"] = colormap if vmin is not None: if isinstance(vmin, Iterable) and not isinstance(indexes, Iterable): diff --git a/localtileserver/tiler/handler.py b/localtileserver/tiler/handler.py index d1949772..b28974d7 100644 --- a/localtileserver/tiler/handler.py +++ b/localtileserver/tiler/handler.py @@ -150,7 +150,21 @@ def _render_image( colormap: Optional[str] = None, img_format: str = "PNG", ): - colormap = cmap.get(colormap) if colormap else None + import json + from matplotlib.colors import LinearSegmentedColormap + + if colormap in cmap.list(): + colormap = cmap.get(colormap) + else: + c = json.loads(colormap) + if isinstance(c, list): + c = LinearSegmentedColormap.from_list('', c, N=256) + colormap = {k:tuple(v) for k,v in enumerate(c(range(256),1,1))} + else: + colormap = {} + for key,value in c.items(): + colormap[int(key)] = tuple(value) + if ( not colormap and len(indexes) == 1 diff --git a/localtileserver/widgets.py b/localtileserver/widgets.py index 06863b5a..1517ef2f 100644 --- a/localtileserver/widgets.py +++ b/localtileserver/widgets.py @@ -1,6 +1,7 @@ import logging import pathlib from typing import List, Optional, Union +from matplotlib.colors import Colormap import warnings import rasterio @@ -23,7 +24,7 @@ def get_leaflet_tile_layer( port: Union[int, str] = "default", debug: bool = False, indexes: Optional[List[int]] = None, - colormap: Optional[str] = None, + colormap: Optional[Union[str, Colormap]] = None, vmin: Optional[Union[float, List[float]]] = None, vmax: Optional[Union[float, List[float]]] = None, nodata: Optional[Union[int, float]] = None,