diff --git a/CMakeLists.txt b/CMakeLists.txt index cc2a97f..128b2e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ set(STS_CPP_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/log/sampler.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/metropolis_hastings_move.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/node.cc + ${CMAKE_CURRENT_SOURCE_DIR}/src/node_deleter.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/online_calculator.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/rooted_merge.cc ${CMAKE_CURRENT_SOURCE_DIR}/src/smc_init.cc diff --git a/src/edge.cc b/src/edge.cc index dde47c9..fd74a51 100644 --- a/src/edge.cc +++ b/src/edge.cc @@ -36,10 +36,10 @@ Edge::Edge(std::shared_ptr node) : prior_log_likelihood(0), node(node) {} -std::shared_ptr Edge::of_tree(std::shared_ptr calc, bpp::TreeTemplate &tree, int node_number, std::unordered_map, std::string>& names) +std::shared_ptr Edge::of_tree(bpp::TreeTemplate &tree, int node_number, std::unordered_map, std::string>& names) { return std::make_shared( - Node::of_tree(calc, tree, node_number, names), + Node::of_tree(tree, node_number, names), tree.getDistanceToFather(node_number)); } diff --git a/src/edge.h b/src/edge.h index fb15443..7d27ad1 100644 --- a/src/edge.h +++ b/src/edge.h @@ -35,8 +35,7 @@ class Edge std::shared_ptr node; /// Make an edge from a bpp Tree and node number - static std::shared_ptr of_tree(std::shared_ptr, - bpp::TreeTemplate &, + static std::shared_ptr of_tree(bpp::TreeTemplate &, int, std::unordered_map, std::string>&); }; diff --git a/src/metropolis_hastings_move.cc b/src/metropolis_hastings_move.cc index f68e95c..af0b8fa 100644 --- a/src/metropolis_hastings_move.cc +++ b/src/metropolis_hastings_move.cc @@ -1,6 +1,7 @@ #include "metropolis_hastings_move.h" #include "edge.h" +#include "node_deleter.h" using namespace sts::particle; @@ -25,7 +26,7 @@ int Metropolis_hastings_move::operator()(long time, smc::particlenode->edge_prior_log_likelihood(); - particle::Node_ptr new_node = std::make_shared(*cur_part->node); + particle::Node_ptr new_node(new particle::Node(*cur_part->node), likelihood::Node_deleter(calc)); new_part->node = new_node; diff --git a/src/node.cc b/src/node.cc index f6eee50..d6437ad 100644 --- a/src/node.cc +++ b/src/node.cc @@ -3,16 +3,15 @@ #include "node.h" #include "edge.h" -#include "online_calculator.h" namespace sts { namespace particle { -// Implementation -Node::Node(std::shared_ptr calc) : calc(calc) {} -Node::Node(const Node & other) : calc(other.calc) +Node::Node() {}; + +Node::Node(const Node & other) { if(!other.is_leaf()) { child1 = std::make_shared(*other.child1); @@ -20,16 +19,8 @@ Node::Node(const Node & other) : calc(other.calc) } } -Node::~Node() -{ - auto p = calc.lock(); - if(p) - p->unregister_node(this); -} - Node & Node::operator=(const Node & other) { - calc = other.calc; if(!other.is_leaf()) { child1 = std::make_shared(*other.child1); child2 = std::make_shared(*other.child2); @@ -43,23 +34,23 @@ bool Node::is_leaf() const return this->child1 == nullptr && this->child2 == nullptr; } -Node_ptr Node::of_tree(std::shared_ptr calc, bpp::TreeTemplate &tree, int node_number, std::unordered_map& names) +Node_ptr Node::of_tree(bpp::TreeTemplate& tree, int node_number, std::unordered_map& names) { - Node_ptr n = std::make_shared(calc); + Node_ptr n = std::make_shared(); if(tree.isLeaf(node_number)) { names[n] = tree.getNodeName(node_number); return n; } std::vector children = tree.getSonsId(node_number); assert(children.size() == 2); - n->child1 = Edge::of_tree(calc, tree, children[0], names); - n->child2 = Edge::of_tree(calc, tree, children[1], names); + n->child1 = Edge::of_tree(tree, children[0], names); + n->child2 = Edge::of_tree(tree, children[1], names); return n; } -Node_ptr Node::of_tree(std::shared_ptr calc, bpp::TreeTemplate &tree, std::unordered_map& names) +Node_ptr Node::of_tree(bpp::TreeTemplate &tree, std::unordered_map& names) { - return Node::of_tree(calc, tree, tree.getRootId(), names); + return Node::of_tree(tree, tree.getRootId(), names); } diff --git a/src/node.h b/src/node.h index e42bc83..753cb57 100644 --- a/src/node.h +++ b/src/node.h @@ -10,25 +10,18 @@ namespace sts { -// Circular dependencies -namespace likelihood -{ -class Online_calculator; -} - namespace particle { class Edge; /// \class Node -/// Represents the merge of two trees in a forest. +/// \brief Represents the merge of two trees in a forest. class Node { public: - explicit Node(std::shared_ptr calc); + Node(); Node(const Node & other); - ~Node(); Node & operator=(const Node & other); @@ -39,18 +32,15 @@ class Node /// Make a Node from a bpp Tree static std::shared_ptr - of_tree(std::shared_ptr, - bpp::TreeTemplate&, + of_tree(bpp::TreeTemplate&, std::unordered_map, std::string>&); /// Make a Node from a bpp tree and node number static std::shared_ptr - of_tree(std::shared_ptr, bpp::TreeTemplate &, int, + of_tree(bpp::TreeTemplate &, int, std::unordered_map, std::string>&); double edge_prior_log_likelihood() const; -private: - std::weak_ptr calc; }; /// A node in a phylogenetic tree diff --git a/src/node_deleter.cc b/src/node_deleter.cc new file mode 100644 index 0000000..8009275 --- /dev/null +++ b/src/node_deleter.cc @@ -0,0 +1,23 @@ +#include "node_deleter.h" +#include "node.h" +#include "online_calculator.h" + +using sts::particle::Node; + +namespace sts +{ +namespace likelihood +{ + +void Node_deleter::operator()(Node* node) +{ + // Unregister from the calculator + if(auto p = calc.lock()) + p->unregister_node(node); + // ... before deleting the node + d(node); +} + + +} // namespace likelihood +} // namespace sts diff --git a/src/node_deleter.h b/src/node_deleter.h new file mode 100644 index 0000000..c456099 --- /dev/null +++ b/src/node_deleter.h @@ -0,0 +1,41 @@ +/// \file node_deleter.h +#ifndef STS_LIKELIHOOD_NODE_DELETER_H +#define STS_LIKELIHOOD_NODE_DELETER_H + +#include + +namespace sts +{ +// Forwards +namespace particle { class Node; } +namespace likelihood +{ +class Online_calculator; + +/// \brief Deleter for sts::particle::Node which unregisters nodes from an Online_calculator +/// during deletion +/// +/// Sample usage: +/// \code +/// std::shared_ptr c; +/// sts::particle::Node_ptr p = sts::particle::Node_ptr(new sts::particle::Node(), Node_deleter(c)); +/// \endcode +/// +/// \related Node +/// \related Node_ptr +struct Node_deleter +{ + Node_deleter() {}; + Node_deleter(const std::shared_ptr& c) : calc(c) {}; + + /// The online calculator to unregister the node with on deletion + std::weak_ptr calc; + std::default_delete d; + + void operator()(sts::particle::Node* node); +}; + +} +} + +#endif // STS_LIKELIHOOD_NODE_DELETER_H diff --git a/src/node_ptr.h b/src/node_ptr.h index c4fda72..60a445c 100644 --- a/src/node_ptr.h +++ b/src/node_ptr.h @@ -1,3 +1,6 @@ +/// \file node_ptr.h +/// \brief Typedef for shared_ptr + #ifndef STS_NODE_PTR_H #define STS_NODE_PTR_H @@ -8,6 +11,17 @@ namespace sts namespace particle { class Node; + +/// \brief A shared_ptr to a Node. + +/// NB: if using Node_ptr in conjunction with an Online_calculator instance, use +/// a Node_deleter, e.g.: +/// \code +/// std::shared_ptr calc = ...; +/// Node_ptr n = Node_ptr(new Node(), Node_deleter(calc)); +/// \endcode +/// \related sts::particle::Node +/// \related sts::likelihood::Node_deleter typedef std::shared_ptr Node_ptr; } // sts } // particle diff --git a/src/rooted_merge.cc b/src/rooted_merge.cc index feba6a3..301ad50 100644 --- a/src/rooted_merge.cc +++ b/src/rooted_merge.cc @@ -1,6 +1,7 @@ -#include "rooted_merge.h" #include "edge.h" +#include "node_deleter.h" #include "node.h" +#include "rooted_merge.h" #include "util.h" #include @@ -38,7 +39,7 @@ int Rooted_merge::do_move(long time, smc::particle& p_from, // The following gives the uniform distribution on legal choices that are not n1. Think of taking the uniform // distribution on [0,n-2], breaking it at n1 and moving the right hand bit one to the right. if(n2 >= n1) n2++; - pp->node = std::make_shared(calc); + pp->node = particle::Node_ptr(new particle::Node(), likelihood::Node_deleter(calc)); // Draw branch lengths. pp->node->child1 = std::make_shared(prop_vector[n1]); diff --git a/src/state.cc b/src/state.cc index b7236d9..4f54e92 100644 --- a/src/state.cc +++ b/src/state.cc @@ -7,11 +7,11 @@ namespace sts namespace particle { -Particle State::of_tree(std::shared_ptr calc, bpp::TreeTemplate &tree, +Particle State::of_tree(bpp::TreeTemplate& tree, std::unordered_map& names) { Particle p = std::make_shared(); - p->node = Node::of_tree(calc, tree, names); + p->node = Node::of_tree(tree, names); if(p->node->is_leaf()) return p; @@ -34,13 +34,11 @@ Particle State::of_tree(std::shared_ptr calc, bpp return p; } - -Particle State::of_newick_string(std::shared_ptr calc, std::string &tree_string, +Particle State::of_newick_string(std::string &tree_string, std::unordered_map& names) { - bpp::TreeTemplate *tree = bpp::TreeTemplateTools::parenthesisToTree(tree_string); - Particle node = State::of_tree(calc, *tree, names); - delete tree; + std::unique_ptr> tree(bpp::TreeTemplateTools::parenthesisToTree(tree_string)); + Particle node = State::of_tree(*tree, names); return node; } diff --git a/src/state.h b/src/state.h index f8b0372..95894dd 100644 --- a/src/state.h +++ b/src/state.h @@ -32,7 +32,7 @@ class State log_likelihood(0.0) {}; /// The merge novel to this particle. If \c nullptr then the particle is \f$\perp\f$. - std::shared_ptr node; + Node_ptr node; /// The predecessor particles, which specify the rest of the merges for this particle. std::shared_ptr predecessor; @@ -45,14 +45,14 @@ class State /// Make a State from a bpp Tree static std::shared_ptr - of_tree(std::shared_ptr, bpp::TreeTemplate &, std::unordered_map&); + of_tree(bpp::TreeTemplate &, std::unordered_map&); /// Make a State from a Newick tree string static std::shared_ptr - of_newick_string(std::shared_ptr, std::string &, std::unordered_map&); + of_newick_string(std::string &, std::unordered_map&); }; } // namespace particle } // namespace sts -#endif // STS_PARTICLE_PHYLO_PARTICLE_H \ No newline at end of file +#endif // STS_PARTICLE_PHYLO_PARTICLE_H diff --git a/src/sts.cc b/src/sts.cc index c2d0975..2989dee 100644 --- a/src/sts.cc +++ b/src/sts.cc @@ -5,6 +5,7 @@ #include "forest_likelihood.h" #include "online_calculator.h" #include "node.h" +#include "node_deleter.h" #include "state.h" #include "util.h" @@ -21,7 +22,6 @@ #include "gamma_branch_length_proposer.h" #include "uniform_branch_length_proposer.h" - #include #include #include @@ -246,7 +246,7 @@ model->getAlphabet())); leaf_nodes.resize(num_iters); unordered_map node_name_map; for(int i = 0; i < num_iters; i++) { - leaf_nodes[i] = make_shared(calc); + leaf_nodes[i] = shared_ptr(new Node(), Node_deleter(calc)); calc->register_leaf(leaf_nodes[i], aln->getSequencesNames()[i]); node_name_map[leaf_nodes[i]] = aln->getSequencesNames()[i]; } diff --git a/src/tests/test_sts_likelihood.hpp b/src/tests/test_sts_likelihood.hpp index 699479c..37c5b51 100644 --- a/src/tests/test_sts_likelihood.hpp +++ b/src/tests/test_sts_likelihood.hpp @@ -53,7 +53,7 @@ void test_known_tree_jc69(std::string fasta_path, std::string newick_path, doubl if(compress) calc->set_weights(weights); std::unordered_map names; - auto root = sts::particle::State::of_newick_string(calc, nwk_string, names); + auto root = sts::particle::State::of_newick_string(nwk_string, names); // Register sts::util::register_nodes(*calc, root->node, names); std::unordered_set visited; diff --git a/src/tests/test_sts_parsing.hpp b/src/tests/test_sts_parsing.hpp index f11e7e7..90e581e 100644 --- a/src/tests/test_sts_parsing.hpp +++ b/src/tests/test_sts_parsing.hpp @@ -2,7 +2,6 @@ #define TEST_STS_PARSING_HPP #include "edge.h" -#include "online_calculator.h" #include "node.h" #include "particle.h" #include "state.h" @@ -21,15 +20,12 @@ namespace parsing { using namespace sts::particle; -using sts::likelihood::Online_calculator; - -std::shared_ptr null_calculator; TEST_CASE("phylofunc/newick_parsing/one_leaf", "test parsing a newick tree with one leaf") { std::string tree = "A;"; std::unordered_map names; - sts::particle::Particle p = State::of_newick_string(null_calculator, tree, names); + sts::particle::Particle p = State::of_newick_string(tree, names); REQUIRE(p->node->is_leaf()); } @@ -37,7 +33,7 @@ TEST_CASE("phylofunc/newick_parsing/two_leaf", "test parsing a newick tree with { std::string tree = "(A:2,B:3);"; std::unordered_map names; - sts::particle::Particle p = State::of_newick_string(null_calculator, tree, names); + sts::particle::Particle p = State::of_newick_string(tree, names); REQUIRE(!p->node->is_leaf()); REQUIRE(p->node->child1->length == 2); REQUIRE(p->node->child1->node->is_leaf()); @@ -59,7 +55,7 @@ TEST_CASE("phylofunc/newick_parsing/three_leaf", "test parsing a newick tree wit { std::string tree = "((A:2,B:3):4,C:6);"; std::unordered_map names; - sts::particle::Particle p = State::of_newick_string(null_calculator, tree, names); + sts::particle::Particle p = State::of_newick_string(tree, names); REQUIRE(!p->node->is_leaf()); REQUIRE(p->node->child1->length == 4); REQUIRE(!p->node->child1->node->is_leaf()); @@ -87,7 +83,7 @@ TEST_CASE("phylofunc/newick_parsing/four_leaf", "test parsing a newick tree with { std::string tree = "((A:2,B:3):4,(C:6,D:7):9);"; std::unordered_map names; - sts::particle::Particle p = State::of_newick_string(null_calculator, tree, names); + sts::particle::Particle p = State::of_newick_string(tree, names); REQUIRE(!p->node->is_leaf()); REQUIRE(p->node->child1->length == 4); REQUIRE(!p->node->child1->node->is_leaf()); @@ -121,7 +117,7 @@ static std::string roundtrip(std::string &tree) { std::unordered_map names; - sts::particle::Particle p = State::of_newick_string(null_calculator, tree, names); + sts::particle::Particle p = State::of_newick_string(tree, names); std::ostringstream ostream; sts::util::write_tree(ostream, p->node, names); return ostream.str();