diff --git a/sketch_main.cpp b/sketch_main.cpp index 4431852..cea9a48 100644 --- a/sketch_main.cpp +++ b/sketch_main.cpp @@ -1,3 +1,4 @@ +#include "sequence/alphabets.hpp" #include "sequence/fasta_io.hpp" #include "sketch/edit_distance.hpp" #include "sketch/hash_base.hpp" @@ -16,6 +17,7 @@ #include #include +#include #include #include #include @@ -25,7 +27,7 @@ using namespace ts; // The main command this program should perform. // Triangle: compute a triangular distance matrix. // More actions will be added. -DEFINE_string(action, "", "Which action to do. One of: triangle, none"); +DEFINE_string(action, "triangle", "Which action to do. One of: triangle, none"); DEFINE_string(alphabet, "dna4", @@ -113,83 +115,6 @@ void adjust_short_names() { } } -template -class SketchHelper { - public: - SketchHelper(std::function(const std::vector &)> sketcher, - std::function(const std::vector &)> slide_sketcher) - : sketcher(std::move(sketcher)), slide_sketcher(std::move(slide_sketcher)) {} - - void compute_sketches() { - size_t num_seqs = seqs.size(); - sketches = new3D(seqs.size(), FLAGS_embed_dim, 0); - for (size_t si = 0; si < num_seqs; si++) { - std::vector kmers - = seq2kmer(seqs[si], FLAGS_kmer_length, alphabet_size); - - for (size_t i = 0; i < kmers.size(); i += FLAGS_stride) { - auto end = std::min(kmers.begin() + i + FLAGS_window_size, kmers.end()); - std::vector kmer_slice(kmers.begin() + i, end); - std::vector embed_slice = sketcher(kmer_slice); - for (int m = 0; m < FLAGS_embed_dim; m++) { - sketches[si][m].push_back(embed_slice[m]); - } - } - } - } - - void compute_slide() { - sketches = new3D(seqs.size(), FLAGS_embed_dim, 0); - - for (size_t si = 0; si < seqs.size(); si++) { - std::vector kmers - = seq2kmer(seqs[si], FLAGS_kmer_length, alphabet_size); - sketches[si] = slide_sketcher(kmers); - } - } - - void read_input() { - FastaFile file = read_fasta(FLAGS_i, FLAGS_input_format); - seqs = std::move(file.sequences); - seq_names = std::move(file.comments); - } - - void save_output() { - std::filesystem::path ofile = std::filesystem::absolute(std::filesystem::path(FLAGS_o)); - std::filesystem::path opath = ofile.parent_path(); - if (!std::filesystem::exists(opath) && !std::filesystem::create_directories(opath)) { - std::cerr << "Could not create output directory: " << opath << std::endl; - std::exit(1); - } - - std::ofstream fo(FLAGS_o); - if (!fo.is_open()) { - std::cerr << "Could not open " << ofile << " for writing." << std::endl; - std::exit(1); - } - std::cout << "Writing sketches to: " << FLAGS_o << std::endl; - - for (size_t si = 0; si < seqs.size(); si++) { - fo << seq_names[si] << std::endl; - for (size_t m = 0; m < sketches[si].size(); m++) { - for (size_t i = 0; i < sketches[si][m].size(); i++) { - fo << sketches[si][m][i] << ","; - } - } - fo << '\b' << std::endl; - } - fo.close(); - } - - private: - Vec2D seqs; - std::vector seq_names; - Vec3D sketches; - - std::function(const std::vector &)> sketcher; - std::function(const std::vector &)> slide_sketcher; -}; - // Some global constant types. using seq_type = uint8_t; @@ -211,7 +136,12 @@ void run_triangle(SketchAlgorithm &algorithm) { for (size_t i = 0; i < n; ++i) { assert(files[i].sequences.size() == 1 && "Each input file must contain exactly one sequence!"); - sketches[i] = algorithm.compute(files[i].sequences[0]); + if constexpr (SketchAlgorithm::kmer_input) { + sketches[i] + = algorithm.compute(files[i].sequences[0], FLAGS_kmer_length, alphabet_size); + } else { + sketches[i] = algorithm.compute(files[i].sequences[0]); + } progress_bar::iter(); } @@ -263,12 +193,26 @@ void run_function_on_algorithm(F f) { auto kmer_word_size = int_pow(alphabet_size, FLAGS_kmer_length); std::random_device rd; + if (FLAGS_sketch_method == "MH") { + f(MinHash(kmer_word_size, FLAGS_embed_dim, HashAlgorithm::murmur, rd())); + return; + } + if (FLAGS_sketch_method == "WMH") { + f(WeightedMinHash(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len, + HashAlgorithm::murmur, rd())); + return; + } + if (FLAGS_sketch_method == "OMH") { + f(OrderedMinHash(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len, + FLAGS_tuple_length, HashAlgorithm::murmur, rd())); + return; + } if (FLAGS_sketch_method == "ED") { f(EditDistance()); return; } if (FLAGS_sketch_method == "TE") { - f(TensorEmbedding(kmer_word_size, FLAGS_tuple_length)); + f(TensorEmbedding(alphabet_size, FLAGS_tuple_length, "TensorEmbedding")); return; } if (FLAGS_sketch_method == "TS") { @@ -285,6 +229,7 @@ void run_function_on_algorithm(F f) { FLAGS_window_size, FLAGS_stride, rd())); return; } + std::cerr << "Unknown sketch method: " << FLAGS_sketch_method << "\n"; } @@ -299,69 +244,10 @@ int main(int argc, char *argv[]) { std::exit(1); } - auto kmer_word_size = int_pow(alphabet_size, FLAGS_kmer_length); - - std::random_device rd; - if (FLAGS_action == "triangle") { run_function_on_algorithm([](auto x) { run_triangle(x); }); return 0; } - if (FLAGS_sketch_method.substr(FLAGS_sketch_method.size() - 2, 2) == "MH") { - std::function(const std::vector &)> sketcher; - - if (FLAGS_sketch_method == "MH") { - // The hash function is part of the lambda state. - sketcher = [&, - min_hash - = MinHash(kmer_word_size, FLAGS_embed_dim, HashAlgorithm::uniform, - rd())](const std::vector &seq) mutable { - return min_hash.compute(seq); - }; - } else if (FLAGS_sketch_method == "WMH") { - sketcher = [&, - wmin_hash - = WeightedMinHash(kmer_word_size, FLAGS_embed_dim, FLAGS_max_len, - HashAlgorithm::uniform, rd())]( - const std::vector &seq) mutable { - return wmin_hash.compute(seq); - }; - } else if (FLAGS_sketch_method == "OMH") { - sketcher = [&, - omin_hash = OrderedMinHash(kmer_word_size, FLAGS_embed_dim, - FLAGS_max_len, FLAGS_tuple_length, - HashAlgorithm::uniform, rd())]( - const std::vector &seq) mutable { - return omin_hash.compute(seq); - }; - } - std::function(const std::vector &)> slide_sketcher - = [&](const std::vector & /*unused*/) { return new2D(0, 0); }; - SketchHelper sketch_helper(sketcher, slide_sketcher); - sketch_helper.read_input(); - sketch_helper.compute_sketches(); - sketch_helper.save_output(); - } else if (FLAGS_sketch_method.rfind("TS", 0) == 0) { - Tensor tensor_sketch(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length, rd()); - TensorBlock tensor_block(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length, - FLAGS_block_size, rd()); - TensorSlide tensor_slide(kmer_word_size, FLAGS_embed_dim, FLAGS_tuple_length, - FLAGS_window_size, FLAGS_stride, rd()); - std::function(const std::vector &)> sketcher - = [&](const std::vector &seq) { - return FLAGS_block_size == 1 ? tensor_sketch.compute(seq) - : tensor_block.compute(seq); - }; - std::function(const std::vector &)> slide_sketcher - = [&](const std::vector &seq) { return tensor_slide.compute(seq); }; - SketchHelper sketch_helper(sketcher, slide_sketcher); - sketch_helper.read_input(); - FLAGS_sketch_method == "TSS" ? sketch_helper.compute_slide() - : sketch_helper.compute_sketches(); - sketch_helper.save_output(); - } else { - std::cerr << "Unkknown method: " << FLAGS_sketch_method << std::endl; - exit(1); - } + std::cerr << "Unknown action: " << FLAGS_action << "\n"; }