-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbaltic.py
1403 lines (1179 loc) · 71.8 KB
/
baltic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from matplotlib.collections import LineCollection
import re,copy,math,json,sys
import datetime as dt
from functools import reduce
from matplotlib.collections import LineCollection
__all__ = ['decimalDate', 'convertDate', 'calendarDate', 'reticulation', # make from baltic import * safe
'clade', 'leaf', 'node', 'tree',
'make_tree', 'make_treeJSON', 'loadJSON', 'loadNexus', 'loadNewick', 'untangle']
sys.setrecursionlimit(9001)
def decimalDate(date,fmt="%Y-%m-%d",variable=False):
""" Converts calendar dates in specified format to decimal date. """
if fmt == "":
return date
delimiter=re.search('[^0-9A-Za-z%]',fmt) ## search for non-alphanumeric symbols in fmt (should be field delimiter)
delimit=None
if delimiter is not None:
delimit=delimiter.group()
if variable==True: ## if date is variable - extract what is available
if delimit is not None:
dateL=len(date.split(delimit)) ## split date based on symbol
else:
dateL=1 ## no non-alphanumeric characters in date, assume dealing with an imprecise date (something like just year)
if dateL==2:
fmt=delimit.join(fmt.split(delimit)[:-1]) ## reduce fmt down to what's available
elif dateL==1:
fmt=delimit.join(fmt.split(delimit)[:-2])
adatetime=dt.datetime.strptime(date,fmt) ## convert to datetime object
year = adatetime.year ## get year
boy = dt.datetime(year, 1, 1) ## get beginning of the year
eoy = dt.datetime(year + 1, 1, 1) ## get beginning of next year
return year + ((adatetime - boy).total_seconds() / ((eoy - boy).total_seconds())) ## return fractional year
def calendarDate(timepoint,fmt='%Y-%m-%d'):
""" Converts decimal dates to a specified calendar date format. """
year = int(timepoint)
rem = timepoint - year
base = dt.datetime(year, 1, 1)
result = base + dt.timedelta(seconds=(base.replace(year=base.year + 1) - base).total_seconds() * rem)
return dt.datetime.strftime(result,fmt)
def convertDate(x,start,end):
""" Converts calendar dates between given formats """
return dt.datetime.strftime(dt.datetime.strptime(x,start),end)
class reticulation: ## reticulation class (recombination, conversion, reassortment)
def __init__(self,name):
self.branchType='leaf'
self.length=0.0
self.height=0.0
self.absoluteTime=None
self.parent=None
self.traits={}
self.index=None
self.name=name
self.x=None
self.y=None
self.width=0.5
self.target=None
def is_leaflike(self):
return True
def is_leaf(self):
return False
def is_node(self):
return False
class clade: ## clade class
def __init__(self,givenName):
self.branchType='leaf' ## clade class poses as a leaf
self.subtree=None ## subtree will contain all the branches that were collapsed
self.leaves=None
self.length=0.0
self.height=None
self.absoluteTime=None
self.parent=None
self.traits={}
self.index=None
self.name=givenName ## the pretend tip name for the clade
self.x=None
self.y=None
self.lastHeight=None ## refers to the height of the highest tip in the collapsed clade
self.lastAbsoluteTime=None ## refers to the absolute time of the highest tip in the collapsed clade
self.width=1
def is_leaflike(self):
return True
def is_leaf(self):
return False
def is_node(self):
return False
class node: ## node class
def __init__(self):
self.branchType='node'
self.length=0.0 ## branch length, recovered from string
self.height=None ## height, set by traversing the tree, which adds up branch lengths along the way
self.absoluteTime=None ## branch end point in absolute time, once calibrations are done
self.parent=None ## reference to parent node of the node
self.children=[] ## a list of descendent branches of this node
self.traits={} ## dictionary that will contain annotations from the tree string, e.g. {'posterior':1.0}
self.index=None ## index of the character designating this object in the tree string, it's a unique identifier for every object in the tree
self.childHeight=None ## the youngest descendant tip of this node
self.x=None ## X and Y coordinates of this node, once drawTree() is called
self.y=None
## contains references to all tips of this node
self.leaves=set() ## is a set of tips that are descended from it
def is_leaflike(self):
return False
def is_leaf(self):
return False
def is_node(self):
return True
class leaf: ## leaf class
def __init__(self):
self.branchType='leaf'
self.name=None ## name of tip after translation, since BEAST trees will generally have numbers for taxa but will provide a map at the beginning of the file
self.index=None ## index of the character that defines this object, will be a unique ID for each object in the tree
self.length=None ## branch length
self.absoluteTime=None ## position of tip in absolute time
self.height=None ## height of tip
self.parent=None ## parent
self.traits={} ## trait dictionary
self.x=None ## position of tip on x axis if the tip were to be plotted
self.y=None ## position of tip on y axis if the tip were to be plotted
def is_leaflike(self):
return False
def is_leaf(self):
return True
def is_node(self):
return False
class tree: ## tree class
def __init__(self):
self.cur_node=node() ## current node is a new instance of a node class
self.cur_node.index='Root' ## first object in the tree is the root to which the rest gets attached
self.cur_node.length=0.0 ## startind node branch length is 0
self.cur_node.height=0.0 ## starting node height is 0
self.root=None #self.cur_node ## root of the tree is current node
self.Objects=[] ## tree objects have a flat list of all branches in them
self.tipMap=None
self.treeHeight=0 ## tree height is the distance between the root and the most recent tip
self.mostRecent=None
self.ySpan=0.0
def add_reticulation(self,name):
""" Adds a reticulate branch. """
ret=reticulation(name)
ret.index=name
ret.parent=self.cur_node
self.cur_node.children.append(ret)
self.Objects.append(ret)
self.cur_node=ret
def add_node(self,i):
""" Attaches a new node to current node. """
new_node=node() ## new node instance
new_node.index=i ## new node's index is the position along the tree string
if self.root is None:
self.root=new_node
new_node.parent=self.cur_node ## new node's parent is current node
assert self.cur_node.is_node(), 'Attempted to add a child to a non-node object. Check if tip names have illegal characters like parentheses.'
self.cur_node.children.append(new_node) ## new node is a child of current node
self.cur_node=new_node ## current node is now new node
self.Objects.append(self.cur_node) ## add new node to list of objects in the tree
def add_leaf(self,i,name):
""" Attach a new leaf (tip) to current node. """
new_leaf=leaf() ## new instance of leaf object
new_leaf.index=i ## index is position along tree string
if self.root is None: self.root=new_leaf
new_leaf.parent=self.cur_node ## leaf's parent is current node
assert self.cur_node.is_node(), 'Attempted to add a child to a non-node object. Check if tip names have illegal characters like parentheses.'
self.cur_node.children.append(new_leaf) ## assign leaf to parent's children
new_leaf.name=name
self.cur_node=new_leaf ## current node is now new leaf
self.Objects.append(self.cur_node) ## add leaf to all objects in the tree
def subtree(self,k=None,traverse_condition=None,stem=True):
""" Generate a subtree (as a baltic tree object) from a traversal.
k is the starting branch for traversal (default: root).
traverse_condition is a function that determines whether a child branch should be visited (default: always true).
Returns a new baltic tree instance.
Note - custom traversal functions can result in multitype trees.
If this is undesired call singleType() on the resulting subtree afterwards. """
subtree=copy.deepcopy(self.traverse_tree(k,include_condition=lambda k:True,traverse_condition=traverse_condition))
if subtree is None or len([k for k in subtree if k.is_leaf()])==0:
return None
else:
local_tree=tree() ## create a new tree object where the subtree will be
local_tree.Objects=subtree ## assign branches to new tree object
local_tree.root=subtree[0] ## root is the beginning of the traversal
if stem==True: ## we want the stem
local_tree.root.parent=copy.deepcopy(k.parent) ## means assigning an invisible parent to root
local_tree.root.parent.height=0.0 ## set height to 0.0 so heights can be set
if local_tree.root.parent.parent: local_tree.root.parent.parent=None ## remove reference to the invisible parent's parent
else: ## don't want stem
local_tree.root.parent=None ## tree begins strictly at node
subtree_set=set(subtree) ## turn branches into set for quicker look up later
if traverse_condition is not None: ## didn't use default traverse condition, might need to deal with hanging nodes and prune children
for nd in local_tree.getInternal(): ## iterate over nodes
nd.children=list(filter(lambda k:k in subtree_set,nd.children)) ## only keep children seen in traversal
local_tree.fixHangingNodes()
if self.tipMap: ## if original tree has a tipMap dictionary
local_tree.tipMap={tipNum: self.tipMap[tipNum] for tipNum in self.tipMap if self.tipMap[tipNum] in [w.name for w in local_tree.getExternal()]} ## copy over the relevant tip translations
return local_tree
def singleType(self):
""" Removes any branches with a single child (multitype nodes). """
multiTypeNodes=[k for k in self.Objects if k.is_node() and len(k.children)==1]
while len(multiTypeNodes)>0:
multiTypeNodes=[k for k in self.Objects if k.is_node() and len(k.children)==1]
for k in sorted(multiTypeNodes,key=lambda x:-x.height):
child=k.children[0] ## fetch child
grandparent=k.parent if k.parent.index else self.root ## fetch grandparent
child.parent=grandparent ## child's parent is now grandparent
grandparent.children.append(child) ## add child to grandparent's children
grandparent.children.remove(k) ## remove old parent from grandparent's children
grandparent.children=list(set(grandparent.children))
child.length+=k.length ## adjust child length
multiTypeNodes.remove(k) ## remove old parent from multitype nodes
self.Objects.remove(k) ## remove old parent from all objects
self.sortBranches()
def setAbsoluteTime(self,date):
""" place all objects in absolute time by providing the date of the most recent tip """
for i in self.Objects: ## iterate over all objects
i.absoluteTime=date-self.treeHeight+i.height ## heights are in units of time from the root
self.mostRecent=max([k.absoluteTime for k in self.Objects])
def treeStats(self):
""" provide information about the tree """
self.traverse_tree() ## traverse the tree
obs=self.Objects ## convenient list of all objects in the tree
print('\nTree height: %.6f\nTree length: %.6f'%(self.treeHeight,sum([x.length for x in obs]))) ## report the height and length of tree
nodes=self.getInternal() ## get all nodes
strictlyBifurcating=False ## assume tree is not strictly bifurcating
multiType=False
singleton=False
N_children=[len(x.children) for x in nodes]
if len(N_children)==0:
singleton=True
else:
minChildren,maxChildren=min(N_children),max(N_children) ## get the largest number of descendant branches of any node
if maxChildren==2 and minChildren==2: ## if every node has at most two children branches
strictlyBifurcating=True ## it's strictly bifurcating
if minChildren==1:
multiType=True
hasTraits=False ## assume tree has no annotations
maxAnnotations=max([len(x.traits) for x in obs]) ## check the largest number of annotations any branch has
if maxAnnotations>0: ## if it's more than 0
hasTraits=True ## there are annotations
if strictlyBifurcating: print('strictly bifurcating tree') ## report
if multiType: print('multitype tree') ## report
if singleton: print('singleton tree')
if hasTraits: print('annotations present') ## report
print('\nNumbers of objects in tree: %d (%d nodes and %d leaves)\n'%(len(obs),len(nodes),len(self.getExternal()))) ## report numbers of different objects in the tree
def traverse_tree(self,cur_node=None,include_condition=None,traverse_condition=None,collect=None,verbose=False):
if cur_node==None: ## if no starting point defined - start from root
if verbose==True: print('Initiated traversal from root')
cur_node=self.root
if traverse_condition==None and include_condition==None: ## reset heights if traversing from scratch
for k in self.Objects: ## reset various parameters
if k.is_node():
k.leaves=set()
k.childHeight=None
k.height=None
if traverse_condition==None: traverse_condition=lambda k: True
if include_condition==None: include_condition=lambda k: k.is_leaf()
if collect==None: ## initiate collect list if not initiated
collect=[]
if cur_node.parent and cur_node.height==None: ## cur_node has a parent - set height if it doesn't have it already
cur_node.height=cur_node.length+cur_node.parent.height
elif cur_node.height==None: ## cur_node does not have a parent (root), if height not set before it's zero
cur_node.height=0.0
if verbose==True: print('at %s (%s)'%(cur_node.index,cur_node.branchType))
if include_condition(cur_node): ## test if interested in cur_node
collect.append(cur_node) ## add to collect list for reporting later
if cur_node.is_leaf() and self.root!=cur_node: ## cur_node is a tip (and tree is not single tip)
cur_node.parent.leaves.add(cur_node.name) ## add to parent's list of tips
elif cur_node.is_node(): ## cur_node is node
for child in filter(traverse_condition,cur_node.children): ## only traverse through children we're interested
if verbose==True: print('visiting child %s'%(child.index))
self.traverse_tree(cur_node=child,include_condition=include_condition,traverse_condition=traverse_condition,verbose=verbose,collect=collect) ## recurse through children
if verbose==True: print('child %s done'%(child.index))
assert len(cur_node.children)>0, 'Tried traversing through hanging node without children. Index: %s'%(cur_node.index)
cur_node.childHeight=max([child.childHeight if child.is_node() else child.height for child in cur_node.children])
if cur_node.parent:
cur_node.parent.leaves=cur_node.parent.leaves.union(cur_node.leaves) ## pass tips seen during traversal to parent
self.treeHeight=cur_node.childHeight ## it's the highest child of the starting node
return collect
def renameTips(self,d=None):
""" Give each tip its correct label using a dictionary. """
if d==None and self.tipMap!=None:
d=self.tipMap
for k in self.getExternal(): ## iterate through leaf objects in tree
# k.name=d[k.numName] ## change its name
k.name=d[k.name] ## change its name
def sortBranches(self,descending=True,sort_function=None):
""" Sort descendants of each node. """
mod=-1 if descending else 0
if sort_function==None: sort_function=lambda k: (k.is_node(),-len(k.leaves)*mod,k.length*mod) if k.is_node() else (k.is_node(),k.length*mod)
for k in self.getInternal(): ## iterate over nodes
k.children=sorted(k.children,key=sort_function)
self.drawTree() ## update x and y positions of each branch, since y positions will have changed because of sorting
def drawTree(self,order=None,width_function=None,pad_nodes=None,verbose=False):
""" Find x and y coordinates of each branch. """
if order==None:
order=self.traverse_tree() ## order is a list of tips recovered from a tree traversal to make sure they're plotted in the correct order along the vertical tree dimension
if verbose==True: print('Drawing tree in pre-order')
else:
if verbose==True: print('Drawing tree with provided order')
name_order={x.name: i for i,x in enumerate(order)}
assert len(name_order)==len(order), 'Non-unique names present in tree'
if width_function==None:
if verbose==True:
print('Drawing tree with default widths (1 unit for leaf objects, width+1 for clades)')
skips=[1 if isinstance(x,leaf) else x.width+1 for x in order]
else:
skips=list(map(width_function,order))
for k in self.Objects: ## reset coordinates for all objects
k.x=None
k.y=None
drawn={} ## drawn keeps track of what's been drawn
for k in order: ## iterate over tips
x=k.height ## x position is height
y_idx=name_order[k.name] ## assign y index
y=sum(skips[y_idx:])-skips[y_idx]/2.0 ## sum across skips to find y position
k.x=x ## set x and y coordinates
k.y=y
drawn[k.index]=None ## remember that this objects has been drawn
if pad_nodes!=None: ## will be padding nodes
for n in pad_nodes: ## iterate over nodes whose descendants will be padded
idx=sorted([name_order[lf] for lf in n.leaves]) if n.is_node() else [order.index(n)] ## indices of all tips to be padded
for i,k in enumerate(order): ## iterate over all tips
if i<idx[0]: ## tip below clade
k.y+=pad_nodes[n] ## pad
if (i-1)<idx[-1]: ## tip above clade
k.y+=pad_nodes[n] ## pad again
all_ys=filter(None,self.getParameter('y')) ## get all y positions in tree that aren't None
minY=min(all_ys) ## get min
for k in self.getExternal(): ## reset y positions so tree starts at y=0.5
k.y-=minY-0.5
assert len(self.getExternal())==len(order),'Number of tips in tree does not match number of unique tips, check if two or more collapsed clades were assigned the same name.'
storePlotted=0
while len(drawn)!=len(self.Objects): # keep drawing the tree until everything is drawn
if verbose==True: print('Drawing iteration %d'%(len(drawn)))
for k in filter(lambda w:w.index not in drawn,self.getInternal()): ## iterate through internal nodes that have not been drawn
if len([q.y for q in k.children if q.y!=None])==len(k.children): ## all y coordinates of children known
if verbose==True: print('Setting node %s coordinates to'%(k.index)),
x=k.height ## x position is height
children_y_coords=[q.y for q in k.children if q.y!=None] ## get all existing y coordinates of the node
y=sum(children_y_coords)/float(len(children_y_coords)) ## internal branch is in the middle of the vertical bar
k.x=x
k.y=y
drawn[k.index]=None ## remember that this objects has been drawn
if verbose==True: print('%s (%s branches drawn)'%(k.y,len(drawn)))
minYrange=min([min(child.yRange) if child.is_node() else child.y for child in k.children]) ## get lowest y coordinate across children
maxYrange=max([max(child.yRange) if child.is_node() else child.y for child in k.children]) ## get highest y coordinate across children
setattr(k,'yRange',[minYrange,maxYrange]) ## assign the maximum extent of children's y coordinates
if len(self.Objects)>len(drawn):
assert len(drawn)>storePlotted,'Got stuck trying to find y positions of objects (%d branches drawn this iteration, %d branches during previous iteration out of %d total)'%(len(drawn),storePlotted,len(tree.Objects))
storePlotted=len(drawn) ## remember how many branches were drawn this iteration
yvalues=[k.y for k in self.Objects] ## all y values
self.ySpan=max(yvalues)-min(yvalues)+min(yvalues)*2 ## determine appropriate y axis span of tree
if self.root.is_node():
self.root.x=min([q.x-q.length for q in self.root.children if q.x!=None]) ## set root x and y coordinates
children_y_coords=[q.y for q in self.root.children if q.y!=None]
self.root.y=sum(children_y_coords)/float(len(children_y_coords))
else:
self.root.x=self.root.length
def drawUnrooted(self,n=None,total=None):
"""
Calculate x and y coordinates in an unrooted arrangement.
Code translated from https://github.com/nextstrain/auspice/commit/fc50bbf5e1d09908be2209450c6c3264f298e98c, written by Richard Neher.
"""
if n==None:
total=sum([1 if isinstance(x,leaf) else x.width+1 for x in self.getExternal()])
n=self.root#.children[0]
for k in self.Objects:
k.traits['tau']=0.0
k.x=0.0
k.y=0.0
if n.is_leaf():
w=2*math.pi*1.0/float(total)
else:
w=2*math.pi*len(n.leaves)/float(total)
if n.parent.x==None:
n.parent.x=0.0
n.parent.y=0.0
n.x = n.parent.x + n.length * math.cos(n.traits['tau'] + w*0.5)
n.y = n.parent.y + n.length * math.sin(n.traits['tau'] + w*0.5)
eta=n.traits['tau']
if n.is_node():
for ch in n.children:
if ch.is_leaf():
w=2*math.pi*1.0/float(total)
else:
w=2*math.pi*len(ch.leaves)/float(total)
ch.traits['tau'] = eta
eta += w
self.drawUnrooted(ch,total)
def commonAncestor(self,descendants,strict=False):
"""
Find the most recent node object that gave rise to a given list of descendant branches.
"""
assert len(descendants)>1,'Not enough descendants to find common ancestor: %d'%(len(descendants))
paths_to_root={k.index: set() for k in descendants} ## for every descendant create an empty set
for k in descendants: ## iterate through every descendant
cur_node=k ## start descent from descendant
while cur_node: ## while not at root
paths_to_root[k.index].add(cur_node) ## remember every node visited along the way
cur_node=cur_node.parent ## descend
return sorted(reduce(set.intersection,paths_to_root.values()),key=lambda k: k.height)[-1] ## return the most recent branch that is shared across all paths to root
def collapseSubtree(self,cl,givenName,verbose=False,widthFunction=lambda k:len(k.leaves)):
""" Collapse an entire subtree into a clade object. """
assert cl.is_node(),'Cannot collapse non-node class'
collapsedClade=clade(givenName)
collapsedClade.index=cl.index
collapsedClade.leaves=cl.leaves
collapsedClade.length=cl.length
collapsedClade.height=cl.height
collapsedClade.parent=cl.parent
collapsedClade.absoluteTime=cl.absoluteTime
collapsedClade.traits=cl.traits
collapsedClade.width=widthFunction(cl)
if verbose==True: print('Replacing node %s (parent %s) with a clade class'%(cl.index,cl.parent.index))
parent=cl.parent
remove_from_tree=self.traverse_tree(cl,include_condition=lambda k: True)
collapsedClade.subtree=remove_from_tree
assert len(remove_from_tree)<len(self.Objects),'Attempted collapse of entire tree'
collapsedClade.lastHeight=max([x.height for x in remove_from_tree])
if [x.absoluteTime for x in remove_from_tree].count(None)!=len(remove_from_tree):
collapsedClade.lastAbsoluteTime=max([x.absoluteTime for x in remove_from_tree])
for k in remove_from_tree:
self.Objects.remove(k)
parent.children.remove(cl)
parent.children.append(collapsedClade)
self.Objects.append(collapsedClade)
collapsedClade.parent=parent
if self.tipMap!=None: self.tipMap[givenName]=givenName
self.traverse_tree()
self.sortBranches()
return collapsedClade
def uncollapseSubtree(self):
""" Uncollapse all collapsed subtrees. """
while len([k for k in self.Objects if isinstance(k,clade)])>0:
clades=[k for k in self.Objects if isinstance(k,clade)]
for cl in clades:
parent=cl.parent
subtree=cl.subtree
parent.children.remove(cl)
parent.children.append(subtree[0])
self.Objects+=subtree
self.Objects.remove(cl)
if self.tipMap!=None:
self.tipMap.pop(cl.name,None)
self.traverse_tree()
def collapseBranches(self,collapseIf=lambda x:x.traits['posterior']<=0.5,designated_nodes=[],verbose=False):
""" Collapse all branches that satisfy a function collapseIf (default is an anonymous function that returns true if posterior probability is <=0.5).
Alternatively, a list of nodes can be supplied to the script.
Returns a deep copied version of the tree.
"""
newTree=copy.deepcopy(self) ## work on a copy of the tree
if len(designated_nodes)==0: ## no nodes were designated for deletion - relying on anonymous function to collapse nodes
nodes_to_delete=list(filter(lambda n: n.is_node() and collapseIf(n)==True and n!=newTree.root, newTree.Objects)) ## fetch a list of all nodes who are not the root and who satisfy the condition
else:
assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion'
assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion'
nodes_to_delete=list(filter(lambda w: w.index in [q.index for q in designated_nodes], newTree.Objects)) ## need to look up nodes designated for deletion by their indices, since the tree has been copied and nodes will have new memory addresses
if verbose==True: print('%s nodes set for collapsing: %s'%(len(nodes_to_delete),[w.index for w in nodes_to_delete]))
assert len(nodes_to_delete)<len(newTree.getInternal())-1,'Chosen cutoff would remove all branches'
while len(nodes_to_delete)>0: ## as long as there are branches to be collapsed - keep reducing the tree
if verbose==True: print('Continuing collapse cycle, %s nodes left'%(len(nodes_to_delete)))
for k in sorted(nodes_to_delete,key=lambda x:-x.height): ## start with branches near the tips
zero_node=k.children ## fetch the node's children
k.parent.children+=zero_node ## add them to the zero node's parent
old_parent=k ## node to be deleted is the old parent
new_parent=k.parent ## once node is deleted, the parent to all their children will be the parent of the deleted node
if new_parent==None:
new_parent=self.root
if verbose==True: print('Removing node %s, attaching children %s to node %s'%(old_parent.index,[w.index for w in k.children],new_parent.index))
for w in newTree.Objects: ## assign the parent of deleted node as the parent to any children of deleted node
if w.parent==old_parent:
w.parent=new_parent
w.length+=old_parent.length
if verbose==True: print('Fixing branch length for node %s'%(w.index))
k.parent.children.remove(k) ## remove traces of deleted node - it doesn't exist as a child, doesn't exist in the tree and doesn't exist in the nodes list
newTree.Objects.remove(k)
nodes_to_delete.remove(k) ## in fact, the node never existed
if len(designated_nodes)==0:
nodes_to_delete==list(filter(lambda n: n.is_node() and collapseIf(n)==True and n!=newTree.root, newTree.Objects))
else:
assert [w.branchType for w in designated_nodes].count('node')==len(designated_nodes),'Non-node class detected in list of nodes designated for deletion'
assert len([w for w in designated_nodes if w!=newTree.root])==0,'Root node was designated for deletion'
nodes_to_delete=[w for w in newTree.Objects if w.index in [q.index for q in designated_nodes]]
if verbose==True: print('Removing references to node %s'%(k.index))
newTree.sortBranches() ## sort the tree to traverse, draw and sort tree to adjust y coordinates
return newTree ## return collapsed tree
def toString(self,cur_node=None,traits=None,verbose=False,nexus=False,string_fragment=None,traverse_condition=None,rename=None,quotechar="'",json=False):
""" Output the topology of the tree with branch lengths and comments to stringself.
cur_node: starting point (default: None, starts at root)
traits: list of keys that will be used to output entries in traits dict of each branch (default: all traits)
numName: boolean, whether encoded (True) or decoded (default: False) tip names will be output
verbose: boolean, debug
nexus: boolean, whether to output newick (default: False) or nexus (True) formatted tree
string_fragment: list of characters that comprise the tree string
"""
if cur_node==None: cur_node=self.root#.children[-1]
if traits==None: traits=set(sum([list(k.traits.keys()) for k in self.Objects],[])) ## fetch all trait keys
if string_fragment==None:
string_fragment=[]
if nexus==True:
assert json==False,'Nexus format not a valid option for JSON output'
if verbose==True: print('Exporting to Nexus format')
string_fragment.append('#NEXUS\nBegin trees;\ntree TREE1 = [&R] ')
if traverse_condition==None: traverse_condition=lambda k: True
comment=[] ## will hold comment
if len(traits)>0: ## non-empty list of traits to output
for tr in traits: ## iterate through keys
if tr in cur_node.traits: ## if key is available
if verbose==True: print('trait %s available for %s (%s) type: %s'%(tr,cur_node.index,cur_node.branchType,type(cur_node.traits[tr])))
if isinstance(cur_node.traits[tr],str): ## string value
comment.append('%s="%s"'%(tr,cur_node.traits[tr]))
if verbose==True: print('adding string comment %s'%(comment[-1]))
elif isinstance(cur_node.traits[tr],float) or isinstance(cur_node.traits[tr],int): ## float or integer
comment.append('%s=%s'%(tr,cur_node.traits[tr]))
if verbose==True: print('adding numeric comment %s'%(comment[-1]))
elif isinstance(cur_node.traits[tr],list): ## lists
rangeComment=[]
for val in cur_node.traits[tr]:
if isinstance(val,str): ## string
rangeComment.append('"%s"'%(val))
elif isinstance(val,float) or isinstance(val,int): ## float or integer
rangeComment.append('%s'%(val))
comment.append('%s={%s}'%(tr,','.join(rangeComment)))
if verbose==True: print('adding range comment %s'%(comment[-1]))
elif verbose==True: print('trait %s unavailable for %s (%s)'%(tr,cur_node.index,cur_node.branchType))
if cur_node.is_node():
if verbose==True: print('node: %s'%(cur_node.index))
string_fragment.append('(')
traverseChildren=list(filter(traverse_condition,cur_node.children))
assert len(traverseChildren)>0,'Node %s does not have traversable children'%(cur_node.index)
for c,child in enumerate(traverseChildren): ## iterate through children of node if they satisfy traverse condition
if verbose==True: print('moving to child %s of node %s'%(child.index,cur_node.index))
self.toString(cur_node=child,traits=traits,verbose=verbose,nexus=nexus,string_fragment=string_fragment,traverse_condition=traverse_condition,rename=rename,quotechar=quotechar)
if (c+1)<len(traverseChildren): ## not done with children, add comma for next iteration
string_fragment.append(',')
string_fragment.append(')') ## last child, node terminates
elif cur_node.is_leaf():
if rename==None:
treeName=cur_node.name ## designated numName
else:
assert isinstance(rename,dict), 'Variable "rename" is not a dictionary'
assert cur_node.name in rename, 'Tip name %s not in rename dictionary'%(cur_node.name)
treeName=rename[cur_node.name]
if verbose==True: print('leaf: %s (%s)'%(cur_node.index,treeName))
string_fragment.append("%s%s%s"%(quotechar,treeName,quotechar))
if len(comment)>0:
if verbose==True: print('adding comment to %s'%(cur_node.index))
comment=','.join(comment)
comment='[&'+comment+']'
string_fragment.append('%s'%(comment)) ## end of node, add annotations
if verbose==True: print('adding branch length to %s'%(cur_node.index))
string_fragment.append(':%8f'%(cur_node.length)) ## end of node, add branch length
if cur_node==self.root:#.children[-1]:
string_fragment.append(';')
if nexus==True:
string_fragment.append('\nEnd;')
if verbose==True: print('finished')
return ''.join(string_fragment)
def allTMRCAs(self):
tip_names=[k.name for k in self.getExternal()]
tmrcaMatrix={x:{y:None if x!=y else 0.0 for y in tip_names} for x in tip_names} ## pairwise matrix of tips
for k in self.getInternal(): ## iterate over nodes
all_children=list(k.leaves) ## fetch all descendant tips of node
for a,tipA in enumerate(all_children):
for tipB in all_children[a+1:]:
if tmrcaMatrix[tipA][tipB]==None or tmrcaMatrix[tipA][tipB]<=k.absoluteTime: ## if node's time is more recent than previous entry - set new TMRCA value for pair of tips
tmrcaMatrix[tipA][tipB]=k.absoluteTime
tmrcaMatrix[tipB][tipA]=k.absoluteTime
return tmrcaMatrix
def reduceTree(self,keep,verbose=False):
"""
Reduce the tree to just those tracking a small number of tips.
Returns a new baltic tree object.
"""
assert len(keep)>0,"No tips given to reduce the tree to."
assert len([k for k in keep if k.is_leaf()])==0, "Embedding contains %d non-leaf branches."%(len([k for k in keep if k.is_leaf()==False]))
if verbose==True: print("Preparing branch hash for keeping %d branches"%(len(keep)))
branch_hash={k.index:k for k in keep}
embedding=[]
if verbose==True: print("Deep copying tree")
reduced_tree=copy.deepcopy(self) ## new tree object
for k in reduced_tree.Objects: ## deep copy branches from current tree
if k.index in branch_hash: ## if branch is designated as one to keep
cur_b=k
if verbose==True: print("Traversing to root from %s"%(cur_b.index))
while cur_b!=reduced_tree.root: ## descend to root
if verbose==True: print("at %s root: %s"%(cur_b.index,cur_b==reduced_tree.root))
embedding.append(cur_b) ## keep track of the path to root
cur_b=cur_b.parent
embedding.append(reduced_tree.root) ## add root to embedding
if verbose==True: print("Finished extracting embedding with %s branches (%s tips, %s nodes)"%(len(embedding),len([w for w in embedding if w.is_leaf()]),len([w for w in embedding if w.is_node()])))
embedding=set(embedding) ## prune down to only unique branches
reduced_tree.Objects=sorted(list(embedding),key=lambda x:x.height) ## assign branches that are kept to new tree's Objects
if verbose==True: print("Pruning untraversed lineages")
for k in reduced_tree.getInternal(): ## iterate through reduced tree
k.children = [c for c in k.children if c in embedding] ## only keep children that are present in lineage traceback
reduced_tree.root.children=[c for c in reduced_tree.root.children if c in embedding] ## do the same for root
reduced_tree.fixHangingNodes()
if verbose==True: print("Last traversal and branch sorting")
reduced_tree.traverse_tree() ## traverse
reduced_tree.sortBranches() ## sort
return reduced_tree ## return new tree
def countLineages(self,t,attr='absoluteTime',condition=lambda x:True):
return len([k for k in self.Objects if getattr(k.parent,attr)!=None and getattr(k.parent,attr)<t<=getattr(k,attr) and condition(k)])
def getExternal(self,secondFilter=None):
"""
Get all branches whose branchType is "leaf".
A function can be provided to filter internal nodes according to an additional property.
"""
externals=list(filter(secondFilter,filter(lambda k: k.is_leaf(),self.Objects)))
return externals
def getInternal(self,secondFilter=None):
"""
Get all branches whose branchType is "node".
A function can be provided to filter internal nodes according to an additional property.
"""
internals=list(filter(secondFilter,filter(lambda k: k.is_node(),self.Objects)))
return internals
def getBranches(self,attrs=lambda x:True,warn=True):
select=list(filter(attrs,self.Objects))
if len(select)==0 and warn==True:
raise Exception('No branches satisfying function were found amongst branches')
elif len(select)==0 and warn==False:
return []
elif len(select)==1:
return select[-1]
else:
return select
def getParameter(self,statistic,use_trait=False,which=None):
"""
Return either branch trait or attribute (default: trait, to switch to attribute set use_trait parameter to False) statistic across branches determined by the which_branches function (default: all objects in the tree).
Note - branches which do not have the trait or attribute are skipped.
"""
if which==None:
branches=self.Objects
else:
branches=filter(which,self.Objects)
if use_trait==False:
params=[getattr(k,statistic) for k in branches if hasattr(k,statistic)]
elif use_trait==True:
params=[k.traits[statistic] for k in branches if statistic in k.traits]
return params
def fixHangingNodes(self):
"""
Remove internal nodes without any children.
"""
hangingCondition=lambda k: k.is_node() and len(k.children)==0
hangingNodes=list(filter(hangingCondition,self.Objects)) ## check for nodes without any children (hanging nodes)
while len(hangingNodes)>0:
for h in sorted(hangingNodes,key=lambda x:-x.height):
h.parent.children.remove(h) ## remove old parent from grandparent's children
hangingNodes.remove(h) ## remove old parent from multitype nodes
self.Objects.remove(h) ## remove old parent from all objects
hangingNodes=list(filter(hangingCondition,self.Objects)) ## regenerate list
def addText(self,ax,target=None,x_attr=None,y_attr=None,text=None,zorder=None,**kwargs):
if target==None: target=lambda k: k.is_leaf()
if x_attr==None: x_attr=lambda k: k.x
if y_attr==None: y_attr=lambda k: k.y
if text==None: text=lambda k: k.name
if zorder==None: zorder=4
for k in filter(target,self.Objects):
x,y=x_attr(k),y_attr(k)
z=zorder
ax.text(x,y,text(k),zorder=z,**kwargs)
return ax
def plotPoints(self,ax,x_attr=None,y_attr=None,target=None,size=None,colour=None,
zorder=None,outline=None,outline_size=None,outline_colour=None,**kwargs):
if target==None: target=lambda k: k.is_leaf()
if x_attr==None: x_attr=lambda k:k.x
if y_attr==None: y_attr=lambda k:k.y
if size==None: size=40
if colour==None: colour=lambda f:'k'
if zorder==None: zorder=3
if outline==None: outline=True
if outline_size==None: outline_size=lambda k: size(k)*2 if callable(size) else size*2
if outline_colour==None: outline_colour='k'
xs=[]
ys=[]
colours=[]
sizes=[]
outline_xs=[]
outline_ys=[]
outline_colours=[]
outline_sizes=[]
for k in filter(target,self.Objects):
xs.append(x_attr(k))
ys.append(y_attr(k))
colours.append(colour(k)) if callable(colour) else colours.append(colour)
sizes.append(size(k)) if callable(size) else sizes.append(size)
if outline:
outline_xs.append(xs[-1])
outline_ys.append(ys[-1])
outline_colours.append(outline_colour(k)) if callable(outline_colour) else outline_colours.append(outline_colour)
outline_sizes.append(outline_size(k)) if callable(outline_size) else outline_sizes.append(outline_size)
ax.scatter(xs,ys,s=sizes,facecolor=colours,edgecolor='none',zorder=zorder,**kwargs) ## put a circle at each tip
if outline:
ax.scatter(outline_xs,outline_ys,s=outline_sizes,facecolor=outline_colours,edgecolor='none',zorder=zorder-1,**kwargs) ## put a circle at each tip
return ax
def plotTree(self,ax,connection_type=None,target=None,
x_attr=None,y_attr=None,width=None,
colour=None,**kwargs):
if target==None: target=lambda k: True
if x_attr==None: x_attr=lambda k: k.x
if y_attr==None: y_attr=lambda k: k.y
if width==None: width=2
if colour==None: colour='k'
if connection_type==None: connection_type='baltic'
assert connection_type in ['baltic','direct','elbow'],'Unrecognised drawing type "%s"'%(tree_type)
branches=[]
colours=[]
linewidths=[]
for k in filter(target,self.Objects): ## iterate over branches
x=x_attr(k) ## get branch x position
xp=x_attr(k.parent) if k.parent else x ## get parent x position
y=y_attr(k) ## get y position
try:
colours.append(colour(k)) if callable(colour) else colours.append(colour)
except KeyError:
colours.append((0.7,0.7,0.7))
linewidths.append(width(k)) if callable(width) else linewidths.append(width)
if connection_type=='baltic':
branches.append(((xp,y),(x,y)))
if k.is_node():
yl,yr=y_attr(k.children[0]),y_attr(k.children[-1])
branches.append(((x,yl),(x,yr)))
linewidths.append(linewidths[-1])
colours.append(colours[-1])
elif connection_type=='elbow':
yp=y_attr(k.parent) if k.parent else y ## get parent x position
branches.append(((xp,yp),(xp,y),(x,y)))
elif connection_type=='direct':
yp=y_attr(k.parent) ## get y position
branches.append(((xp,yp),(x,y)))
else:
pass ## for now
line_segments = LineCollection(branches,lw=linewidths,color=colours,capstyle='projecting',**kwargs)
ax.add_collection(line_segments)
return ax
def plotCircularTree(self,ax,target=None,x_attr=None,y_attr=None,width=None,colour=None,
circStart=0.0,circFrac=1.0,inwardSpace=0.0,normaliseHeight=None,precision=15,**kwargs):
if target==None: target=lambda k: True
if x_attr==None: x_attr=lambda k:k.x
if y_attr==None: y_attr=lambda k:k.y
if colour==None: colour='k'
if width==None: width=2
if inwardSpace<0: inwardSpace-=self.treeHeight
branches=[]
colours=[]
linewidths=[]
circ_s=circStart*math.pi*2
circ=circFrac*math.pi*2
allXs=list(map(x_attr,self.Objects))
if normaliseHeight==None: normaliseHeight=lambda value: (value-min(allXs))/(max(allXs)-min(allXs))
linspace=lambda start,stop,n: list(start+((stop-start)/(n-1))*i for i in range(n)) if n>1 else stop
for k in filter(target,self.Objects): ## iterate over branches
x=normaliseHeight(x_attr(k)+inwardSpace) ## get branch x position
xp=normaliseHeight(x_attr(k.parent)+inwardSpace) if k.parent.parent else x ## get parent x position
y=y_attr(k) ## get y position
try:
colours.append(colour(k)) if callable(colour) else colours.append(colour)
except KeyError:
colours.append((0.7,0.7,0.7))
linewidths.append(width(k)) if callable (width) else linewidths.append(width)
y=circ_s+circ*y/self.ySpan
X=math.sin(y)
Y=math.cos(y)
branches.append(((X*xp,Y*xp),(X*x,Y*x)))
if k.is_node():
yl,yr=y_attr(k.children[0]),y_attr(k.children[-1]) ## get leftmost and rightmost children's y coordinates
yl=circ_s+circ*yl/self.ySpan ## transform y into a fraction of total y
yr=circ_s+circ*yr/self.ySpan
ybar=linspace(yl,yr,precision) ## what used to be vertical node bar is now a curved line
xs=[yx*x for yx in map(math.sin,ybar)] ## convert to polar coordinates
ys=[yy*x for yy in map(math.cos,ybar)]
branches+=tuple(zip(zip(xs,ys),zip(xs[1:],ys[1:]))) ## add curved segment
linewidths+=[linewidths[-1] for q in zip(ys,ys[1:])] ## repeat linewidths
colours+=[colours[-1] for q in zip(ys,ys[1:])] ## repeat colours
line_segments = LineCollection(branches,lw=linewidths,ls='-',color=colours,capstyle='projecting',zorder=1) ## create line segments
ax.add_collection(line_segments) ## add collection to axes
return ax
def plotCircularPoints(self,ax,x_attr=None,y_attr=None,target=None,size=None,colour=None,circStart=0.0,circFrac=1.0,inwardSpace=0.0,normaliseHeight=None,
zorder=None,outline=None,outline_size=None,outline_colour=None,**kwargs):
if target==None: target=lambda k: k.is_leaf()
if x_attr==None: x_attr=lambda k:k.x
if y_attr==None: y_attr=lambda k:k.y
if size==None: size=40
if colour==None: colour='k'
if zorder==None: zorder=3
if outline==None: outline=True
if outline_size==None: outline_size=lambda k: size(k)*2 if callable(size) else size*2
if outline_colour==None: outline_colour=lambda k: 'k'
if inwardSpace<0: inwardSpace-=self.treeHeight
circ_s=circStart*math.pi*2
circ=circFrac*math.pi*2
allXs=list(map(x_attr,self.Objects))
if normaliseHeight==None: normaliseHeight=lambda value: (value-min(allXs))/(max(allXs)-min(allXs))
linspace=lambda start,stop,n: list(start+((stop-start)/(n-1))*i for i in range(n)) if n>1 else stop
xs=[]
ys=[]
colours=[]
sizes=[]
outline_xs=[]
outline_ys=[]
outline_colours=[]
outline_sizes=[]
for k in filter(target,self.Objects):
x=normaliseHeight(x_attr(k)+inwardSpace) ## find normalised x position along circle's radius
y=circ_s+circ*y_attr(k)/self.ySpan ## get y position along circle's perimeter
X=math.sin(y)*x ## transform
Y=math.cos(y)*x ## transform
xs.append(X)
ys.append(Y)
colours.append(colour(k)) if callable(colour) else colours.append(colour)
sizes.append(size(k)) if callable(size) else sizes.append(size)
if outline:
outline_xs.append(xs[-1])
outline_ys.append(ys[-1])
outline_colours.append(outline_colour(k)) if callable(outline_colour) else outline_colours.append(outline_colour)
outline_sizes.append(outline_size(k)) if callable(outline_size) else outline_sizes.append(outline_size)
ax.scatter(xs,ys,s=sizes,facecolor=colours,edgecolor='none',zorder=zorder,**kwargs) ## put a circle at each tip
if outline:
ax.scatter(outline_xs,outline_ys,s=outline_sizes,facecolor=outline_colours,edgecolor='none',zorder=zorder-1,**kwargs) ## put a circle at each tip
return ax
def untangle(trees,cost_function=None,iterations=None,verbose=False):
"""
Minimise y-axis discrepancies between tips of trees in a list.
Only the tangling of adjacent trees in the list is minimised, so the order of trees matters.
Trees do not need to have the same number of tips but tip names should match.
"""
from itertools import permutations
if iterations==None: iterations=3
if cost_function==None: cost_function=lambda pair: math.pow(abs(pair[0]-pair[1]),2)
y_positions={T: {k.name: k.y for k in T.getExternal()} for T in trees} ## get y positions of all the tips in every tree
for iteration in range(iterations):
if verbose: print('Untangling iteration %d'%(iteration+1))
first_trees=list(range(len(trees)-1))+[-1] ## trees up to next-to-last + last