Skip to content

Commit

Permalink
Merge pull request #277 from mehatamm/master
Browse files Browse the repository at this point in the history
Automatic Graph Drawing for draw_d3
  • Loading branch information
jvdwetering authored Nov 19, 2024
2 parents a2775a2 + 0f782d5 commit cd0f992
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions pyzx/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def draw(g: Union[BaseGraph[VT,ET], Circuit], labels: bool=False, **kwargs) -> A
# allow global setting to labels=False
# TODO: probably better to make labels Optional[bool]
labels = labels or settings.show_labels

if get_mode() == "shell":
return draw_matplotlib(g, labels, **kwargs)
elif get_mode() == "browser":
Expand Down Expand Up @@ -288,15 +287,51 @@ def draw_matplotlib(
# library_code += '</script>'
# display(HTML(library_code))

def auto_layout_vertex_locs(g:BaseGraph[VT, ET]): #Force-based graph drawing algorithm given by Eades(1984):
c1 = 2 #Sample parameters that work decently well
c2 = 1
c3 = 1
c4 = .1
v_locs:Dict[VT, Tuple[float, float]] = dict()
for v in g.vertices():
v_locs[v]=(random.random()*math.sqrt(g.num_vertices()), random.random()*math.sqrt(g.num_vertices()))
for i in range(100): #100 iterations of force-based drawing
forces:Dict[VT, Tuple[float, float]] = dict()
for v in g.vertices():
forces[v] = (0, 0)
for v1 in g.vertices():
if(v!=v1):
diff = (v_locs[v][0]-v_locs[v1][0], v_locs[v][1]-v_locs[v1][1])
d = math.sqrt(diff[0]*diff[0]+diff[1]*diff[1])
if g.connected(v1, v): #edge between vertices: apply rule c1*log(d/c2)
force_mag = -c1*math.log(d/c2) #negative force attracts
elif v != v1: #nonadjacent vertices: apply rule -c3/d^2
force_mag = c3/(d*d) #positive force repels
else: #free body in question, applies no force on itself
raise ValueError("Vertices ended up at same point")
v_force = (diff[0]*force_mag*c4/d, diff[1]*force_mag*c4/d)
forces[v] = (forces[v][0]+v_force[0], forces[v][1]+v_force[1])
for v in g.vertices(): #leave y value constant if input or output
v_locs[v]=(v_locs[v][0]+forces[v][0], v_locs[v][1]+forces[v][1])
max_x = max(v[0] for v in v_locs.values())
min_x = min(v[0] for v in v_locs.values())
max_y = max(v[1] for v in v_locs.values())
min_y = min(v[1] for v in v_locs.values())
v_locs = {k:(v[0]-min_x, v[1]-min_y) for k, v in v_locs.items()} #translate to origin
return v_locs, max_x-min_x, max_y-min_y


def draw_d3(
g: Union[BaseGraph[VT,ET], Circuit],
labels:bool=False,
scale:Optional[FloatInt]=None,
auto_hbox:Optional[bool]=None,
show_scalar:bool=False,
vdata: List[str]=[]
vdata: List[str]=[],
auto_layout = False
) -> Any:

"""If auto_layout is checked, will automatically space vertices of graph
with no regard to qubit/row."""
if get_mode() not in ("notebook", "browser"):
raise Exception("This method only works when loaded in a webpage or Jupyter notebook")

Expand All @@ -310,25 +345,35 @@ def draw_d3(
# use an 8-digit random alphanum instead.
graph_id = ''.join(random_graphid.choice(string.ascii_letters + string.digits) for _ in range(8))

minrow = min([g.row(v) for v in g.vertices()], default=0)
maxrow = max([g.row(v) for v in g.vertices()], default=0)
minqub = min([g.qubit(v) for v in g.vertices()], default=0)
maxqub = max([g.qubit(v) for v in g.vertices()], default=0)
if(auto_layout):
v_dict, w, h = auto_layout_vertex_locs(g)
if scale is None:
scale = 800 / w
if scale > 50: scale = 50
if scale < 20: scale = 20

w = (w+2) * scale
h = (h+3) * scale
else:
minrow = min([g.row(v) for v in g.vertices()], default=0)
maxrow = max([g.row(v) for v in g.vertices()], default=0)
minqub = min([g.qubit(v) for v in g.vertices()], default=0)
maxqub = max([g.qubit(v) for v in g.vertices()], default=0)

if scale is None:
scale = 800 / (maxrow-minrow + 2)
if scale > 50: scale = 50
if scale < 20: scale = 20

if scale is None:
scale = 800 / (maxrow-minrow + 2)
if scale > 50: scale = 50
if scale < 20: scale = 20
w = (maxrow-minrow + 2) * scale
h = (maxqub-minqub + 3) * scale

node_size = 0.2 * scale
if node_size < 2: node_size = 2

w = (maxrow-minrow + 2) * scale
h = (maxqub-minqub + 3) * scale

nodes = [{'name': str(v),
'x': (g.row(v)-minrow + 1) * scale,
'y': (g.qubit(v)-minqub + 2) * scale,
'x': (v_dict[v][0]+1)*scale if auto_layout else (g.row(v)-minrow + 1) * scale,
'y': (v_dict[v][1]+2)*scale if auto_layout else (g.qubit(v)-minqub + 2) * scale,
't': g.type(v),
'phase': phase_to_s(g.phase(v), g.type(v)) if g.type(v) != VertexType.Z_BOX else str(get_z_box_label(g, v)),
'ground': g.is_ground(v),
Expand Down

0 comments on commit cd0f992

Please sign in to comment.