Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to use concurrent pair info buffer in distance estimation and around #1425

Merged
merged 7 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .clangd
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
CompileFlags:
CompilationDatabase: build_spades/
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ CMakeCache.txt
cmake_install.cmake
assembler/src/tools/quality/results*
__pycache__
.clangd
.DS_Store
compile_commands.json
.cache
Expand Down
6 changes: 6 additions & 0 deletions src/common/paired_info/concurrent_pair_info_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class ConcurrentPairedBuffer : public PairedBufferBase<ConcurrentPairedBuffer<G,
template<class Graph>
using ConcurrentPairedInfoBuffer = ConcurrentPairedBuffer<Graph, RawPointTraits, btree_map>;

template<class Graph>
using ConcurrentClusteredPairedInfoBuffer = ConcurrentPairedBuffer<Graph, PointTraits, btree_map>;

template<class Graph>
using ConcurrentUnorderedClusteredPairedInfoBuffer = ConcurrentPairedBuffer<Graph, PointTraits, phmap_map>;

} // namespace de

} // namespace omnigraph
40 changes: 17 additions & 23 deletions src/common/paired_info/distance_estimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
//***************************************************************************

#include "distance_estimation.hpp"
#include "pair_info_bounds.hpp"
#include "assembly_graph/paths/path_processor.hpp"

namespace omnigraph {
namespace de {
namespace omnigraph::de {

using namespace debruijn_graph;

Expand Down Expand Up @@ -75,32 +75,26 @@ AbstractDistanceEstimator::OutHistogram AbstractDistanceEstimator::ClusterResult
return result;
}

void AbstractDistanceEstimator::AddToResult(const OutHistogram &clustered, EdgePair ep,
PairedInfoBuffer<Graph> &result) const {
result.AddMany(ep.first, ep.second, clustered);
}

void DistanceEstimator::Estimate(PairedInfoIndexT<Graph> &result, size_t nthreads) const {
this->Init();
const auto &index = this->index();
ConcurrentUnorderedClusteredPairedInfoBuffer<Graph> buffer(graph());

DEBUG("Collecting edge infos");
std::vector<EdgeId> edges;
for (EdgeId e : this->graph().edges())
edges.push_back(e);
omnigraph::IterationHelper<Graph, EdgeId> edges(graph());
auto ranges = edges.Ranges(nthreads * 16);

DEBUG("Processing");
PairedInfoBuffersT<Graph> buffer(this->graph(), nthreads);
# pragma omp parallel for num_threads(nthreads) schedule(guided, 10)
for (size_t i = 0; i < edges.size(); ++i) {
EdgeId edge = edges[i];
ProcessEdge(edge, index, buffer[omp_get_thread_num()]);
}
# pragma omp parallel for schedule(guided) num_threads(nthreads)
for (size_t i = 0; i < ranges.size(); ++i) {
TRACE("Processing chunk #" << i);

for (size_t i = 0; i < nthreads; ++i) {
result.Merge(buffer[i]);
buffer[i].clear();
for (EdgeId e : ranges[i]) {
TRACE("Estimating for edge " << e);
ProcessEdge(e, index, buffer);
}
}

result.Merge(buffer);
}

DistanceEstimator::EstimHist DistanceEstimator::EstimateEdgePairDistances(EdgePair ep, const InHistogram &histogram,
Expand Down Expand Up @@ -158,7 +152,7 @@ DistanceEstimator::EstimHist DistanceEstimator::EstimateEdgePairDistances(EdgePa
return result;
}

void DistanceEstimator::ProcessEdge(EdgeId e1, const InPairedIndex &pi, PairedInfoBuffer<Graph> &result) const {
void DistanceEstimator::ProcessEdge(EdgeId e1, const InPairedIndex &pi, Buffer &result) const {
typename base::LengthMap second_edges;
auto inner_map = pi.GetHalf(e1);
for (auto i : inner_map)
Expand All @@ -181,5 +175,5 @@ void DistanceEstimator::ProcessEdge(EdgeId e1, const InPairedIndex &pi, PairedIn
this->AddToResult(res, ep, result);
}
}
}
}

} // namespace omnigraph::de
30 changes: 23 additions & 7 deletions src/common/paired_info/distance_estimation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
#define DISTANCE_ESTIMATION_HPP_

#include "paired_info.hpp"
#include "pair_info_bounds.hpp"
#include "pair_info_filters.hpp"
#include "concurrent_pair_info_buffer.hpp"

#include "assembly_graph/core/graph.hpp"
#include "utils/parallel/openmp_wrapper.h"

#include "math/xmath.h"

Expand Down Expand Up @@ -51,14 +51,18 @@ class AbstractDistanceEstimator {
typedef PairedInfoIndexT<debruijn_graph::Graph> OutPairedIndex;
typedef typename InPairedIndex::HistProxy InHistogram;
typedef typename OutPairedIndex::Histogram OutHistogram;
typedef AbstractPairInfoChecker<debruijn_graph::Graph> PairInfoChecker;


public:
AbstractDistanceEstimator(const debruijn_graph::Graph &graph,
const InPairedIndex &index,
const GraphDistanceFinder &distance_finder,
const PairInfoChecker &pair_info_checker,
size_t linkage_distance = 0)
: graph_(graph), index_(index),
distance_finder_(distance_finder), linkage_distance_(linkage_distance) { }
distance_finder_(distance_finder), pair_info_checker_(pair_info_checker),
linkage_distance_(linkage_distance) { }

virtual void Estimate(PairedInfoIndexT<debruijn_graph::Graph> &result, size_t nthreads) const = 0;

Expand All @@ -78,12 +82,21 @@ class AbstractDistanceEstimator {

OutHistogram ClusterResult(EdgePair /*ep*/, const EstimHist &estimated) const;

void AddToResult(const OutHistogram &clustered, EdgePair ep, PairedInfoBuffer<debruijn_graph::Graph> &result) const;
template<class Buffer>
void AddToResult(const OutHistogram &clustered, EdgePair ep, Buffer &result) const {
OutHistogram filtered;
for (Point p : clustered)
if (pair_info_checker_.Check(ep.first, ep.second, p))
filtered.insert(p);

result.AddMany(ep.first, ep.second, filtered);
}

private:
const debruijn_graph::Graph &graph_;
const InPairedIndex &index_;
const GraphDistanceFinder &distance_finder_;
const PairInfoChecker &pair_info_checker_;
const size_t linkage_distance_;

virtual const std::string Name() const = 0;
Expand All @@ -102,15 +115,18 @@ class DistanceEstimator : public AbstractDistanceEstimator {
typedef typename base::OutPairedIndex OutPairedIndex;
typedef typename base::InHistogram InHistogram;
typedef typename base::OutHistogram OutHistogram;
typedef ConcurrentUnorderedClusteredPairedInfoBuffer<debruijn_graph::Graph> Buffer;

public:
DistanceEstimator(const debruijn_graph::Graph &graph,
const InPairedIndex &index,
const GraphDistanceFinder &distance_finder,
const PairInfoChecker &checker,
size_t linkage_distance, size_t max_distance)
: base(graph, index, distance_finder, linkage_distance), max_distance_(max_distance) { }
: base(graph, index, distance_finder, checker, linkage_distance),
max_distance_(max_distance) { }

virtual ~DistanceEstimator() { }
virtual ~DistanceEstimator() = default;

void Init() const {
INFO("Using " << this->Name() << " distance estimator");
Expand All @@ -128,7 +144,7 @@ class DistanceEstimator : public AbstractDistanceEstimator {
private:
virtual void ProcessEdge(debruijn_graph::EdgeId e1,
const InPairedIndex &pi,
PairedInfoBuffer<debruijn_graph::Graph> &result) const;
Buffer &result) const;

virtual const std::string Name() const {
static const std::string my_name = "SIMPLE";
Expand Down
20 changes: 6 additions & 14 deletions src/common/paired_info/distance_estimation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ using namespace debruijn_graph;
using namespace omnigraph::de;

void EstimateWithEstimator(PairedInfoIndexT<Graph> &clustered_index,
const AbstractDistanceEstimator &estimator,
AbstractPairInfoChecker<Graph> &checker) {
DEBUG("Estimating distances");

const AbstractDistanceEstimator &estimator) {
INFO("Estimating distances");
estimator.Estimate(clustered_index, omp_get_max_threads());

INFO("Filtering info");
PairInfoFilter<Graph>(checker).Filter(clustered_index);
DEBUG("Info Filtered");
}

// Postprocessing, checking that clusters do not intersect
Expand Down Expand Up @@ -106,15 +100,15 @@ void EstimateScaffoldingDistances(PairedInfoIndexT<Graph> &scaffolding_index,
PairInfoWeightChecker<Graph> checker(graph, 0.);
DEBUG("Weight Filter Done");

SmoothingDistanceEstimator estimator(graph, paired_index, dist_finder,
SmoothingDistanceEstimator estimator(graph, paired_index, dist_finder, checker,
[&] (int i) {return wrapper.CountWeight(i);},
linkage_distance, max_distance,
ade.threshold, ade.range_coeff,
ade.delta_coeff, ade.cutoff,
ade.min_peak_points,
ade.percentage,
ade.derivative_threshold);
EstimateWithEstimator(scaffolding_index, estimator, checker);
EstimateWithEstimator(scaffolding_index, estimator);
}

void EstimatePairedDistances(PairedInfoIndexT<Graph> &clustered_index,
Expand All @@ -130,11 +124,9 @@ void EstimatePairedDistances(PairedInfoIndexT<Graph> &clustered_index,

PairInfoWeightChecker<Graph> checker(graph, de_config.clustered_filter_threshold);

INFO("Weight Filter Done");

DistanceEstimator estimator(graph, paired_index, dist_finder, linkage_distance, max_distance);
DistanceEstimator estimator(graph, paired_index, dist_finder, checker, linkage_distance, max_distance);

EstimateWithEstimator(clustered_index, estimator, checker);
EstimateWithEstimator(clustered_index, estimator);

INFO("Refining clustered pair information "); // this procedure checks, whether index
RefinePairedInfo(clustered_index, graph); // contains intersecting paired info clusters,
Expand Down
44 changes: 22 additions & 22 deletions src/common/paired_info/pair_info_filters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
#define PAIR_INFO_FILTERS_HPP_

#include "paired_info_helpers.hpp"
#include "sequence/sequence.hpp"

namespace omnigraph {

namespace de {

template<class Graph>
class AbstractPairInfoChecker{
class AbstractPairInfoChecker {
private:
typedef typename Graph::VertexId VertexId;
typedef typename Graph::EdgeId EdgeId;
Expand All @@ -28,52 +29,51 @@ class AbstractPairInfoChecker{
public:
AbstractPairInfoChecker(const Graph &graph) : graph_(graph) { }

virtual bool Check(const PairInfoT&) {
virtual bool Check(EdgeId, EdgeId, Point) const {
return true;
}

virtual bool Check(EdgeId, EdgeId) {
virtual bool Check(EdgeId, EdgeId) const {
return true;
}

virtual ~AbstractPairInfoChecker() { }
virtual ~AbstractPairInfoChecker() = default;
};

template<class Graph>
class PairInfoWeightChecker : public AbstractPairInfoChecker<Graph>{
private:
typedef typename Graph::EdgeId EdgeId;
typedef PairInfo<EdgeId> PairInfoT;
double weight_threshold_;
DEWeight weight_threshold_;

public:
PairInfoWeightChecker(const Graph& graph, double weight_threshold) :
AbstractPairInfoChecker<Graph>(graph), weight_threshold_(weight_threshold) {
}

bool Check(const PairInfoT& info) {
return math::ge(info.weight(), weight_threshold_);
bool Check(EdgeId, EdgeId, Point p) const override {
return math::ge(p.weight, weight_threshold_);
}
};

template<class Graph>
class PairInfoWeightCheckerWithCoverage: public AbstractPairInfoChecker<Graph> {
private:
private:
typedef typename Graph::EdgeId EdgeId;
typedef PairInfo<EdgeId> PairInfoT;
double weight_threshold_;
DEWeight weight_threshold_;

public:
PairInfoWeightCheckerWithCoverage(const Graph& graph, double weight_threshold) :
AbstractPairInfoChecker<Graph>(graph), weight_threshold_(weight_threshold){
}
public:
PairInfoWeightCheckerWithCoverage(const Graph& graph, double weight_threshold)
: AbstractPairInfoChecker<Graph>(graph), weight_threshold_(weight_threshold) { }

bool Check(const PairInfoT& info) {
double info_weight = info.weight();
bool Check(EdgeId e1, EdgeId e2, Point p) const override {
double info_weight = p.weight;
return math::ge(info_weight, weight_threshold_)
|| (math::ge(info_weight, 0.1 * this->graph_.coverage(info.first)))
|| (math::ge(info_weight, 0.1 * this->graph_.coverage(info.second)));
}
|| (math::ge(info_weight, 0.1 * this->graph_.coverage(e1)))
|| (math::ge(info_weight, 0.1 * this->graph_.coverage(e2)));
}
};

template <class Graph>
Expand Down Expand Up @@ -148,7 +148,6 @@ class AmbiguousPairInfoChecker : public AbstractPairInfoChecker<Graph> {
}

bool InnerCheck(const PairInfoT& info){

EdgeId edge1 = info.first;
EdgeId edge2 = info.second;

Expand Down Expand Up @@ -182,14 +181,15 @@ class AmbiguousPairInfoChecker : public AbstractPairInfoChecker<Graph> {
relative_length_threshold_(relative_length_threshold),
relative_seq_threshold_(relative_seq_threshold) { }

bool Check(const PairInfoT& info) {
TRACE(this->graph_.int_id(info.first) << " " << this->graph_.int_id(info.second));
bool Check(EdgeId e1, EdgeId e2, Point p) const override {
PairInfoT info(e1, e2, p);
TRACE(this->graph_.int_id(e1) << " " << this->graph_.int_id(e2));
if (EdgesAreFromSimpleBulgeWithAmbPI(info)){
TRACE("Forward directed edges form a simple bulge");
return InnerCheck(info);
}

if (EdgesAreFromSimpleBulgeWithAmbPI(BackwardInfo(info))){
if (EdgesAreFromSimpleBulgeWithAmbPI(BackwardInfo(info))) {
TRACE("Backward directed edges form a simple bulge");
return InnerCheck(BackwardInfo(info));
}
Expand Down
Loading