From 8a8a2141c13947f6afaaf65a8c29c4417cd3a5a3 Mon Sep 17 00:00:00 2001
From: Frank Anema <33519926+Conengmo@users.noreply.github.com>
Date: Sat, 18 May 2024 11:00:18 +0200
Subject: [PATCH] Add type hints (#146)
* Update colormap.py
* Create test_mypy.yml
* add Mypy requirement
* Update pyproject.toml
* Update element.py
* utilities.py wip
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update branca/utilities.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Restore Div and MacroElement
* Update test_mypy.yml
* Update colormap.py
* Update element.py
* Update utilities.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* error after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
This reverts commit e21b6d4dc5c5510da8c20767d871d07f27253b70.
* another rebase artefact :/
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.github/workflows/test_mypy.yml | 33 ++++++
branca/colormap.py | 176 ++++++++++++++++-------------
branca/element.py | 189 +++++++++++++++++++++-----------
branca/utilities.py | 60 +++++-----
pyproject.toml | 3 +
requirements-dev.txt | 1 +
6 files changed, 291 insertions(+), 171 deletions(-)
create mode 100644 .github/workflows/test_mypy.yml
diff --git a/.github/workflows/test_mypy.yml b/.github/workflows/test_mypy.yml
new file mode 100644
index 0000000..96715e0
--- /dev/null
+++ b/.github/workflows/test_mypy.yml
@@ -0,0 +1,33 @@
+name: Mypy type hint checks
+
+on:
+ pull_request:
+ push:
+ branches:
+ - main
+
+jobs:
+ run:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Setup Micromamba env
+ uses: mamba-org/setup-micromamba@v1
+ with:
+ environment-name: TEST
+ create-args: >-
+ python=3
+ --file requirements.txt
+ --file requirements-dev.txt
+
+ - name: Install branca from source
+ shell: bash -l {0}
+ run: |
+ python -m pip install -e . --no-deps --force-reinstall
+
+ - name: Mypy test
+ shell: bash -l {0}
+ run: |
+ mypy branca
diff --git a/branca/colormap.py b/branca/colormap.py
index e5dd839..2e76ba5 100644
--- a/branca/colormap.py
+++ b/branca/colormap.py
@@ -9,51 +9,67 @@
import json
import math
import os
+from typing import Dict, List, Optional, Sequence, Tuple, Union
from jinja2 import Template
from branca.element import ENV, Figure, JavascriptLink, MacroElement
from branca.utilities import legend_scaler
-rootpath = os.path.abspath(os.path.dirname(__file__))
+rootpath: str = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(rootpath, "_cnames.json")) as f:
- _cnames = json.loads(f.read())
+ _cnames: Dict[str, str] = json.loads(f.read())
with open(os.path.join(rootpath, "_schemes.json")) as f:
- _schemes = json.loads(f.read())
+ _schemes: Dict[str, List[str]] = json.loads(f.read())
-def _is_hex(x):
+TypeRGBInts = Tuple[int, int, int]
+TypeRGBFloats = Tuple[float, float, float]
+TypeRGBAInts = Tuple[int, int, int, int]
+TypeRGBAFloats = Tuple[float, float, float, float]
+TypeAnyColorType = Union[TypeRGBInts, TypeRGBFloats, TypeRGBAInts, TypeRGBAFloats, str]
+
+
+def _is_hex(x: str) -> bool:
return x.startswith("#") and len(x) == 7
-def _parse_hex(color_code):
+def _parse_hex(color_code: str) -> TypeRGBAFloats:
return (
- int(color_code[1:3], 16),
- int(color_code[3:5], 16),
- int(color_code[5:7], 16),
+ _color_int_to_float(int(color_code[1:3], 16)),
+ _color_int_to_float(int(color_code[3:5], 16)),
+ _color_int_to_float(int(color_code[5:7], 16)),
+ 1.0,
)
-def _parse_color(x):
+def _color_int_to_float(x: int) -> float:
+ """Convert an integer between 0 and 255 to a float between 0. and 1.0"""
+ return x / 255.0
+
+
+def _color_float_to_int(x: float) -> int:
+ """Convert a float between 0. and 1.0 to an integer between 0 and 255"""
+ return int(x * 255.9999)
+
+
+def _parse_color(x: Union[tuple, list, str]) -> TypeRGBAFloats:
if isinstance(x, (tuple, list)):
- color_tuple = tuple(x)[:4]
- elif isinstance(x, (str, bytes)) and _is_hex(x):
- color_tuple = _parse_hex(x)
- elif isinstance(x, (str, bytes)):
+ return tuple(tuple(x) + (1.0,))[:4] # type: ignore
+ elif isinstance(x, str) and _is_hex(x):
+ return _parse_hex(x)
+ elif isinstance(x, str):
cname = _cnames.get(x.lower(), None)
if cname is None:
raise ValueError(f"Unknown color {cname!r}.")
- color_tuple = _parse_hex(cname)
+ return _parse_hex(cname)
else:
raise ValueError(f"Unrecognized color code {x!r}")
- if max(color_tuple) > 1.0:
- color_tuple = tuple(u / 255.0 for u in color_tuple)
- return tuple(map(float, (color_tuple + (1.0,))[:4]))
-def _base(x):
+def _base(x: float) -> float:
if x > 0:
base = pow(10, math.floor(math.log10(x)))
return round(x / base) * base
@@ -78,15 +94,15 @@ class ColorMap(MacroElement):
Maximum number of legend tick labels
"""
- _template = ENV.get_template("color_scale.js")
+ _template: Template = ENV.get_template("color_scale.js")
def __init__(
self,
- vmin=0.0,
- vmax=1.0,
- caption="",
- text_color="black",
- max_labels=10,
+ vmin: float = 0.0,
+ vmax: float = 1.0,
+ caption: str = "",
+ text_color: str = "black",
+ max_labels: int = 10,
):
super().__init__()
self._name = "ColorMap"
@@ -95,9 +111,9 @@ def __init__(
self.vmax = vmax
self.caption = caption
self.text_color = text_color
- self.index = [vmin, vmax]
+ self.index: List[float] = [vmin, vmax]
self.max_labels = max_labels
- self.tick_labels = None
+ self.tick_labels: Optional[Sequence[Union[float, str]]] = None
self.width = 450
self.height = 40
@@ -127,7 +143,7 @@ def render(self, **kwargs):
name="d3",
) # noqa
- def rgba_floats_tuple(self, x):
+ def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""
This class has to be implemented for each class inheriting from
Colormap. This has to be a function of the form float ->
@@ -137,37 +153,37 @@ def rgba_floats_tuple(self, x):
"""
raise NotImplementedError
- def rgba_bytes_tuple(self, x):
+ def rgba_bytes_tuple(self, x: float) -> TypeRGBAInts:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with int values between 0 and 255.
"""
- return tuple(int(u * 255.9999) for u in self.rgba_floats_tuple(x))
+ return tuple(_color_float_to_int(u) for u in self.rgba_floats_tuple(x)) # type: ignore
- def rgb_bytes_tuple(self, x):
+ def rgb_bytes_tuple(self, x: float) -> TypeRGBInts:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B) with int values between 0 and 255.
"""
return self.rgba_bytes_tuple(x)[:3]
- def rgb_hex_str(self, x):
+ def rgb_hex_str(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBB".
"""
return "#%02x%02x%02x" % self.rgb_bytes_tuple(x)
- def rgba_hex_str(self, x):
+ def rgba_hex_str(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBBAA".
"""
return "#%02x%02x%02x%02x" % self.rgba_bytes_tuple(x)
- def __call__(self, x):
+ def __call__(self, x: float) -> str:
"""Provides the color corresponding to value `x` in the
form of a string of hexadecimal values "#RRGGBBAA".
"""
return self.rgba_hex_str(x)
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
"""Display the colormap in a Jupyter Notebook.
Does not support all the class arguments.
@@ -264,14 +280,14 @@ class LinearColormap(ColorMap):
def __init__(
self,
- colors,
- index=None,
- vmin=0.0,
- vmax=1.0,
- caption="",
- text_color="black",
- max_labels=10,
- tick_labels=None,
+ colors: Sequence[TypeAnyColorType],
+ index: Optional[Sequence[float]] = None,
+ vmin: float = 0.0,
+ vmax: float = 1.0,
+ caption: str = "",
+ text_color: str = "black",
+ max_labels: int = 10,
+ tick_labels: Optional[Sequence[float]] = None,
):
super().__init__(
vmin=vmin,
@@ -280,7 +296,7 @@ def __init__(
text_color=text_color,
max_labels=max_labels,
)
- self.tick_labels = tick_labels
+ self.tick_labels: Optional[Sequence[float]] = tick_labels
n = len(colors)
if n < 2:
@@ -289,9 +305,9 @@ def __init__(
self.index = [vmin + (vmax - vmin) * i * 1.0 / (n - 1) for i in range(n)]
else:
self.index = list(index)
- self.colors = [_parse_color(x) for x in colors]
+ self.colors: List[TypeRGBAFloats] = [_parse_color(x) for x in colors]
- def rgba_floats_tuple(self, x):
+ def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with float values between 0. and 1.
"""
@@ -308,20 +324,20 @@ def rgba_floats_tuple(self, x):
else:
raise ValueError("Thresholds are not sorted.")
- return tuple(
+ return tuple( # type: ignore
(1.0 - p) * self.colors[i - 1][j] + p * self.colors[i][j] for j in range(4)
)
def to_step(
self,
- n=None,
- index=None,
- data=None,
- method=None,
- quantiles=None,
- round_method=None,
- max_labels=10,
- ):
+ n: Optional[int] = None,
+ index: Optional[Sequence[float]] = None,
+ data: Optional[Sequence[float]] = None,
+ method: str = "linear",
+ quantiles: Optional[Sequence[float]] = None,
+ round_method: Optional[str] = None,
+ max_labels: int = 10,
+ ) -> "StepColormap":
"""Splits the LinearColormap into a StepColormap.
Parameters
@@ -382,11 +398,7 @@ def to_step(
max_ = max(data)
min_ = min(data)
scaled_cm = self.scale(vmin=min_, vmax=max_)
- method = (
- "quantiles"
- if quantiles is not None
- else method if method is not None else "linear"
- )
+ method = "quantiles" if quantiles is not None else method
if method.lower().startswith("lin"):
if n is None:
raise ValueError(msg)
@@ -454,7 +466,12 @@ def to_step(
tick_labels=self.tick_labels,
)
- def scale(self, vmin=0.0, vmax=1.0, max_labels=10):
+ def scale(
+ self,
+ vmin: float = 0.0,
+ vmax: float = 1.0,
+ max_labels: int = 10,
+ ) -> "LinearColormap":
"""Transforms the colorscale so that the minimal and maximal values
fit the given parameters.
"""
@@ -510,14 +527,14 @@ class StepColormap(ColorMap):
def __init__(
self,
- colors,
- index=None,
- vmin=0.0,
- vmax=1.0,
- caption="",
- text_color="black",
- max_labels=10,
- tick_labels=None,
+ colors: Sequence[TypeAnyColorType],
+ index: Optional[Sequence[float]] = None,
+ vmin: float = 0.0,
+ vmax: float = 1.0,
+ caption: str = "",
+ text_color: str = "black",
+ max_labels: int = 10,
+ tick_labels: Optional[Sequence[float]] = None,
):
super().__init__(
vmin=vmin,
@@ -535,9 +552,9 @@ def __init__(
self.index = [vmin + (vmax - vmin) * i * 1.0 / n for i in range(n + 1)]
else:
self.index = list(index)
- self.colors = [_parse_color(x) for x in colors]
+ self.colors: List[TypeRGBAFloats] = [_parse_color(x) for x in colors]
- def rgba_floats_tuple(self, x):
+ def rgba_floats_tuple(self, x: float) -> TypeRGBAFloats:
"""
Provides the color corresponding to value `x` in the
form of a tuple (R,G,B,A) with float values between 0. and 1.
@@ -549,9 +566,13 @@ def rgba_floats_tuple(self, x):
return self.colors[-1]
i = len([u for u in self.index if u <= x]) # 0 < i < n.
- return tuple(self.colors[i - 1])
+ return self.colors[i - 1]
- def to_linear(self, index=None, max_labels=10):
+ def to_linear(
+ self,
+ index: Optional[Sequence[float]] = None,
+ max_labels: int = 10,
+ ) -> LinearColormap:
"""
Transforms the StepColormap into a LinearColormap.
@@ -584,7 +605,12 @@ def to_linear(self, index=None, max_labels=10):
max_labels=max_labels,
)
- def scale(self, vmin=0.0, vmax=1.0, max_labels=10):
+ def scale(
+ self,
+ vmin: float = 0.0,
+ vmax: float = 1.0,
+ max_labels: int = 10,
+ ) -> "StepColormap":
"""Transforms the colorscale so that the minimal and maximal values
fit the given parameters.
"""
@@ -611,7 +637,7 @@ def __init__(self):
for key, val in _schemes.items():
setattr(self, key, LinearColormap(val))
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
return Template(
"""
@@ -634,7 +660,7 @@ def __init__(self):
for key, val in _schemes.items():
setattr(self, key, StepColormap(val))
- def _repr_html_(self):
+ def _repr_html_(self) -> str:
return Template(
"""
diff --git a/branca/element.py b/branca/element.py
index 1912aca..a1e005a 100644
--- a/branca/element.py
+++ b/branca/element.py
@@ -14,11 +14,12 @@
from html import escape
from os import urandom
from pathlib import Path
+from typing import BinaryIO, List, Optional, Tuple, Type, Union
from urllib.request import urlopen
from jinja2 import Environment, PackageLoader, Template
-from .utilities import _camelify, _parse_size, none_max, none_min
+from .utilities import TypeParseSize, _camelify, _parse_size, none_max, none_min
ENV = Environment(loader=PackageLoader("branca", "templates"))
@@ -45,26 +46,30 @@ class Element:
"""
- _template = Template(
+ _template: Template = Template(
"{% for name, element in this._children.items() %}\n"
" {{element.render(**kwargs)}}"
"{% endfor %}",
)
- def __init__(self, template=None, template_name=None):
- self._name = "Element"
- self._id = hexlify(urandom(16)).decode()
- self._children = OrderedDict()
- self._parent = None
- self._template_str = template
- self._template_name = template_name
+ def __init__(
+ self,
+ template: Optional[str] = None,
+ template_name: Optional[str] = None,
+ ):
+ self._name: str = "Element"
+ self._id: str = hexlify(urandom(16)).decode()
+ self._children: OrderedDict[str, Element] = OrderedDict()
+ self._parent: Optional[Element] = None
+ self._template_str: Optional[str] = template
+ self._template_name: Optional[str] = template_name
if template is not None:
self._template = Template(template)
elif template_name is not None:
self._template = ENV.get_template(template_name)
- def __getstate__(self):
+ def __getstate__(self) -> dict:
"""Modify object state when pickling the object.
jinja2 Templates cannot be pickled, so remove the instance attribute
@@ -83,7 +88,7 @@ def __setstate__(self, state: dict):
self.__dict__.update(state)
- def get_name(self):
+ def get_name(self) -> str:
"""Returns a string representation of the object.
This string has to be unique and to be a python and
javascript-compatible
@@ -91,13 +96,13 @@ def get_name(self):
"""
return _camelify(self._name) + "_" + self._id
- def _get_self_bounds(self):
+ def _get_self_bounds(self) -> List[List[Optional[float]]]:
"""Computes the bounds of the object itself (not including it's children)
in the form [[lat_min, lon_min], [lat_max, lon_max]]
"""
return [[None, None], [None, None]]
- def get_bounds(self):
+ def get_bounds(self) -> List[List[Optional[float]]]:
"""Computes the bounds of the object and all it's children
in the form [[lat_min, lon_min], [lat_max, lon_max]].
"""
@@ -117,7 +122,12 @@ def get_bounds(self):
]
return bounds
- def add_children(self, child, name=None, index=None):
+ def add_children(
+ self,
+ child: "Element",
+ name: Optional[str] = None,
+ index: Optional[int] = None,
+ ) -> "Element":
"""Add a child."""
warnings.warn(
"Method `add_children` is deprecated. Please use `add_child` instead.",
@@ -126,7 +136,12 @@ def add_children(self, child, name=None, index=None):
)
return self.add_child(child, name=name, index=index)
- def add_child(self, child, name=None, index=None):
+ def add_child(
+ self,
+ child: "Element",
+ name: Optional[str] = None,
+ index: Optional[int] = None,
+ ) -> "Element":
"""Add a child."""
if name is None:
name = child.get_name()
@@ -139,13 +154,24 @@ def add_child(self, child, name=None, index=None):
child._parent = self
return self
- def add_to(self, parent, name=None, index=None):
+ def add_to(
+ self,
+ parent: "Element",
+ name: Optional[str] = None,
+ index: Optional[int] = None,
+ ) -> "Element":
"""Add element to a parent."""
parent.add_child(self, name=name, index=index)
return self
- def to_dict(self, depth=-1, ordered=True, **kwargs):
+ def to_dict(
+ self,
+ depth: int = -1,
+ ordered: bool = True,
+ **kwargs,
+ ) -> Union[dict, OrderedDict]:
"""Returns a dict representation of the object."""
+ dict_fun: Type[Union[dict, OrderedDict]]
if ordered:
dict_fun = OrderedDict
else:
@@ -159,25 +185,30 @@ def to_dict(self, depth=-1, ordered=True, **kwargs):
(name, child.to_dict(depth=depth - 1))
for name, child in self._children.items()
],
- ) # noqa
+ )
return out
- def to_json(self, depth=-1, **kwargs):
+ def to_json(self, depth: int = -1, **kwargs) -> str:
"""Returns a JSON representation of the object."""
return json.dumps(self.to_dict(depth=depth, ordered=True), **kwargs)
- def get_root(self):
+ def get_root(self) -> "Element":
"""Returns the root of the elements tree."""
if self._parent is None:
return self
else:
return self._parent.get_root()
- def render(self, **kwargs):
+ def render(self, **kwargs) -> str:
"""Renders the HTML representation of the element."""
return self._template.render(this=self, kwargs=kwargs)
- def save(self, outfile, close_file=True, **kwargs):
+ def save(
+ self,
+ outfile: Union[str, bytes, Path, BinaryIO],
+ close_file: bool = True,
+ **kwargs,
+ ):
"""Saves an Element into a file.
Parameters
@@ -187,6 +218,7 @@ def save(self, outfile, close_file=True, **kwargs):
close_file : bool, default True
Whether the file has to be closed after write.
"""
+ fid: BinaryIO
if isinstance(outfile, (str, bytes, Path)):
fid = open(outfile, "wb")
else:
@@ -202,15 +234,27 @@ def save(self, outfile, close_file=True, **kwargs):
class Link(Element):
"""An abstract class for embedding a link in the HTML."""
- def get_code(self):
+ def __init__(self, url: str, download: bool = False):
+ super().__init__()
+ self.url = url
+ self.code: Optional[bytes] = None
+ if download:
+ self.get_code()
+
+ def get_code(self) -> bytes:
"""Opens the link and returns the response's content."""
if self.code is None:
self.code = urlopen(self.url).read()
return self.code
- def to_dict(self, depth=-1, **kwargs):
+ def to_dict(
+ self,
+ depth: int = -1,
+ ordered: bool = True,
+ **kwargs,
+ ) -> Union[dict, OrderedDict]:
"""Returns a dict representation of the object."""
- out = super().to_dict(depth=-1, **kwargs)
+ out = super().to_dict(depth=depth, ordered=ordered, **kwargs)
out["url"] = self.url
return out
@@ -235,13 +279,9 @@ class JavascriptLink(Link):
"{% endif %}",
)
- def __init__(self, url, download=False):
- super().__init__()
+ def __init__(self, url: str, download: bool = False):
+ super().__init__(url=url, download=download)
self._name = "JavascriptLink"
- self.url = url
- self.code = None
- if download:
- self.get_code()
class CssLink(Link):
@@ -264,13 +304,9 @@ class CssLink(Link):
"{% endif %}",
)
- def __init__(self, url, download=False):
- super().__init__()
+ def __init__(self, url: str, download: bool = False):
+ super().__init__(url=url, download=download)
self._name = "CssLink"
- self.url = url
- self.code = None
- if download:
- self.get_code()
class Figure(Element):
@@ -314,11 +350,11 @@ class Figure(Element):
def __init__(
self,
- width="100%",
- height=None,
- ratio="60%",
- title=None,
- figsize=None,
+ width: str = "100%",
+ height: Optional[str] = None,
+ ratio: str = "60%",
+ title: Optional[str] = None,
+ figsize: Optional[Tuple[int, int]] = None,
):
super().__init__()
self._name = "Figure"
@@ -346,7 +382,12 @@ def __init__(
name="meta_http",
)
- def to_dict(self, depth=-1, **kwargs):
+ def to_dict(
+ self,
+ depth: int = -1,
+ ordered: bool = True,
+ **kwargs,
+ ) -> Union[dict, OrderedDict]:
"""Returns a dict representation of the object."""
out = super().to_dict(depth=depth, **kwargs)
out["header"] = self.header.to_dict(depth=depth - 1, **kwargs)
@@ -354,17 +395,17 @@ def to_dict(self, depth=-1, **kwargs):
out["script"] = self.script.to_dict(depth=depth - 1, **kwargs)
return out
- def get_root(self):
+ def get_root(self) -> "Figure":
"""Returns the root of the elements tree."""
return self
- def render(self, **kwargs):
+ def render(self, **kwargs) -> str:
"""Renders the HTML representation of the element."""
for name, child in self._children.items():
child.render(**kwargs)
return self._template.render(this=self, kwargs=kwargs)
- def _repr_html_(self, **kwargs):
+ def _repr_html_(self, **kwargs) -> str:
"""Displays the Figure in a Jupyter notebook."""
html = escape(self.render(**kwargs))
if self.height is None:
@@ -387,7 +428,7 @@ def _repr_html_(self, **kwargs):
).format(html=html, width=self.width, height=self.height)
return iframe
- def add_subplot(self, x, y, n, margin=0.05):
+ def add_subplot(self, x: int, y: int, n: int, margin: float = 0.05) -> "Div":
"""Creates a div child subplot in a matplotlib.figure.add_subplot style.
Parameters
@@ -398,8 +439,11 @@ def add_subplot(self, x, y, n, margin=0.05):
The number of columns in the grid.
n : int
The cell number in the grid, counted from 1 to x*y.
+ margin : float, default 0.05
+ Factor to add to the left, top, width and height parameters.
- Example:
+ Example
+ -------
>>> fig.add_subplot(3, 2, 5)
# Create a div in the 5th cell of a 3rows x 2columns
grid(bottom-left corner).
@@ -447,9 +491,15 @@ class Html(Element):
'' # noqa
"{% if this.script %}{{this.data}}{% else %}{{this.data|e}}{% endif %}
",
- ) # noqa
+ )
- def __init__(self, data, script=False, width="100%", height="100%"):
+ def __init__(
+ self,
+ data: str,
+ script: bool = False,
+ width: TypeParseSize = "100%",
+ height: TypeParseSize = "100%",
+ ):
super().__init__()
self._name = "Html"
self.script = script
@@ -494,18 +544,18 @@ class Div(Figure):
def __init__(
self,
- width="100%",
- height="100%",
- left="0%",
- top="0%",
- position="relative",
+ width: TypeParseSize = "100%",
+ height: TypeParseSize = "100%",
+ left: TypeParseSize = "0%",
+ top: TypeParseSize = "0%",
+ position: str = "relative",
):
super(Figure, self).__init__()
self._name = "Div"
# Size Parameters.
- self.width = _parse_size(width)
- self.height = _parse_size(height)
+ self.width = _parse_size(width) # type: ignore
+ self.height = _parse_size(height) # type: ignore
self.left = _parse_size(left)
self.top = _parse_size(top)
self.position = position
@@ -522,7 +572,7 @@ def __init__(
self.html._parent = self
self.script._parent = self
- def get_root(self):
+ def get_root(self) -> "Div":
"""Returns the root of the elements tree."""
return self
@@ -554,14 +604,14 @@ def render(self, **kwargs):
if script is not None:
figure.script.add_child(Element(script(self, kwargs)), name=self.get_name())
- def _repr_html_(self, **kwargs):
+ def _repr_html_(self, **kwargs) -> str:
"""Displays the Div in a Jupyter notebook."""
if self._parent is None:
self.add_to(Figure())
- out = self._parent._repr_html_(**kwargs)
+ out = self._parent._repr_html_(**kwargs) # type: ignore
self._parent = None
else:
- out = self._parent._repr_html_(**kwargs)
+ out = self._parent._repr_html_(**kwargs) # type: ignore
return out
@@ -588,7 +638,14 @@ class IFrame(Element):
width="600px", height="300px".
"""
- def __init__(self, html=None, width="100%", height=None, ratio="60%", figsize=None):
+ def __init__(
+ self,
+ html: Optional[Union[str, Element]] = None,
+ width: str = "100%",
+ height: Optional[str] = None,
+ ratio: str = "60%",
+ figsize: Optional[Tuple[int, int]] = None,
+ ):
super().__init__()
self._name = "IFrame"
@@ -599,19 +656,17 @@ def __init__(self, html=None, width="100%", height=None, ratio="60%", figsize=No
self.width = str(60 * figsize[0]) + "px"
self.height = str(60 * figsize[1]) + "px"
- if isinstance(html, str) or isinstance(html, bytes):
+ if isinstance(html, str):
self.add_child(Element(html))
elif html is not None:
self.add_child(html)
- def render(self, **kwargs):
+ def render(self, **kwargs) -> str:
"""Renders the HTML representation of the element."""
html = super().render(**kwargs)
html = "data:text/html;charset=utf-8;base64," + base64.b64encode(
html.encode("utf8"),
- ).decode(
- "utf8",
- ) # noqa
+ ).decode("utf8")
if self.height is None:
iframe = (
diff --git a/branca/utilities.py b/branca/utilities.py
index 4bb3540..14a1d46 100644
--- a/branca/utilities.py
+++ b/branca/utilities.py
@@ -14,36 +14,43 @@
import struct
import typing
import zlib
-from typing import Any, Callable, Union
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from jinja2 import Environment, PackageLoader
try:
import numpy as np
except ImportError:
- np = None
+ np = None # type: ignore
if typing.TYPE_CHECKING:
from branca.colormap import ColorMap
-rootpath = os.path.abspath(os.path.dirname(__file__))
+rootpath: str = os.path.abspath(os.path.dirname(__file__))
-def get_templates():
+TypeParseSize = Union[int, float, str, Tuple[float, str]]
+
+
+def get_templates() -> Environment:
"""Get Jinja templates."""
return Environment(loader=PackageLoader("branca", "templates"))
-def legend_scaler(legend_values, max_labels=10.0):
+def legend_scaler(
+ legend_values: Sequence[float],
+ max_labels: int = 10,
+) -> List[Union[float, str]]:
"""
Downsamples the number of legend values so that there isn't a collision
of text on the legend colorbar (within reason). The colorbar seems to
support ~10 entries as a maximum.
"""
+ legend_ticks: List[Union[float, str]]
if len(legend_values) < max_labels:
- legend_ticks = legend_values
+ legend_ticks = list(legend_values)
else:
spacer = int(math.ceil(len(legend_values) / max_labels))
legend_ticks = []
@@ -53,16 +60,11 @@ def legend_scaler(legend_values, max_labels=10.0):
return legend_ticks
-def linear_gradient(hexList, nColors):
+def linear_gradient(hexList: List[str], nColors: int) -> List[str]:
"""
Given a list of hexcode values, will return a list of length
nColors where the colors are linearly interpolated between the
(r, g, b) tuples that are given.
-
- Examples
- --------
- >>> linear_gradient([(0, 0, 0), (255, 0, 0), (255, 255, 0)], 100)
-
"""
def _scale(start, finish, length, i):
@@ -80,7 +82,7 @@ def _scale(start, finish, length, i):
thex = "0" + thex
return thex
- allColors = []
+ allColors: List[str] = []
# Separate (R, G, B) pairs.
for start, end in zip(hexList[:-1], hexList[1:]):
# Linearly interpolate between pair of hex ###### values and
@@ -93,7 +95,7 @@ def _scale(start, finish, length, i):
allColors.append("".join(["#", r, g, b]))
# Pick only nColors colors from the total list.
- result = []
+ result: List[str] = []
for counter in range(nColors):
fraction = float(counter) / (nColors - 1)
index = int(fraction * (len(allColors) - 1))
@@ -101,7 +103,7 @@ def _scale(start, finish, length, i):
return result
-def color_brewer(color_code, n=6):
+def color_brewer(color_code: str, n: int = 6) -> List[str]:
"""
Generate a colorbrewer color scheme of length 'len', type 'scheme.
Live examples can be seen at http://colorbrewer2.org/
@@ -198,7 +200,11 @@ def color_brewer(color_code, n=6):
return color_scheme
-def image_to_url(image, colormap=None, origin="upper"):
+def image_to_url(
+ image: Any,
+ colormap: Union["ColorMap", Callable, None] = None,
+ origin: str = "upper",
+) -> str:
"""Infers the type of an image argument and transforms it into a URL.
Parameters
@@ -212,7 +218,7 @@ def image_to_url(image, colormap=None, origin="upper"):
origin : ['upper' | 'lower'], optional, default 'upper'
Place the [0, 0] index of the array in the upper left or
lower left corner of the axes.
- colormap : callable, used only for `mono` image.
+ colormap : ColorMap or callable, used only for `mono` image.
Function of the form [x -> (r,g,b)] or [x -> (r,g,b,a)]
for transforming a mono image into RGB.
It must output iterables of length 3 or 4, with values between
@@ -344,21 +350,17 @@ def png_pack(png_tag, data):
)
-def _camelify(out):
+def _camelify(out: str) -> str:
return (
(
"".join(
[
(
"_" + x.lower()
- if i < len(out) - 1
- and x.isupper()
- and out[i + 1].islower() # noqa
+ if i < len(out) - 1 and x.isupper() and out[i + 1].islower()
else (
x.lower() + "_"
- if i < len(out) - 1
- and x.islower()
- and out[i + 1].isupper() # noqa
+ if i < len(out) - 1 and x.islower() and out[i + 1].isupper()
else x.lower()
)
)
@@ -368,10 +370,10 @@ def _camelify(out):
)
.lstrip("_")
.replace("__", "_")
- ) # noqa
+ )
-def _parse_size(value):
+def _parse_size(value: TypeParseSize) -> Tuple[float, str]:
if isinstance(value, (int, float)):
return float(value), "px"
elif isinstance(value, str):
@@ -421,7 +423,7 @@ def _locations_tolist(x):
return x
-def none_min(x, y):
+def none_min(x: Optional[float], y: Optional[float]) -> Optional[float]:
if x is None:
return y
elif y is None:
@@ -430,7 +432,7 @@ def none_min(x, y):
return min(x, y)
-def none_max(x, y):
+def none_max(x: Optional[float], y: Optional[float]) -> Optional[float]:
if x is None:
return y
elif y is None:
@@ -439,7 +441,7 @@ def none_max(x, y):
return max(x, y)
-def iter_points(x):
+def iter_points(x: Union[List, Tuple]) -> list:
"""Iterates over a list representing a feature, and returns a list of points,
whatever the shape of the array (Point, MultiPolyline, etc).
"""
diff --git a/pyproject.toml b/pyproject.toml
index 97c7d5c..26a8a96 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,3 +1,6 @@
[build-system]
requires = ["setuptools>=41.2", "setuptools_scm"]
build-backend = "setuptools.build_meta"
+
+[tool.mypy]
+ignore_missing_imports = true
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 549938d..3c3b55f 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -7,6 +7,7 @@ flake8-mutable
flake8-print
isort
jupyter
+mypy
nbsphinx
nbval
numpy