diff --git a/include/astar/heapset.h b/include/astar/heapset.h index d126f55..a99a29a 100644 --- a/include/astar/heapset.h +++ b/include/astar/heapset.h @@ -5,86 +5,109 @@ #include #include -template, class VC = std::less> -class HeapSet { +enum class CloseOnPop : bool { FALSE, TRUE }; + +template +class Closeable { + public: + Closeable(const T& t, bool c = false) : t_(t), closed_(c) {} + operator T&() { return t_; } + operator const T&() const { return t_; } + bool operator!() const { return closed_; } + private: + T t_; bool closed_; +}; + +template +class Wrapper { + public: + Wrapper(const T&, bool = false) {} + operator T&() { return t_; } + operator const T&() const { return t_; } + bool operator!() const { return false; } + private: + T t_; +}; + +template struct SetT { + using type = Wrapper; +}; +template struct SetT { + using type = Closeable; +}; + +template, class VC = std::less, + CloseOnPop COP = CloseOnPop::FALSE> +class CachingHeapSet { public: bool empty() const; const T& top() const; T pop(); void push(const T& val); - void push(T&& val); private: - std::set states_; + using SetType = typename SetT::type; + + std::set states_; std::priority_queue,VC> heap_; // returns true if second arg is higher priority than the first VC higher_priority_; - T pop_heap(); + T pop_heap(CloseOnPop); void update_heap(const T& old, const T& updated); }; -template -bool HeapSet::empty() const { +template +bool CachingHeapSet::empty() const { return states_.empty(); } -template -const T& HeapSet::top() const { +template +const T& CachingHeapSet::top() const { return heap_.top(); } -template -T HeapSet::pop() { - T t = pop_heap(); - states_.erase(t); - return std::move(t); +template +T CachingHeapSet::pop() { + return pop_heap(COP); } -template -void HeapSet::push(const T& t) { - auto it_inserted = states_.insert(t); +template +void CachingHeapSet::push(const T& t) { + auto it_inserted = states_.insert(SetType{t}); + bool inserted = it_inserted.second; + auto& tt = *it_inserted.first; - if (it_inserted.second) { - heap_.push(t); + if (inserted) { + heap_.push(tt); } else { - auto old = *it_inserted.first; - if (higher_priority_(old, t)) { - update_heap(old, t); + if (!tt && higher_priority_(tt, t)) { + update_heap(tt, t); auto it = states_.erase(it_inserted.first); states_.insert(it, t); } } } -template -void HeapSet::push(T&& t) { - auto it_inserted = states_.insert(t); - - if (it_inserted.second) { - heap_.push(std::forward(t)); - } else { - auto old = *it_inserted.first; - if (higher_priority_(old, t)) { - update_heap(old, t); - auto it = states_.erase(it_inserted.first); - states_.insert(it, std::forward(t)); - } - } -} - -template -T HeapSet::pop_heap() { +template +T CachingHeapSet::pop_heap(CloseOnPop cop) { T t = heap_.top(); + if (COP == CloseOnPop::TRUE && cop == CloseOnPop::TRUE) { + SetType val{t, true}; + auto it = states_.find(val); + it = states_.erase(it); + states_.insert(it, val); + } heap_.pop(); return std::move(t); } -template -void HeapSet::update_heap(const T& old, const T& updated) { +template +void CachingHeapSet::update_heap(const T& old, const T& updated) { std::deque queue; do { - queue.push_back(pop_heap()); + queue.push_back(pop_heap(CloseOnPop::FALSE)); } while (queue.back() != old); queue.back() = updated; diff --git a/include/astar/solver.h b/include/astar/solver.h index 81ea67e..77e8b2f 100644 --- a/include/astar/solver.h +++ b/include/astar/solver.h @@ -68,8 +68,7 @@ class AStarSolver Distance distance_func_; Estimator cost_func_; SNode last_; - std::set closed_set_; - HeapSet open_set_; + CachingHeapSet states_; }; template @@ -85,7 +84,7 @@ template AStarSolver::AStarSolver(const T& s, const T& g, const Generator& gen, const Distance& d, const Estimator& c) : goal_(g), generator_func_(gen), distance_func_(d), cost_func_(c) { - open_set_.push(make_snode(*this,s)); + states_.push(make_snode(*this,s)); } template @@ -115,10 +114,8 @@ void AStarSolver::print_solution(std::ostream& o) const template bool AStarSolver::solve() { - while (!open_set_.empty()) { - auto it_inserted = closed_set_.insert(open_set_.pop()); - - auto snode = *it_inserted.first; + while (!states_.empty()) { + auto snode = states_.pop(); if (snode->state_ == goal_) { last_ = snode; @@ -126,10 +123,7 @@ bool AStarSolver::solve() } for (auto& n : generator_func_(snode->state_)) { - auto new_node = make_snode(*this, n, snode); - if (closed_set_.find(new_node) == end(closed_set_)) { - open_set_.push(std::move(new_node)); - } + states_.push(make_snode(*this, n, snode)); } }