diff --git a/algorithms.py b/algorithms.py index 716f882e3..c69b95685 100644 --- a/algorithms.py +++ b/algorithms.py @@ -115,7 +115,7 @@ def find(self, v): # Once we drop support for 3.9 we can use slots=True to prevent # writing extra attrs. -@dataclasses.dataclass # (slots=True) +@dataclasses.dataclass(slots=True) # FIXME class Segment: """ A class representing a single segment. Each segment has a left @@ -130,7 +130,7 @@ class Segment: prev: Segment = None next: Segment = None # noqa: A003 lineage: Lineage = None - population: int = -1 + # population: int = -1 # label: int = 0 def __repr__(self): @@ -145,10 +145,11 @@ def show_chain(seg): return s[:-2] def __lt__(self, other): - return (self.left, self.right, self.population, self.node) < ( + # TODO not clear here why we need population in the key? + return (self.left, self.right, self.lineage.population, self.node) < ( other.left, other.right, - other.population, + other.lineage.population, self.node, ) @@ -730,12 +731,14 @@ def __repr__(self): @dataclasses.dataclass class Lineage: head: Segment + population: int hull: Hull = None label: int = 0 def __str__(self): s = ( - f"Lineage(id={hex(id(self))},label={self.label},hull={self.hull}," + f"Lineage(id={hex(id(self))}," + f"population={self.population},label={self.label},hull={self.hull}," f"head={self.head.index}," f"chain={Segment.show_chain(self.head)})" ) @@ -969,9 +972,11 @@ def __init__( def initialise(self, ts): root_time = np.max(self.tables.nodes.time) self.t = root_time - + # TODO get rid of root_segments_head/tail when we add the + # tail attr to lineage root_segments_head = [None for _ in range(ts.num_nodes)] root_segments_tail = [None for _ in range(ts.num_nodes)] + root_lineages = [None for _ in range(ts.num_nodes)] last_S = -1 for tree in ts.trees(): left, right = tree.interval @@ -985,7 +990,9 @@ def initialise(self, ts): for root in tree.roots: population = ts.node(root).population if root_segments_head[root] is None: - seg = self.alloc_segment(left, right, root, population) + seg = self.alloc_segment(left, right, root) + lineage = self.alloc_lineage(seg, population) + root_lineages[root] = lineage root_segments_head[root] = seg root_segments_tail[root] = seg else: @@ -1002,12 +1009,11 @@ def initialise(self, ts): # Insert the segment chains into the algorithm state. for node in range(ts.num_nodes): - seg = root_segments_head[node] - if seg is not None: + lineage = root_lineages[node] + if lineage is not None: + seg = lineage.head left_end = seg.left - pop = seg.population - lineage = self.alloc_lineage(seg) - self.P[seg.population].add(lineage) + self.add_lineage(lineage) while seg is not None: self.set_segment_mass(seg) seg = seg.next @@ -1085,7 +1091,7 @@ def alloc_segment( left, right, node, - population, + population=None, prev=None, next=None, # noqa: A002 lineage=None, @@ -1097,14 +1103,13 @@ def alloc_segment( s.left = left s.right = right s.node = node - s.population = population s.next = next s.prev = prev s.lineage = lineage return s - def alloc_lineage(self, head, *, label=0): - lineage = Lineage(head, label=label) + def alloc_lineage(self, head, population, *, label=0): + lineage = Lineage(head, population=population, label=label) lineage.reset_segments() x = head while x is not None: @@ -1117,7 +1122,6 @@ def copy_segment(self, segment): left=segment.left, right=segment.right, node=segment.node, - population=segment.population, next=segment.next, prev=segment.prev, lineage=segment.lineage, @@ -1189,7 +1193,7 @@ def store_edge(self, left, right, parent, child): ) def add_lineage(self, lineage): - pop = lineage.head.population + pop = lineage.population self.P[pop].add(lineage, lineage.label) # print("add", lineage) x = lineage.head @@ -1615,7 +1619,7 @@ def process_pedigree_common_ancestors(self, ind, ploid): # ancestor in this ploid of this individual. First we remove # them from the populations they are stored in: for _, seg in common_ancestors: - pop = self.P[seg.population] + pop = self.P[seg.lineage.population] pop.remove_individual(seg.lineage) # Merge together these lists of ancestral segments to create the @@ -1723,11 +1727,7 @@ def migration_event(self, j, k): if self.additional_nodes.value & msprime.NODE_IS_MIG_EVENT > 0: self.store_node(k, flags=msprime.NODE_IS_MIG_EVENT) self.store_arg_edges(x) - # Set the population id for each segment also. - u = x - while u is not None: - u.population = k - u = u.next + lineage.population = k def get_recomb_left_bound(self, seg): """ @@ -1794,6 +1794,8 @@ def hudson_recombination_event(self, label, return_heads=False): """ self.num_re_events += 1 y, bp = self.choose_breakpoint(self.recomb_mass_index[label], self.recomb_map) + left_lineage = y.lineage + assert left_lineage.label == label x = y.prev if y.left < bp: # x y @@ -1825,7 +1827,7 @@ def hudson_recombination_event(self, label, return_heads=False): alpha = y lhs_tail = x - right_lineage = self.alloc_lineage(alpha, label=label) + right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label) if self.model == "smc_k": # modify original hull pop = alpha.population @@ -1839,11 +1841,11 @@ def hudson_recombination_event(self, label, return_heads=False): self.P[alpha.population].add_hull(label, alpha_hull) self.set_segment_mass(alpha) - self.P[alpha.population].add(right_lineage, label) + self.add_lineage(right_lineage) if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: - self.store_node(lhs_tail.population, flags=msprime.NODE_IS_RE_EVENT) + self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) - self.store_node(alpha.population, flags=msprime.NODE_IS_RE_EVENT) + self.store_node(right_lineage.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(alpha) ret = None if return_heads: @@ -1888,7 +1890,8 @@ def wiuf_gene_conversion_within_event(self, label): self.num_gc_events += 1 hull = y.get_hull() assert (self.model == "smc_k") == (hull is not None) - pop = y.population + lineage = y.lineage + pop = lineage.population reset_right = -1 # Process left break @@ -1998,7 +2001,7 @@ def wiuf_gene_conversion_within_event(self, label): elif head is not None: new_individual_head = head if new_individual_head is not None: - lineage = self.alloc_lineage(new_individual_head) + lineage = self.alloc_lineage(new_individual_head, pop) if self.model == "smc_k": assert hull_left < hull_right hull_right = min(self.L, hull_right + self.hull_offset) @@ -2033,7 +2036,8 @@ def wiuf_gene_conversion_left_event(self, label): self.num_gc_events += 1 x = y.prev - pop = y.population + lineage = y.lineage + pop = lineage.population lhs_hull = y.get_hull() assert (self.model == "smc_k") == (lhs_hull is not None) if y.left < bp: @@ -2081,8 +2085,8 @@ def wiuf_gene_conversion_left_event(self, label): self.set_segment_mass(alpha) assert alpha.prev is None - lineage = self.alloc_lineage(alpha) - self.P[alpha.population].add(lineage, label) + lineage = self.alloc_lineage(alpha, pop) + self.add_lineage(lineage) def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): """ @@ -2096,15 +2100,17 @@ def hudson_recombination_event_sweep_phase(self, label, sweep_site, pop_freq): if sweep_site < rhs.left: if r < 1.0 - pop_freq: # move rhs to other population - self.P[rhs.population].remove_individual(right_lin, right_lin.label) + self.P[right_lin.population].remove_individual( + right_lin, right_lin.label + ) self.set_labels(right_lin, 1 - label) - self.P[rhs.population].add(right_lin, right_lin.label) + self.P[right_lin.population].add(right_lin, right_lin.label) else: if r < 1.0 - pop_freq: # move lhs to other population - self.P[rhs.population].remove_individual(left_lin, left_lin.label) + self.P[left_lin.population].remove_individual(left_lin, left_lin.label) self.set_labels(left_lin, 1 - label) - self.P[lhs.population].add(left_lin, left_lin.label) + self.P[left_lin.population].add(left_lin, left_lin.label) def dtwf_generate_breakpoint(self, start): left_bound = start + 1 if self.discrete_genome else start @@ -2213,7 +2219,7 @@ def dtwf_recombine(self, lineage, ind_nodes): lineage.reset_segments() ret.append(lineage) else: - ret.append(self.alloc_lineage(seg)) + ret.append(self.alloc_lineage(seg, lineage.population)) return ret @@ -2243,11 +2249,12 @@ def bottleneck_event(self, pop_id, label, intensity): self.merge_ancestors(H, pop_id, label) def store_additional_nodes_edges(self, flag, new_node_id, z): + # FIXME if self.additional_nodes.value & flag > 0: if new_node_id == -1: - new_node_id = self.store_node(z.population, flags=flag) - else: - self.update_node_flag(new_node_id, flag) + self.store_node(population_index) + new_node_id = len(self.tables.nodes) - 1 + self.update_node_flag(new_node_id, flag) self.store_arg_edges(z, new_node_id) return new_node_id @@ -2273,7 +2280,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): if len(X) == 1: x = X[0] if len(H) > 0 and H[0][0] < x.right: - alpha = self.alloc_segment(x.left, H[0][0], x.node, x.population) + alpha = self.alloc_segment(x.left, H[0][0], x.node) x.left = H[0][0] heapq.heappush(H, (x.left, x)) else: @@ -2320,7 +2327,7 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - new_lineage = self.alloc_lineage(alpha) + new_lineage = self.alloc_lineage(alpha, pop_id) pop.add(new_lineage, label) else: if (coalescence and not self.coalescing_segments_only) or ( @@ -2517,7 +2524,9 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): # loop tail; update alpha and integrate it into the state. if alpha is not None: if z is None: - new_lineage = self.alloc_lineage(alpha, label=label) + new_lineage = self.alloc_lineage( + alpha, population_index, label=label + ) else: if (coalescence and not self.coalescing_segments_only) or ( self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0 @@ -2629,11 +2638,12 @@ def print_state(self, verify=False): self.verify() def verify_segments(self): - for pop in self.P: + for pop_index, pop in enumerate(self.P): for label in range(self.num_labels): for lineage in pop.iter_label(label): assert isinstance(lineage, Lineage) assert lineage.label == label + assert lineage.population == pop_index head = lineage.head assert head.lineage is lineage assert head.prev is None @@ -2645,7 +2655,6 @@ def verify_segments(self): assert prev.next is u assert u.prev is prev assert u.left >= prev.right - assert u.population == head.population prev = u u = u.next @@ -2702,10 +2711,10 @@ def verify_mass_index(self, label, mass_index, rate_map, compute_left_bound): for pop_index, pop in enumerate(self.P): for lineage in pop.iter_label(label): u = lineage.head + assert lineage.population == pop_index assert u.prev is None left = compute_left_bound(u) while u is not None: - assert u.population == pop_index assert u.left < u.right left_bound = compute_left_bound(u) s = rate_map.mass_between(left_bound, u.right)