diff --git a/Cargo.lock b/Cargo.lock index 4e3fa969..f2758b59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21,9 +21,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" +checksum = "8b79b82693f705137f8fb9b37871d99e4f9a7df12b917eed79c3d3954830a60b" dependencies = [ "cfg-if", "getrandom", @@ -70,9 +70,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.12" +version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540" +checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" dependencies = [ "anstyle", "anstyle-parse", @@ -198,7 +198,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -227,9 +227,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.9.0" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c48f0051a4b4c5e0b6d365cd04af53aeaa209e3cc15ec2cdb69e73cc87fbd0dc" +checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" dependencies = [ "memchr", "regex-automata", @@ -330,9 +330,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.86" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" +checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc" dependencies = [ "libc", ] @@ -363,7 +363,7 @@ dependencies = [ "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -489,7 +489,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -614,7 +614,7 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash 0.8.9", + "ahash 0.8.10", "allocator-api2", "rayon", ] @@ -811,9 +811,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "lz4-sys" @@ -1055,7 +1055,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1174,7 +1174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1237,7 +1237,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", "version_check", "yansi", ] @@ -1310,7 +1310,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1323,7 +1323,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1398,9 +1398,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" dependencies = [ "either", "rayon-core", @@ -1554,7 +1554,7 @@ version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dc64577832a6bcfd10fa6d452c3b5fe7a4ca228375d236f65a1ab0db953ba34" dependencies = [ - "ahash 0.8.9", + "ahash 0.8.10", "fixedbitset", "hashbrown 0.14.3", "indexmap 2.2.3", @@ -1617,7 +1617,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1731,6 +1731,7 @@ dependencies = [ "camino", "csv", "env_logger", + "glob", "log", "needletail", "niffler", @@ -1787,9 +1788,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.50" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2", "quote", @@ -1810,9 +1811,9 @@ checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "tempfile" -version = "3.10.0" +version = "3.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", "fastrand", @@ -1843,7 +1844,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1889,7 +1890,7 @@ checksum = "563b3b88238ec95680aef36bdece66896eaa7ce3c0f1b4f39d38fb2435261352" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] @@ -1987,7 +1988,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", "wasm-bindgen-shared", ] @@ -2009,7 +2010,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2046,7 +2047,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -2064,7 +2065,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.3", + "windows-targets 0.52.4", ] [[package]] @@ -2084,17 +2085,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.3", - "windows_aarch64_msvc 0.52.3", - "windows_i686_gnu 0.52.3", - "windows_i686_msvc 0.52.3", - "windows_x86_64_gnu 0.52.3", - "windows_x86_64_gnullvm 0.52.3", - "windows_x86_64_msvc 0.52.3", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -2105,9 +2106,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -2117,9 +2118,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -2129,9 +2130,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -2141,9 +2142,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -2153,9 +2154,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -2165,9 +2166,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -2177,9 +2178,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "wyz" @@ -2222,7 +2223,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.52", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a2fb414b..edbfb094 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ tempfile = "3.10" needletail = "0.5.1" csv = "1.3.0" camino = "1.1.6" +glob = "0.3.1" rustworkx-core = "0.14.0" [dev-dependencies] diff --git a/doc/README.md b/doc/README.md index 0008d649..65d13d78 100644 --- a/doc/README.md +++ b/doc/README.md @@ -101,6 +101,8 @@ The following formats are accepted: >`genome_filename` entries are considered DNA FASTA, `protein_filename` entries are considered protein FASTA. - 3 columns: `name,read1,read2` > All entries considered DNA FASTA, and both `read1` and `read2` files are used as input for a single sketch with name `name`. +- 4 columns: `name,input_moltype,prefix,exclude` + > This filetype uses `glob` to find files that match `prefix` but do not match `exclude`. As such, `*` are ok in the `prefix` and `exclude` columns. Since we are dealing with "prefixes" here, we automatically search with `*` on the end of the `prefix` entry. A simple way to build a manysketch input file for a directory is this command snippet: ``` diff --git a/src/lib.rs b/src/lib.rs index 4a607f08..f4647156 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,8 +269,9 @@ fn do_manysketch( param_str: String, output: String, singleton: bool, + force: bool, ) -> anyhow::Result { - match manysketch::manysketch(filelist, param_str, output, singleton) { + match manysketch::manysketch(filelist, param_str, output, singleton, force) { Ok(_) => Ok(0), Err(e) => { eprintln!("Error: {e}"); diff --git a/src/manysketch.rs b/src/manysketch.rs index 65364398..a23e0f36 100644 --- a/src/manysketch.rs +++ b/src/manysketch.rs @@ -85,7 +85,7 @@ fn build_siginfo(params: &[Params], moltype: &str) -> Vec { for param in params.iter().cloned() { match moltype { // if dna, only build dna sigs. if protein, only build protein sigs - "dna" if !param.is_dna => continue, + "dna" | "DNA" if !param.is_dna => continue, "protein" if !param.is_protein => continue, _ => (), } @@ -118,8 +118,9 @@ pub fn manysketch( param_str: String, output: String, singleton: bool, + force: bool, ) -> Result<(), Box> { - let (fileinfo, n_fastas) = match load_fasta_fromfile(filelist) { + let (fileinfo, n_fastas) = match load_fasta_fromfile(filelist, force) { Ok((file_info, n_fastas)) => (file_info, n_fastas), Err(e) => bail!("Could not load fromfile csv. Underlying error: {}", e), }; diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index eadb1397..39dd6b59 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -343,6 +343,8 @@ def __init__(self, p): help='number of cores to use (default is all available)') p.add_argument('-s', '--singleton', action="store_true", help='build one sketch per FASTA record, i.e. multiple sketches per FASTA file') + p.add_argument('-f', '--force', action="store_true", + help='allow use of individual FASTA files in more than more sketch') def main(self, args): print_version() @@ -363,7 +365,8 @@ def main(self, args): status = sourmash_plugin_branchwater.do_manysketch(args.fromfile_csv, args.param_string, args.output, - args.singleton) + args.singleton, + args.force) if status == 0: notify(f"...manysketch is done! results in '{args.output}'") return status diff --git a/src/python/tests/test_sketch.py b/src/python/tests/test_sketch.py index b05703b6..ecfae2a7 100644 --- a/src/python/tests/test_sketch.py +++ b/src/python/tests/test_sketch.py @@ -589,3 +589,200 @@ def test_manysketch_reads_singleton(runtmp, capfd): assert sig == ss_sketch2 elif sig.name == 'other': assert sig == ss_sketch3 + + +def test_manysketch_prefix(runtmp, capfd): + fa_csv = runtmp.output('db-fa.csv') + + fa1 = get_test_data('short.fa') + + fa_path = os.path.dirname(fa1) + dna_prefix = os.path.join(fa_path, "short*fa") # need to avoid matching short-protein.fa + prot_prefix = os.path.join(fa_path, "*protein.fa") + + # make prefix input file + with open(fa_csv, 'wt') as fp: + fp.write("name,input_moltype,prefix,exclude\n") + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # short.fa, short2.fa, short3.fa, short-protein.fa + fp.write(f"short_protein,protein,{prot_prefix},\n") # short-protein.fa only + + output = runtmp.output('prefix.zip') + + runtmp.sourmash('scripts', 'manysketch', fa_csv, '-o', output, + '--param-str', "dna,k=31,scaled=1", '-p', "protein,k=10,scaled=1") + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + captured = capfd.readouterr() + print(captured.out) + print(captured.err) + assert "Found 'prefix' CSV. Using 'glob' to find files based on 'prefix' column." in captured.out + assert "DONE. Processed 4 fasta files" in captured.err + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + print(sigs) + + assert len(sigs) == 2 + + # make same sigs with sourmash + fa2 = get_test_data('short2.fa') + fa3 = get_test_data('short3.fa') + fa4 = get_test_data('short-protein.fa') + s1 = runtmp.output('short.sig') + runtmp.sourmash('sketch', 'dna', fa1, fa2, fa3, '-o', s1, + '--param-str', "dna,k=31,scaled=1", '--name', 'short') + sig1 = sourmash.load_one_signature(s1) + s2 = runtmp.output('short-protein.sig') + runtmp.sourmash('sketch', 'protein', fa4, '-o', s2, + '--param-str', "protein,k=10,scaled=1", '--name', 'short_protein') + sig2 = sourmash.load_one_signature(s2) + + expected_signames = ['short', 'short_protein'] + for sig in sigs: + assert sig.name in expected_signames + if sig.name == 'short': + assert sig,minhash.hashes == sig1.minhash.hashes + if sig.name == 'short_protein': + assert sig == sig2 + + +def test_manysketch_prefix2(runtmp, capfd): + fa_csv = runtmp.output('db-fa.csv') + + fa1 = get_test_data('short.fa') + + fa_path = os.path.dirname(fa1) + # test without '*' + dna_prefix = os.path.join(fa_path, "short") # need to avoid matching short-protein.fa + prot_prefix = os.path.join(fa_path, "*protein") + zip_exclude = os.path.join(fa_path, "*zip") + + # make prefix input file + with open(fa_csv, 'wt') as fp: + fp.write("name,input_moltype,prefix,exclude\n") + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # short.fa, short2.fa, short3.fa, short-protein.fa + fp.write(f"short_protein,protein,{prot_prefix},{zip_exclude}\n") # short-protein.fa only + + output = runtmp.output('prefix.zip') + + runtmp.sourmash('scripts', 'manysketch', fa_csv, '-o', output, + '--param-str', "dna,k=31,scaled=1", '-p', "protein,k=10,scaled=1") + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + captured = capfd.readouterr() + print(captured.out) + print(captured.err) + assert "Found 'prefix' CSV. Using 'glob' to find files based on 'prefix' column." in captured.out + assert "DONE. Processed 4 fasta files" in captured.err + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + print(sigs) + + assert len(sigs) == 2 + + # make same sigs with sourmash + fa2 = get_test_data('short2.fa') + fa3 = get_test_data('short3.fa') + fa4 = get_test_data('short-protein.fa') + s1 = runtmp.output('short.sig') + runtmp.sourmash('sketch', 'dna', fa1, fa2, fa3, '-o', s1, + '--param-str', "dna,k=31,scaled=1", '--name', 'short') + sig1 = sourmash.load_one_signature(s1) + s2 = runtmp.output('short-protein.sig') + runtmp.sourmash('sketch', 'protein', fa4, '-o', s2, + '--param-str', "protein,k=10,scaled=1", '--name', 'short_protein') + sig2 = sourmash.load_one_signature(s2) + + expected_signames = ['short', 'short_protein'] + for sig in sigs: + assert sig.name in expected_signames + if sig.name == 'short': + assert sig,minhash.hashes == sig1.minhash.hashes + if sig.name == 'short_protein': + assert sig == sig2 + + +def test_manysketch_prefix_duplicated_fail(runtmp, capfd): + fa_csv = runtmp.output('db-fa.csv') + + fa1 = get_test_data('short.fa') + + fa_path = os.path.dirname(fa1) + # test without '*' + dna_prefix = os.path.join(fa_path, "short") # need to avoid matching short-protein.fa + prot_prefix = os.path.join(fa_path, "*protein") + zip_exclude = os.path.join(fa_path, "*zip") + + # make prefix input file + with open(fa_csv, 'wt') as fp: + fp.write("name,input_moltype,prefix,exclude\n") + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # short.fa, short2.fa, short3.fa, short-protein.fa + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # duplicate of row one -- this should just be skipped + fp.write(f"short_protein,protein,{prot_prefix},{zip_exclude}\n") # short-protein.fa only + # ALSO short-protein.fa, but different name. should raise err without force + fp.write(f"second_protein,protein,{prot_prefix},{zip_exclude}\n") + + output = runtmp.output('prefix.zip') + + with pytest.raises(utils.SourmashCommandFailed): + runtmp.sourmash('scripts', 'manysketch', fa_csv, '-o', output, + '--param-str', "dna,k=31,scaled=1", '-p', "protein,k=10,scaled=1") + + assert not os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + captured = capfd.readouterr() + print(captured.out) + print(captured.err) + assert "Found 'prefix' CSV. Using 'glob' to find files based on 'prefix' column." in captured.out + assert "Found identical FASTA paths in more than one row!" in captured.err + assert "Duplicated paths:" in captured.err + assert "short-protein.fa" in captured.err + assert "Duplicated FASTA files found. Please use --force to bypass this check" in captured.err + + +def test_manysketch_prefix_duplicated_force(runtmp, capfd): + fa_csv = runtmp.output('db-fa.csv') + + fa1 = get_test_data('short.fa') + + fa_path = os.path.dirname(fa1) + # test without '*' + dna_prefix = os.path.join(fa_path, "short") # need to avoid matching short-protein.fa + prot_prefix = os.path.join(fa_path, "*protein") + zip_exclude = os.path.join(fa_path, "*zip") + + # make prefix input file + with open(fa_csv, 'wt') as fp: + fp.write("name,input_moltype,prefix,exclude\n") + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # short.fa, short2.fa, short3.fa, short-protein.fa + fp.write(f"short,DNA,{dna_prefix},{prot_prefix}\n") # duplicate of row one -- this should just be skipped + fp.write(f"short_protein,protein,{prot_prefix},{zip_exclude}\n") # short-protein.fa only + # ALSO short-protein.fa, but different name. should raise err without force + fp.write(f"second_protein,protein,{prot_prefix},{zip_exclude}\n") + + output = runtmp.output('prefix.zip') + + runtmp.sourmash('scripts', 'manysketch', fa_csv, '-o', output, + '--param-str', "dna,k=31,scaled=1", '-p', "protein,k=10,scaled=1", + '--force') + + assert os.path.exists(output) + assert not runtmp.last_result.out # stdout should be empty + captured = capfd.readouterr() + print(captured.out) + print(captured.err) + assert "Found 'prefix' CSV. Using 'glob' to find files based on 'prefix' column." in captured.out + assert "Loaded 3 rows in total (3 DNA FASTA and 2 protein FASTA), 1 duplicate rows skipped." in captured.out + assert "Found identical FASTA paths in more than one row!" in captured.err + assert "Duplicated paths:" in captured.err + assert "short-protein.fa" in captured.err + assert "--force is set. Continuing..." in captured.err + + idx = sourmash.load_file_as_index(output) + sigs = list(idx.signatures()) + print(sigs) + + assert len(sigs) == 3 diff --git a/src/utils.rs b/src/utils.rs index d71c20f2..4c1bb912 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,6 +7,7 @@ use anyhow::{anyhow, Context, Result}; use camino::Utf8Path as Path; use camino::Utf8PathBuf as PathBuf; use csv::Writer; +use glob::glob; use serde::{Deserialize, Serialize}; use std::cmp::{Ordering, PartialOrd}; use std::collections::BinaryHeap; @@ -22,6 +23,7 @@ use sourmash::selection::Selection; use sourmash::signature::{Signature, SigsTrait}; use sourmash::sketch::minhash::KmerMinHash; use sourmash::storage::{FSStorage, InnerStorage, SigStore}; +use std::collections::{HashMap, HashSet}; /// Track a name/minhash. pub struct SmallSignature { @@ -139,6 +141,7 @@ pub struct FastaData { enum CSVType { Assembly, Reads, + Prefix, Unknown, } @@ -155,12 +158,22 @@ fn detect_csv_type(headers: &csv::StringRecord) -> CSVType { && headers.get(2).unwrap() == "read2" { CSVType::Reads + } else if headers.len() == 4 + && headers.get(0).unwrap() == "name" + && headers.get(1).unwrap() == "input_moltype" + && headers.get(2).unwrap() == "prefix" + && headers.get(3).unwrap() == "exclude" + { + CSVType::Prefix } else { CSVType::Unknown } } -pub fn load_fasta_fromfile(sketchlist_filename: String) -> Result<(Vec, usize)> { +pub fn load_fasta_fromfile( + sketchlist_filename: String, + force: bool, +) -> Result<(Vec, usize)> { let mut rdr = csv::Reader::from_path(sketchlist_filename)?; // Check for right header @@ -169,8 +182,9 @@ pub fn load_fasta_fromfile(sketchlist_filename: String) -> Result<(Vec process_assembly_csv(rdr), CSVType::Reads => process_reads_csv(rdr), + CSVType::Prefix => process_prefix_csv(rdr, force), CSVType::Unknown => Err(anyhow!( - "Invalid header. Expected 'name,genome_filename,protein_filename' or 'name,read1,read2', but got '{}'", + "Invalid header. Expected 'name,genome_filename,protein_filename', 'name,read1,read2', or 'name,input_moltype,prefix,exclude', but got '{}'", headers.iter().collect::>().join(",") )), } @@ -291,6 +305,128 @@ fn process_reads_csv(mut rdr: csv::Reader) -> Result<(Vec, + force: bool, +) -> Result<(Vec, usize)> { + let mut results = Vec::new(); + let mut dna_count = 0; + let mut protein_count = 0; + let mut processed_rows = HashSet::new(); + let mut duplicate_count = 0; + let mut all_paths = HashSet::new(); // track FASTA in use + let mut duplicate_paths_count = HashMap::new(); + + for result in rdr.records() { + let record = result?; + let row_string = record.iter().collect::>().join(","); + if processed_rows.contains(&row_string) { + duplicate_count += 1; + continue; + } + processed_rows.insert(row_string.clone()); + + let name = record + .get(0) + .ok_or_else(|| anyhow!("Missing 'name' field"))? + .to_string(); + + let moltype = record + .get(1) + .ok_or_else(|| anyhow!("Missing 'input_moltype' field"))? + .to_string(); + + // Validate moltype + match moltype.as_str() { + "protein" | "dna" | "DNA" => (), + _ => return Err(anyhow!("Invalid 'input_moltype' field value: {}", moltype)), + } + + // For both prefix and exclude, automatically append wildcard for expected "prefix" matching + let prefix = record + .get(2) + .ok_or_else(|| anyhow!("Missing 'prefix' field"))? + .to_string() + + "*"; + + // optional exclude pattern + let exclude = record.get(3).map(|s| s.to_string() + "*"); + + // Use glob to find and collect all paths that match the prefix + let included_paths = glob(&prefix) + .expect("Failed to read glob pattern for included paths") + .filter_map(Result::ok) + .map(|path| PathBuf::from(path.to_str().expect("Path is not valid UTF-8"))) + .collect::>(); + + // Use glob to find and collect all paths that match the exclude_prefix, if any + let excluded_paths = if let Some(ref exclude_pattern) = exclude { + glob(exclude_pattern) + .expect("Failed to read glob pattern for excluded paths") + .filter_map(Result::ok) + .map(|path| PathBuf::from(path.to_str().expect("Path is not valid UTF-8"))) + .collect::>() + } else { + HashSet::new() + }; + + // Exclude the excluded_paths from included_paths + let filtered_paths: Vec = included_paths + .difference(&excluded_paths) + .cloned() + .collect(); + + // Track duplicates among filtered paths + for path in &filtered_paths { + if !all_paths.insert(path.clone()) { + *duplicate_paths_count.entry(path.clone()).or_insert(0) += 1; + } + } + + if !filtered_paths.is_empty() { + match moltype.as_str() { + "dna" | "DNA" => dna_count += filtered_paths.len(), + "protein" => protein_count += filtered_paths.len(), + _ => {} // should not get here b/c validated earlier + } + results.push(FastaData { + name: name.clone(), + paths: filtered_paths.to_vec(), + input_type: moltype.clone(), + }); + } + } + + let total_duplicate_paths: usize = duplicate_paths_count.values().sum(); + + println!("Found 'prefix' CSV. Using 'glob' to find files based on 'prefix' column."); + if total_duplicate_paths > 0 { + eprintln!("Found identical FASTA paths in more than one row!"); + eprintln!("Duplicated paths:"); + for path in duplicate_paths_count.keys() { + eprintln!("{:?}", path); + } + if !force { + return Err(anyhow!( + "Duplicated FASTA files found. Please use --force to bypass this check." + )); + } else { + eprintln!("--force is set. Continuing...") + } + } + println!( + "Loaded {} rows in total ({} DNA FASTA and {} protein FASTA), {} duplicate rows skipped.", + processed_rows.len(), + dna_count, + protein_count, + duplicate_count, + ); + + let n_fastas = dna_count + protein_count; + + Ok((results, n_fastas)) +} + // Load all compatible minhashes from a collection into memory // also store sig name and md5 alongside, as we usually need those pub fn load_sketches( @@ -931,8 +1067,7 @@ pub fn sigwriter( let mut zip = zip::ZipWriter::new(file_writer); let mut manifest_rows: Vec = Vec::new(); // keep track of md5sum occurrences to prevent overwriting duplicates - let mut md5sum_occurrences: std::collections::HashMap = - std::collections::HashMap::new(); + let mut md5sum_occurrences: HashMap = HashMap::new(); while let Ok(message) = recv.recv() { match message {