Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve speed of BottomUp #309

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/ruptures/detection/bottomup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None

def _grow_tree(self):
"""Grow the entire binary tree."""
partition = [(0, self.n_samples)]
partition = [(-self.n_samples, (0, self.n_samples))]
stop = False
while not stop: # recursively divide the signal
stop = True
start, end = max(partition, key=lambda t: t[1] - t[0])
_, (start, end) = partition[0]
mid = (start + end) * 0.5
bkps = list()
for bkp in range(start, end):
Expand All @@ -50,15 +50,15 @@ def _grow_tree(self):
bkps.append(bkp)
if len(bkps) > 0: # at least one admissible breakpoint was found
bkp = min(bkps, key=lambda x: abs(x - mid))
partition.remove((start, end))
partition.append((start, bkp))
partition.append((bkp, end))
heapq.heappop(partition)
heapq.heappush(partition, (-bkp + start, (start, bkp)))
heapq.heappush(partition, (-end + bkp, (bkp, end)))
stop = False

partition.sort()
partition.sort(key=lambda x: x[1])
# compute segment costs
leaves = list()
for start, end in partition:
for _, (start, end) in partition:
val = self.cost.error(start, end)
leaf = Bnode(start, end, val)
leaves.append(leaf)
Expand Down Expand Up @@ -87,6 +87,7 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None):
dict: partition dict {(start, end): cost value,...}
"""
leaves = sorted(self.leaves)
keys = [leaf.start for leaf in leaves]
removed = set()
merged = []
for left, right in pairwise(leaves):
Expand Down Expand Up @@ -121,10 +122,13 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None):
if not stop:
# updates the list of leaves (i.e. segments of the partitions)
# find the merged segments indexes
keys = [leaf.start for leaf in leaves]
left_idx = bisect_left(keys, leaf.left.start)
leaves[left_idx] = leaf # replace leaf.left
del leaves[left_idx + 1] # remove leaf.right
# replace leaf.left
leaves[left_idx] = leaf
keys[left_idx] = leaf.start
# remove leaf.right
del leaves[left_idx + 1]
del keys[left_idx + 1]
# add to the set of removed segments.
removed.add(leaf.left)
removed.add(leaf.right)
Expand Down
Loading