Skip to content

Commit

Permalink
Remove SketchHelper class (#41)
Browse files Browse the repository at this point in the history
* Remove SketchHelper class

* Small cleanup

* Make action=triangle default
  • Loading branch information
RagnarGrootKoerkamp authored Feb 10, 2021
1 parent 7eccb6a commit 571b719
Showing 1 changed file with 26 additions and 140 deletions.
166 changes: 26 additions & 140 deletions sketch_main.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "sequence/alphabets.hpp"
#include "sequence/fasta_io.hpp"
#include "sketch/edit_distance.hpp"
#include "sketch/hash_base.hpp"
Expand All @@ -16,6 +17,7 @@

#include <filesystem>
#include <memory>
#include <numeric>
#include <random>
#include <sstream>
#include <utility>
Expand All @@ -25,7 +27,7 @@ using namespace ts;
// The main command this program should perform.
// Triangle: compute a triangular distance matrix.
// More actions will be added.
DEFINE_string(action, "", "Which action to do. One of: triangle, none");
DEFINE_string(action, "triangle", "Which action to do. One of: triangle, none");

DEFINE_string(alphabet,
"dna4",
Expand Down Expand Up @@ -113,83 +115,6 @@ void adjust_short_names() {
}
}

template <typename seq_type, class kmer_type, class embed_type>
class SketchHelper {
public:
SketchHelper(std::function<std::vector<embed_type>(const std::vector<kmer_type> &)> sketcher,
std::function<Vec2D<double>(const std::vector<uint64_t> &)> slide_sketcher)
: sketcher(std::move(sketcher)), slide_sketcher(std::move(slide_sketcher)) {}

void compute_sketches() {
size_t num_seqs = seqs.size();
sketches = new3D<embed_type>(seqs.size(), FLAGS_embed_dim, 0);
for (size_t si = 0; si < num_seqs; si++) {
std::vector<kmer_type> kmers
= seq2kmer<seq_type, kmer_type>(seqs[si], FLAGS_kmer_length, alphabet_size);

for (size_t i = 0; i < kmers.size(); i += FLAGS_stride) {
auto end = std::min(kmers.begin() + i + FLAGS_window_size, kmers.end());
std::vector<kmer_type> kmer_slice(kmers.begin() + i, end);
std::vector<embed_type> embed_slice = sketcher(kmer_slice);
for (int m = 0; m < FLAGS_embed_dim; m++) {
sketches[si][m].push_back(embed_slice[m]);
}
}
}
}

void compute_slide() {
sketches = new3D<double>(seqs.size(), FLAGS_embed_dim, 0);

for (size_t si = 0; si < seqs.size(); si++) {
std::vector<uint64_t> kmers
= seq2kmer<uint8_t, uint64_t>(seqs[si], FLAGS_kmer_length, alphabet_size);
sketches[si] = slide_sketcher(kmers);
}
}

void read_input() {
FastaFile<seq_type> file = read_fasta<seq_type>(FLAGS_i, FLAGS_input_format);
seqs = std::move(file.sequences);
seq_names = std::move(file.comments);
}

void save_output() {
std::filesystem::path ofile = std::filesystem::absolute(std::filesystem::path(FLAGS_o));
std::filesystem::path opath = ofile.parent_path();
if (!std::filesystem::exists(opath) && !std::filesystem::create_directories(opath)) {
std::cerr << "Could not create output directory: " << opath << std::endl;
std::exit(1);
}

std::ofstream fo(FLAGS_o);
if (!fo.is_open()) {
std::cerr << "Could not open " << ofile << " for writing." << std::endl;
std::exit(1);
}
std::cout << "Writing sketches to: " << FLAGS_o << std::endl;

for (size_t si = 0; si < seqs.size(); si++) {
fo << seq_names[si] << std::endl;
for (size_t m = 0; m < sketches[si].size(); m++) {
for (size_t i = 0; i < sketches[si][m].size(); i++) {
fo << sketches[si][m][i] << ",";
}
}
fo << '\b' << std::endl;
}
fo.close();
}

private:
Vec2D<seq_type> seqs;
std::vector<std::string> seq_names;
Vec3D<embed_type> sketches;

std::function<std::vector<embed_type>(const std::vector<kmer_type> &)> sketcher;
std::function<Vec2D<double>(const std::vector<uint64_t> &)> slide_sketcher;
};

// Some global constant types.
using seq_type = uint8_t;

Expand All @@ -211,7 +136,12 @@ void run_triangle(SketchAlgorithm &algorithm) {
for (size_t i = 0; i < n; ++i) {
assert(files[i].sequences.size() == 1
&& "Each input file must contain exactly one sequence!");
sketches[i] = algorithm.compute(files[i].sequences[0]);
if constexpr (SketchAlgorithm::kmer_input) {
sketches[i]
= algorithm.compute(files[i].sequences[0], FLAGS_kmer_length, alphabet_size);
} else {
sketches[i] = algorithm.compute(files[i].sequences[0]);
}
progress_bar::iter();
}

Expand Down Expand Up @@ -263,12 +193,26 @@ void run_function_on_algorithm(F f) {
auto kmer_word_size = int_pow<kmer_type>(alphabet_size, FLAGS_kmer_length);

std::random_device rd;
if (FLAGS_sketch_method == "MH") {
f(MinHash<kmer_type>(kmer_word_size, FLAGS_embed_dim, HashAlgorithm::murmur, rd()));
return;
}
if (FLAGS_sketch_method == "WMH") {
f(WeightedMinHash<kmer_type>(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len,
HashAlgorithm::murmur, rd()));
return;
}
if (FLAGS_sketch_method == "OMH") {
f(OrderedMinHash<kmer_type>(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len,
FLAGS_tuple_length, HashAlgorithm::murmur, rd()));
return;
}
if (FLAGS_sketch_method == "ED") {
f(EditDistance<seq_type>());
return;
}
if (FLAGS_sketch_method == "TE") {
f(TensorEmbedding<seq_type>(kmer_word_size, FLAGS_tuple_length));
f(TensorEmbedding<seq_type>(alphabet_size, FLAGS_tuple_length, "TensorEmbedding"));
return;
}
if (FLAGS_sketch_method == "TS") {
Expand All @@ -285,6 +229,7 @@ void run_function_on_algorithm(F f) {
FLAGS_window_size, FLAGS_stride, rd()));
return;
}
std::cerr << "Unknown sketch method: " << FLAGS_sketch_method << "\n";
}


Expand All @@ -299,69 +244,10 @@ int main(int argc, char *argv[]) {
std::exit(1);
}

auto kmer_word_size = int_pow<uint64_t>(alphabet_size, FLAGS_kmer_length);

std::random_device rd;

if (FLAGS_action == "triangle") {
run_function_on_algorithm([](auto x) { run_triangle(x); });
return 0;
}

if (FLAGS_sketch_method.substr(FLAGS_sketch_method.size() - 2, 2) == "MH") {
std::function<std::vector<uint64_t>(const std::vector<uint64_t> &)> sketcher;

if (FLAGS_sketch_method == "MH") {
// The hash function is part of the lambda state.
sketcher = [&,
min_hash
= MinHash<uint64_t>(kmer_word_size, FLAGS_embed_dim, HashAlgorithm::uniform,
rd())](const std::vector<uint64_t> &seq) mutable {
return min_hash.compute(seq);
};
} else if (FLAGS_sketch_method == "WMH") {
sketcher = [&,
wmin_hash
= WeightedMinHash<uint64_t>(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len,
HashAlgorithm::uniform, rd())](
const std::vector<uint64_t> &seq) mutable {
return wmin_hash.compute(seq);
};
} else if (FLAGS_sketch_method == "OMH") {
sketcher = [&,
omin_hash = OrderedMinHash<uint64_t>(kmer_word_size, FLAGS_embed_dim,
FLAGS_max_len, FLAGS_tuple_length,
HashAlgorithm::uniform, rd())](
const std::vector<uint64_t> &seq) mutable {
return omin_hash.compute(seq);
};
}
std::function<Vec2D<double>(const std::vector<uint64_t> &)> slide_sketcher
= [&](const std::vector<uint64_t> & /*unused*/) { return new2D<double>(0, 0); };
SketchHelper<uint8_t, uint64_t, uint64_t> sketch_helper(sketcher, slide_sketcher);
sketch_helper.read_input();
sketch_helper.compute_sketches();
sketch_helper.save_output();
} else if (FLAGS_sketch_method.rfind("TS", 0) == 0) {
Tensor<uint64_t> tensor_sketch(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length, rd());
TensorBlock<uint64_t> tensor_block(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length,
FLAGS_block_size, rd());
TensorSlide<uint64_t> tensor_slide(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length,
FLAGS_window_size, FLAGS_stride, rd());
std::function<std::vector<double>(const std::vector<uint64_t> &)> sketcher
= [&](const std::vector<uint64_t> &seq) {
return FLAGS_block_size == 1 ? tensor_sketch.compute(seq)
: tensor_block.compute(seq);
};
std::function<Vec2D<double>(const std::vector<uint64_t> &)> slide_sketcher
= [&](const std::vector<uint64_t> &seq) { return tensor_slide.compute(seq); };
SketchHelper<uint8_t, uint64_t, double> sketch_helper(sketcher, slide_sketcher);
sketch_helper.read_input();
FLAGS_sketch_method == "TSS" ? sketch_helper.compute_slide()
: sketch_helper.compute_sketches();
sketch_helper.save_output();
} else {
std::cerr << "Unkknown method: " << FLAGS_sketch_method << std::endl;
exit(1);
}
std::cerr << "Unknown action: " << FLAGS_action << "\n";
}

0 comments on commit 571b719

Please sign in to comment.