Skip to content

Commit

Permalink
Visualization: support new models, operator layers, grid tag coloring…
Browse files Browse the repository at this point in the history
…, and cross-grid edge highlighting
  • Loading branch information
tbennun committed Feb 6, 2024
1 parent f9e43d0 commit 3e245cf
Showing 1 changed file with 54 additions and 16 deletions.
70 changes: 54 additions & 16 deletions scripts/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
"""Visualize an LBANN model's layer graph and save to file."""

import argparse
import random
import re
import graphviz
import google.protobuf.text_format
from lbann import lbann_pb2, layers_pb2
from lbann.proto import serialize

# Pastel rainbow (slightly shuffled) from colorkit.co
palette = [
'#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff',
'#bdb2ff', '#ffc6ff', '#ffd6a5'
]

# Parse command-line arguments
parser = argparse.ArgumentParser(
Expand All @@ -17,14 +24,14 @@
parser.add_argument('output',
action='store',
nargs='?',
default='graph.pdf',
default='graph.dot',
type=str,
help='output file (default: graph.pdf)')
help='output file (default: graph.dot)')
parser.add_argument('--file-format',
action='store',
default='pdf',
default='dot',
type=str,
help='output file format (default: pdf)',
help='output file format (default: dot)',
metavar='FORMAT')
parser.add_argument('--label-format',
action='store',
Expand All @@ -39,6 +46,10 @@
type=str,
help='Graphviz visualization scheme (default: dot)',
metavar='ENGINE')
parser.add_argument('--color-cross-grid',
action='store_true',
default=False,
help='Highlight cross-grid edges')
args = parser.parse_args()

# Strip extension from filename
Expand All @@ -51,9 +62,7 @@
label_format = re.sub(r' |-|_', '', args.label_format.lower())

# Read prototext file
proto = lbann_pb2.LbannPB()
with open(args.input, 'r') as f:
google.protobuf.text_format.Merge(f.read(), proto)
proto = serialize.generic_load(args.input)
model = proto.model

# Construct graphviz graph
Expand All @@ -62,29 +71,36 @@
engine=args.graphviz_engine)
graph.attr('node', shape='rect')

layer_to_grid_tag = {}

# Construct nodes in layer graph
layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([
'name', 'parents', 'children', 'datatype', 'data_layout',
'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom',
'type', 'motif_layer'
'type', 'motif_layer', 'parallel_strategy', 'grid_tag'
]))
for l in model.layer:

# Determine layer type
type = ''
ltype = ''
for _type in layer_types:
if l.HasField(_type):
type = getattr(l, _type).DESCRIPTOR.name
ltype = getattr(l, _type).DESCRIPTOR.name
break

# If operator layer, use operator type
if ltype == 'OperatorLayer':
url = l.operator_layer.ops[0].parameters.type_url
ltype = url[url.rfind('.') + 1:]

# Construct node label
label = ''
if label_format == 'nameonly':
label = l.name
elif label_format == 'typeonly':
label = type
label = ltype
elif label_format == 'typeandname':
label = '<{0}<br/>{1}>'.format(type, l.name)
label = '<{0}<br/>{1}>'.format(ltype, l.name)
elif label_format == 'full':
label = '<'
for (index, line) in enumerate(str(l).strip().split('\n')):
Expand All @@ -94,14 +110,36 @@
label += '>'

# Add layer as layer graph node
graph.node(l.name, label=label)
tag = l.grid_tag.value
layer_to_grid_tag[l.name] = tag
attrs = {}
if tag != 0:
attrs = dict(style='filled', fillcolor=palette[tag % len(palette)])
graph.node(l.name, label=label, **attrs)

# Add parent/child relationships as layer graph edges
edges = set()
cross_grid_edges = set()
for l in model.layer:
edges.update([(p, l.name) for p in l.parents.split()])
edges.update([(l.name, c) for c in l.children.split()])
tag = layer_to_grid_tag[l.name]
for p in l.parents:
if tag != layer_to_grid_tag[p]:
cross_grid_edges.add((p, l.name))
else:
edges.add((p, l.name))

for c in l.children:
if tag != layer_to_grid_tag[c]:
cross_grid_edges.add((l.name, c))
else:
edges.add((l.name, c))

graph.edges(edges)
if args.color_cross_grid:
for u, v in cross_grid_edges:
graph.edge(u, v, color='red')
else:
graph.edges(cross_grid_edges)

# Save to file
graph.render(filename=filename, cleanup=True, format=file_format)

0 comments on commit 3e245cf

Please sign in to comment.