Skip to content

Commit

Permalink
Improve paths() function, more tests, small update on trailing whites…
Browse files Browse the repository at this point in the history
…pace when printing a net
  • Loading branch information
mdko committed Sep 18, 2021
1 parent b37df27 commit 6b0e0b0
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 8 deletions.
65 changes: 58 additions & 7 deletions pyrtl/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,34 @@ def extract_area_delay_from_yosys_output(yosys_output):
return area, delay


class PathsResult(dict):
def print(self, file=sys.stdout):
""" Pretty print the result of calling paths()
:param f: the open file to print to (defaults to stdout)
:return: None
"""
# All this work, to make sure it's determinstic
def path_sort_key(path):
dst_names = [net.dests[0].name if net.dests else '' for net in path]
return (len(path), dst_names)

for start in sorted(self.keys(), key=lambda w: w.name):
print("From %s" % start.name, file=file)
for end in sorted(self[start].keys(), key=lambda w: w.name):
print(" To %s" % end.name, file=file)
paths = self[start][end]
if len(paths) > 0:
for i, paths in enumerate(sorted(paths, key=path_sort_key)):
print(" Path %d" % i, file=file)
for path in paths:
print(" %s" % str(path), file=file)
else:
print(" (No paths)", file=file)


def paths(src=None, dst=None, dst_nets=None, block=None):
""" Get the list of paths from src to dst.
""" Get the list of all paths from src to dst.
:param Union[WireVector, Iterable[WireVector]] src: source wire(s) from which to
trace your paths; if None, will get paths from all Inputs
Expand All @@ -420,7 +446,8 @@ def paths(src=None, dst=None, dst_nets=None, block=None):
:param Block block: block to use (defaults to working block)
:return: a map of the form {src_wire: {dst_wire: [path]}} for each src_wire in src
(or all inputs if src is None), dst_wire in dst (or all outputs if dst is None),
where path is a list of nets
where path is a list of nets. This map is also an instance of PathsResult,
so you can call `print()` on it to pretty print it.
You can provide dst_nets (the result of calling pyrtl.net_connections()), if you plan
on calling this function repeatedly on a block that hasn't changed, to speed things up.
Expand All @@ -434,6 +461,11 @@ def paths(src=None, dst=None, dst_nets=None, block=None):
to a given dst wire.
If src and dst are both single wires, you still need to access the result via paths[src][dst].
This also finds and returns the loop paths in the case of registers or memories that feed into
themselves, i.e. paths[src][src] is not necessarily empty.
It does not distinguish between loops that include synchronous vs asynchronous memories.
"""
block = working_block(block)

Expand Down Expand Up @@ -472,17 +504,36 @@ def dfs(w, curr_path):
paths.append(curr_path)
for dst_net in dst_nets.get(w, []):
# Avoid loops and the mem net (has no output wire)
if (dst_net not in curr_path) and (dst_net.op != '@'):
dfs(dst_net.dests[0], curr_path + [dst_net])
if dst_net not in curr_path:
if dst_net.op == '@': # dests will be the read ports
for read_net in dst_net.op_param[1].readport_nets:
dfs(read_net.dests[0], curr_path + [dst_net, read_net])
else:
dfs(dst_net.dests[0], curr_path + [dst_net])
dfs(src, [])
return paths

all_paths = collections.defaultdict(dict)
for src_wire in src:
for dst_wire in dst:
all_paths[src_wire][dst_wire] = paths_src_dst(src_wire, dst_wire)

return all_paths
paths = paths_src_dst(src_wire, dst_wire)
# Remove empty paths...
paths = list(filter(lambda x: len(x) > 0, paths))
# ...and those that are supersets of others (resulting from an inner loop).
if src_wire is not dst_wire:
paths = sorted(paths, key=lambda p: len(p), reverse=True)
keep = []
for i in range(len(paths)):
# Check if there is a path in paths[i+1:] that is the suffix
# of paths[i] (paths[i] is at least as large as each path in
# paths[i+1:]). If so, paths[i] contains a loop since both start
# at src_wire, so don't keep it.
if not any(paths[i][-len(p):] == p for p in paths[i + 1:]):
keep.append(paths[i])
paths = keep
all_paths[src_wire][dst_wire] = paths

return PathsResult(all_paths)


def distance(src, dst, f, block=None):
Expand Down
3 changes: 2 additions & 1 deletion pyrtl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def __str__(self):

else: # not in ipython
if self.op in 'w~&|^n+-*<>=xcsr':
return "{} <-- {} -- {} {}".format(lhs, self.op, rhs, options)
options = ' ' + options if options else ''
return "{} <-- {} -- {}{}".format(lhs, self.op, rhs, options)
elif self.op in 'm@':
memid, memblock = self.op_param
extrainfo = 'memid=' + str(memid)
Expand Down
129 changes: 129 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import print_function, unicode_literals, absolute_import

import unittest
import io
import pyrtl


Expand Down Expand Up @@ -113,10 +114,50 @@ def setUp(self):
pyrtl.reset_working_block()


paths_print_output = """\
From i
To o
Path 0
tmp5/3W <-- - -- i/2I, tmp4/2W
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
o/3O <-- w -- tmp6/3W
Path 1
tmp1/3W <-- c -- tmp0/1W, i/2I
tmp2/3W <-- & -- tmp1/3W, j/3I
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
o/3O <-- w -- tmp6/3W
To p
Path 0
tmp8/4W <-- c -- tmp7/2W, i/2I
tmp9/5W <-- - -- k/4I, tmp8/4W
p/5O <-- w -- tmp9/5W
From j
To o
Path 0
tmp2/3W <-- & -- tmp1/3W, j/3I
tmp6/3W <-- | -- tmp2/3W, tmp5/3W
o/3O <-- w -- tmp6/3W
To p
(No paths)
From k
To o
(No paths)
To p
Path 0
tmp9/5W <-- - -- k/4I, tmp8/4W
p/5O <-- w -- tmp9/5W
"""


class TestPaths(unittest.TestCase):

def setUp(self):
pyrtl.reset_working_block()
# To compare textual consistency, need to make
# sure we're starting at the same index for all
# automatically created names.
pyrtl.wire._reset_wire_indexers()
pyrtl.memory._reset_memory_indexer()

def test_one_path_to_one_output(self):
a = pyrtl.Input(4, 'a')
Expand Down Expand Up @@ -181,6 +222,83 @@ def test_subset_of_all_paths(self):
self.assertNotIn(p, paths_from_k) # Because p was not provided as target output
self.assertEqual(len(paths_from_k[o]), 0) # 0 paths from k to o

def test_paths_empty_src_and_dst_equal_with_no_other_logic(self):
i = pyrtl.Input(4, 'i')
paths = pyrtl.paths(i, i)
self.assertEqual(len(paths[i][i]), 0)

def test_paths_with_loop(self):
r = pyrtl.Register(1, 'r')
r.next <<= r & ~r
paths = pyrtl.paths(r, r)
self.assertEqual(len(paths[r][r]), 2)
p1, p2 = sorted(paths[r][r], key=lambda p: len(p), reverse=True)
self.assertEqual(len(p1), 3)
self.assertEqual(p1[0].op, '~')
self.assertEqual(p1[1].op, '&')
self.assertEqual(p1[2].op, 'r')
self.assertEqual(len(p2), 2)
self.assertEqual(p2[0].op, '&')
self.assertEqual(p2[1].op, 'r')

def test_paths_loop_and_input(self):
i = pyrtl.Input(1, 'i')
o = pyrtl.Output(1, 'o')
r = pyrtl.Register(1, 'r')
r.next <<= i & r
o <<= r
paths = pyrtl.paths(r, o)
self.assertEqual(len(paths[r][o]), 1)

def test_paths_loop_get_arbitrary_inner_wires(self):
w = pyrtl.WireVector(1, 'w')
y = w & pyrtl.Const(1)
w <<= ~y
paths = pyrtl.paths(w, y)
self.assertEqual(len(paths[w][y]), 1)
self.assertEqual(paths[w][y][0][0].op, '&')

def test_paths_no_path_exists(self):
i = pyrtl.Input(1, 'i')
o = pyrtl.Output(1, 'o')
o <<= ~i

w = pyrtl.WireVector(1, 'w')
y = w & pyrtl.Const(1)
w <<= ~y

paths = pyrtl.paths(w, o)
self.assertEqual(len(paths[w][o]), 0)

def test_paths_with_memory(self):
i = pyrtl.Input(4, 'i')
o = pyrtl.Output(8, 'o')
mem = pyrtl.MemBlock(8, 32, 'mem')
waddr = pyrtl.Input(32, 'waddr')
raddr = pyrtl.Input(32, 'raddr')
data = mem[raddr]
mem[waddr] <<= (i + ~data).truncate(8)
o <<= data

paths = pyrtl.paths(i, o)
path = paths[i][o][0]
self.assertEqual(path[0].op, 'c')
self.assertEqual(path[1].op, '+')
self.assertEqual(path[2].op, 's')
self.assertEqual(path[3].op, '@')
self.assertEqual(path[4].op, 'm')
self.assertEqual(path[5].op, 'w')

# TODO Once issue with _MemIndexed lookups is resolved,
# these should be `data` instead of `data.wire`.
paths = pyrtl.paths(data.wire, data.wire)
path = paths[data.wire][data.wire][0]
self.assertEqual(path[0].op, '~')
self.assertEqual(path[1].op, '+')
self.assertEqual(path[2].op, 's')
self.assertEqual(path[3].op, '@')
self.assertEqual(path[4].op, 'm')

def test_all_paths(self):
a, b, c = pyrtl.input_list('a/2 b/4 c/1')
o, p = pyrtl.output_list('o/4 p/2')
Expand Down Expand Up @@ -218,6 +336,17 @@ def test_all_paths(self):
paths_c_to_p = paths[c][p]
self.assertEqual(len(paths_c_to_p), 1)

def test_pretty_print(self):
i, j, k = pyrtl.input_list('i/2 j/3 k/4')
o, p = pyrtl.Output(name='o'), pyrtl.Output(name='p')
o <<= (i & j) | (i - 1)
p <<= k - i

paths = pyrtl.paths()
output = io.StringIO()
paths.print(file=output)
self.assertEqual(output.getvalue(), paths_print_output)


class TestDistance(unittest.TestCase):

Expand Down

0 comments on commit 6b0e0b0

Please sign in to comment.