Skip to content

Commit

Permalink
Custom color mapping circos nodes and edges (#689)
Browse files Browse the repository at this point in the history
* add custom mapping

* Update contributors.md

* black formatting

* ensure palette is used properly

* Update encodings.py

don't pass palette to the other functions

---------

Co-authored-by: Kelvin <[email protected]>
  • Loading branch information
zktuong and zktuong authored Aug 9, 2024
1 parent fe3e491 commit 1bbbcda
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 41 deletions.
1 change: 1 addition & 0 deletions docs/contributors.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
- Eduarda Centeno <eduardacenteno (at) hotmail.com>
- Alireza Hosseini <alirezatheh (at) gmail.com>
- Yashrajsinh Jadeja (@Yashrajsinh-Jadeja)
- Kelvin Tuong <@zktuong>
25 changes: 18 additions & 7 deletions nxviz/annotate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Annotation submodule."""
from functools import partial, update_wrapper
from typing import Dict, Hashable
from typing import Dict, Hashable, Union, Optional, List

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -240,13 +240,18 @@ def matrix_block(
ax.add_patch(patch)


def colormapping(data: pd.Series, legend_kwargs: Dict = {}, ax=None):
def colormapping(
data: pd.Series,
legend_kwargs: Dict = {},
ax=None,
palette: Optional[Union[Dict, List]] = None,
):
"""Annotate node color mapping.
If the color attribute is continuous, a colorbar will be added to the matplotlib figure.
Otherwise, a legend will be added.
"""
cmap, data_family = encodings.data_cmap(data)
cmap, data_family = encodings.data_cmap(data, palette)
if ax is None:
ax = plt.gca()
if data_family == "continuous":
Expand All @@ -258,8 +263,12 @@ def colormapping(data: pd.Series, legend_kwargs: Dict = {}, ax=None):
fig = plt.gcf()
fig.colorbar(scalarmap)
else:
labels = data.drop_duplicates().sort_values()
cfunc = encodings.color_func(data)
if (palette is not None) and (isinstance(palette, dict)):
labels = pd.Series(list(palette.keys()))
else:
labels = pd.Series(data.unique())
cmap, _ = encodings.data_cmap(labels, palette)
cfunc = encodings.color_func(labels, palette)
colors = labels.apply(cfunc)
patchlist = []
for color, label in zip(colors, labels):
Expand All @@ -280,25 +289,27 @@ def node_colormapping(
color_by: Hashable,
legend_kwargs: Dict = {"loc": "upper right", "bbox_to_anchor": (0.0, 1.0)},
ax=None,
palette: Optional[Union[Dict, List]] = None,
):
"""Annotate node color mapping."""
nt = utils.node_table(G)
data = nt[color_by]
colormapping(data, legend_kwargs, ax)
colormapping(data, legend_kwargs, ax, palette)


def edge_colormapping(
G: nx.Graph,
color_by: Hashable,
legend_kwargs: Dict = {"loc": "lower right", "bbox_to_anchor": (0.0, 0.0)},
ax=None,
palette: Optional[Union[Dict, List]] = None,
):
"""Annotate edge color mapping."""
if ax is None:
ax = plt.gca()
et = utils.edge_table(G)
data = et[color_by]
colormapping(data, legend_kwargs, ax)
colormapping(data, legend_kwargs, ax, palette)


def node_labels(G, layout_func, group_by, sort_by, fontdict={}, ax=None):
Expand Down
33 changes: 31 additions & 2 deletions nxviz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from functools import partial, update_wrapper
from typing import Callable, Dict, Hashable
from typing import Callable, Dict, Hashable, Optional, Union, List

import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -31,6 +31,8 @@ def base(
edge_enc_kwargs: Dict = {},
node_layout_kwargs: Dict = {},
edge_line_kwargs: Dict = {},
node_palette: Optional[Union[Dict, List]] = None,
edge_palette: Optional[Union[Dict, List]] = None,
):
"""High-level graph plotting function.
Expand All @@ -49,6 +51,10 @@ def base(
- `node_size_by`: Node metadata attribute key to set node size.
- `node_enc_kwargs`: Keyword arguments to set node visual encodings.
TODO: Elaborate on what these arguments are.
- `node_palette`: Optional custom palette of colours for plotting categorical groupings
in a list/dictionary. Colours must be values `matplotlib.colors.ListedColormap`
can interpret. If a dictionary is provided, key and record corresponds to
category and colour respectively.
### Edges
Expand All @@ -60,6 +66,7 @@ def base(
- `edge_lw_by`: Edge metdata attribute key to set edge line width.
- `edge_alpha_by`: Edge metdata attribute key to set edge transparency.
- `edge_enc_kwargs`: Keyword arguments to set edge visual encodings.
- `edge_palette`: Same as node_palette but for edges.
TODO: Elaborate on what these arguments are.
"""
pos = node_layout_func(
Expand All @@ -71,6 +78,7 @@ def base(
alpha_by=node_alpha_by,
encodings_kwargs=node_enc_kwargs,
layout_kwargs=node_layout_kwargs,
palette=node_palette,
)
edge_line_func(
G,
Expand All @@ -80,6 +88,7 @@ def base(
lw_by=edge_lw_by,
alpha_by=edge_alpha_by,
encodings_kwargs=edge_enc_kwargs,
palette=edge_palette,
)

despine()
Expand Down Expand Up @@ -146,6 +155,8 @@ def base_cloned(
node_layout_kwargs: Dict = {},
edge_line_kwargs: Dict = {},
cloned_node_layout_kwargs: Dict = {},
node_palette: Optional[Union[Dict, List]] = None,
edge_palette: Optional[Union[Dict, List]] = None,
):
"""High-level graph plotting function.
Expand All @@ -164,6 +175,10 @@ def base_cloned(
- `node_size_by`: Node metadata attribute key to set node size.
- `node_enc_kwargs`: Keyword arguments to set node visual encodings.
TODO: Elaborate on what these arguments are.
- `node_palette`: Optional custom palette of colours for plotting categorical groupings
in a list/dictionary. Colours must be values `matplotlib.colors.ListedColormap`
can interpret. If a dictionary is provided, key and record corresponds to
category and colour respectively.
### Edges
Expand All @@ -175,7 +190,7 @@ def base_cloned(
- `edge_lw_by`: Edge metdata attribute key to set edge line width.
- `edge_alpha_by`: Edge metdata attribute key to set edge transparency.
- `edge_enc_kwargs`: Keyword arguments to set edge visual encodings.
TODO: Elaborate on what these arguments are.
- `edge_palette`: Same as node_palette but for edges.
"""
pos = node_layout_func(
G,
Expand All @@ -186,6 +201,7 @@ def base_cloned(
alpha_by=node_alpha_by,
encodings_kwargs=node_enc_kwargs,
layout_kwargs=node_layout_kwargs,
palette=node_palette,
)
pos_cloned = node_layout_func(
G,
Expand All @@ -196,6 +212,7 @@ def base_cloned(
alpha_by=node_alpha_by,
encodings_kwargs=node_enc_kwargs,
layout_kwargs=cloned_node_layout_kwargs,
palette=node_palette,
)
edge_line_func(
G,
Expand All @@ -206,6 +223,7 @@ def base_cloned(
lw_by=edge_lw_by,
alpha_by=edge_alpha_by,
encodings_kwargs=edge_enc_kwargs,
palette=edge_palette,
**edge_line_kwargs,
)

Expand Down Expand Up @@ -253,6 +271,8 @@ def __init__(
edge_alpha: Hashable = None,
edge_width: Hashable = None,
edgeprops: Dict = None,
node_palette: Optional[Union[Dict, List]] = None,
edge_palette: Optional[Union[Dict, List]] = None,
):
"""Instantiate a plot.
Expand All @@ -269,6 +289,11 @@ def __init__(
- `edge_alpha`: The edge attribute on which to specify the transparency of edges.
- `edge_width`: The edge attribute on which to specify the width of edges.
- `edgeprops`: A `matplotlib-compatible `props` dictionary.
- `node_palette`: Optional custom palette of colours for plotting categorical groupings
in a list/dictionary. Colours must be values `matplotlib.colors.ListedColormap`
can interpret. If a dictionary is provided, key and record corresponds to
category and colour respectively.
- `edge_palette`: Same as node_palette but for edges.
"""
import warnings

Expand Down Expand Up @@ -300,6 +325,8 @@ def draw():
"edge_alpha_by",
"edge_lw_by",
"edge_enc_kwargs",
"node_palette",
"edge_palette",
]

object_api_names = [
Expand All @@ -313,6 +340,8 @@ def draw():
"edge_alpha",
"edge_width",
"edgeprops",
"node_palette",
"edge_palette",
]

functional_to_object = dict(zip(functional_api_names, object_api_names))
Expand Down
15 changes: 10 additions & 5 deletions nxviz/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from copy import deepcopy
from functools import partial, update_wrapper
from typing import Callable, Dict, Hashable, Tuple, Optional
from typing import Callable, Dict, Hashable, Tuple, Optional, Union, List

import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -44,16 +44,16 @@ def edge_colors(
nt: pd.DataFrame,
color_by: Hashable,
node_color_by: Hashable,
palette: Optional[Union[Dict, List]] = None,
):
"""Default edge line color function."""
if color_by in ("source_node_color", "target_node_color"):
edge_select_by = color_by.split("_")[0]
return encodings.data_color(
et[edge_select_by].apply(nt[node_color_by].get),
nt[node_color_by],
et[edge_select_by].apply(nt[node_color_by].get), nt[node_color_by], palette
)
elif color_by:
return encodings.data_color(et[color_by], et[color_by])
return encodings.data_color(et[color_by], et[color_by], palette)
return pd.Series(["black"] * len(et), name="color_by")


Expand Down Expand Up @@ -85,6 +85,7 @@ def draw(
alpha_by: Hashable = None,
ax=None,
encodings_kwargs: Dict = {},
palette: Optional[Union[Dict, List]] = None,
**linefunc_kwargs,
):
"""Draw edges to matplotlib axes.
Expand All @@ -108,6 +109,10 @@ def draw(
- `ax`: Matplotlib axes object to plot onto.
- `encodings_kwargs`: A dictionary of kwargs
to determine the visual properties of the edge.
- `palette`: Optional custom palette of colours for plotting categorical groupings
in a list/dictionary. Colours must be values `matplotlib.colors.ListedColormap`
can interpret. If a dictionary is provided, key and record corresponds to
category and colour respectively.
- `linefunc_kwargs`: All other keyword arguments passed in
will be passed onto the appropriate linefunc.
Expand Down Expand Up @@ -135,7 +140,7 @@ def draw(
if ax is None:
ax = plt.gca()
validate_color_by(G, color_by, node_color_by)
edge_color = edge_colors(et, nt, color_by, node_color_by)
edge_color = edge_colors(et, nt, color_by, node_color_by, palette)
encodings_kwargs = deepcopy(encodings_kwargs)
lw = line_width(et, lw_by) * encodings_kwargs.pop("lw_scale", 1.0)

Expand Down
Loading

0 comments on commit 1bbbcda

Please sign in to comment.