-
Notifications
You must be signed in to change notification settings - Fork 0
/
dagviz.py
61 lines (53 loc) · 2.32 KB
/
dagviz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Hashable, Dict, List
import array
import networkx as nx
def visualize_dag(g: nx.DiGraph, round_angle: bool = False) -> str:
"""
Creates a text rendering of a directed acyclic graph (DAG) for visualization purposes in a terminal.
:param g: A directed acyclic graph, of type `nx.DiGraph`
:param round_angle: Whether to use a round-angled box drawing character or not
:return: A multi-line string representation of the directed acyclic graph, each line corresponding to a node
"""
assert nx.is_directed_acyclic_graph(g), "Graph contains cycles"
rows: List[Hashable] = []
node_to_row: Dict[Hashable, int] = {}
indents: List[int] = []
def _process_dag(g: nx.DiGraph, indent: int):
for sg in nx.weakly_connected_components(g):
_process_component(g.subgraph(sg), indent=indent)
def _process_component(g: nx.DiGraph, indent: int):
sources = [v for v in g.nodes if g.in_degree(v) == 0]
for i in range(len(sources)):
node_to_row[sources[i]] = len(rows)
rows.append(sources[i])
indents.append(indent + i)
_process_dag(
g.subgraph(set(g.nodes).difference(sources)),
indent=indent + len(sources)
)
_process_dag(g, indent=0)
a = [
array.array('u', [u' '] * indents[i] * 2)
for i in range(len(rows))
]
for i, u in enumerate(rows):
successors = sorted(g.successors(u), key=lambda v: node_to_row[v])
if len(successors) == 0:
continue
l = node_to_row[successors[-1]]
for j in range(i + 1, l):
a[j][indents[i] * 2] = u'│'
for v in successors[:-1]:
j = node_to_row[v]
a[j][indents[i] * 2] = u'┼' \
if indents[i] > 0 and a[j][indents[i] * 2 - 1] == u'─' \
else u'├'
for k in range(indents[i] * 2 + 1, indents[j] * 2):
a[j][k] = u'─'
a[l][indents[i] * 2] = u'┴' \
if indents[i] > 0 and a[l][indents[i] * 2 - 1] == u'─' \
else (u'╰' if round_angle else u'└')
for k in range(indents[i] * 2 + 1, indents[l] * 2):
a[l][k] = u'─'
lines: List[str] = [l.tounicode() + "• " + str(i).replace('\n', ' ') for l, i in zip(a, rows)]
return '\n'.join(lines)