diff --git a/Cargo.lock b/Cargo.lock index a17275a..7ae14fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1289,7 +1289,7 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5f5376ea5e30ce23c03eb77cbe4962b988deead10910c372b226388b594c084" dependencies = [ - "semver", + "semver 0.1.20", ] [[package]] @@ -1348,6 +1348,12 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4f410fedcf71af0345d7607d246e7ad15faaadd49d240ee3b24e5dc21a820ac" +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.197" @@ -1626,7 +1632,7 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "trgt" -version = "1.0.0" +version = "1.1.0" dependencies = [ "arrayvec", "bio", @@ -1643,6 +1649,7 @@ dependencies = [ "rayon", "resvg", "rust-htslib", + "semver 1.0.23", "svg2pdf", "tempfile", "tiny-skia", diff --git a/Cargo.toml b/Cargo.toml index 0f013a7..707a480 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "trgt" -version = "1.0.0" +version = "1.1.0" edition = "2021" build = "build.rs" @@ -26,4 +26,5 @@ tiny-skia = "0.11" svg2pdf = "0.9" tempfile = "3" rayon = "1.10" -crossbeam-channel = "0.5" \ No newline at end of file +crossbeam-channel = "0.5" +semver = "1.0" \ No newline at end of file diff --git a/README.md b/README.md index 949c344..0b2fc68 100644 --- a/README.md +++ b/README.md @@ -19,15 +19,14 @@ changes to the input and output file formats of TRGT. - Repeat definition files are available in [this Zenodo repository](https://zenodo.org/record/8329210) and definitions of known pathogenic repeats are [also available here](repeats/). -## TRGTdb +## Joint analysis of multiple samples -TRGT outputs VCFs containing TR alleles from each region in the repeat catalog. -To facilitate analysis of alleles across multiple samples, we provide the TRGTdb -which can be found [here](https://github.com/PacificBiosciences/trgt/pull/6). -After cloning that fork, the TRGTdb can be installed by running -`python3 -m pip install trgt/`. See the fork's `notebooks/` directory for tutorials -converting results into TRGTdb as well as example analyses. TRGTdb was developed by -[Adam English](https://github.com/ACEnglish). +TRGT outputs VCFs containing repeat alleles from each region in the repeat +catalog. To facilitate analysis of repeats across multiple samples, VCFs can be +either merged into a multi-sample VCF using the `merge` sub-command or converted +into a database using the [TDB tool](https://github.com/ACEnglish/tdb) (formerly +called TRGTdb). TDB offers many advantages over multi-sample VCFs, including +simpler data extraction, support for queries, and reduced file sizes. ## Documentation @@ -119,6 +118,11 @@ tandem repeats at genome scale. 2024](https://www.nature.com/articles/s41587-023 - Lower memory footprint: Better memory management significantly reduces memory usage with large repeat catalogs. - Updated error handling: Malformed entries are now logged as errors without terminating the program. - Added shorthand CLI options to simplify command usage. +- 1.1.0 + - Added a new subcommand `trgt merge`. This command merges VCF files generated by `trgt genotype` into a joint VCF file. **Works with VCFs generated by all versions of TRGT** (the resulting joint VCF will always be in the TRGT v1.0+ format which includes padding bases). + - Added subsampling of regions with ultra-high coverage (`>MAX_DEPTH * 3`, by default 750); implemented via reservoir sampling. + - Fixed a cluster genotyper bug that occurred when only a single read covered a locus. + - Added new logic for filtering non-HiFi reads: remove up to 3% of lower quality reads that do not match the expected repeat sequence. ### DISCLAIMER diff --git a/src/cli.rs b/src/cli.rs index b99c88a..9eebea4 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,4 +1,7 @@ -use crate::utils::{Genotyper, Result, TrgtScoring}; +use crate::{ + merge::vcf_writer::OutputType, + utils::{Genotyper, Result, TrgtScoring}, +}; use chrono::Datelike; use clap::{ArgAction, ArgGroup, Parser, Subcommand}; use env_logger::fmt::Color; @@ -46,6 +49,91 @@ pub enum Command { Plot(PlotArgs), #[clap(about = "Tandem Repeat Catalog Validator")] Validate(ValidateArgs), + #[clap(about = "Tandem Repeat VCF Merger")] + Merge(MergeArgs), +} + +#[derive(Parser, Debug)] +#[command(group(ArgGroup::new("merge")))] +#[command(arg_required_else_help(true))] +pub struct MergeArgs { + #[clap(required = true)] + #[clap(short = 'v')] + #[clap(long = "vcf")] + #[clap(help = "VCF files to merge")] + #[clap(value_name = "VCF")] + #[clap(num_args = 1..)] + pub vcfs: Vec, + + #[clap(short = 'g')] + #[clap(long = "genome")] + #[clap(help = "Path to reference genome FASTA")] + #[clap(value_name = "FASTA")] + #[arg(value_parser = check_file_exists)] + pub genome_path: PathBuf, + + #[clap(short = 'o')] + #[clap(long = "output")] + #[clap(value_name = "FILE")] + #[clap(help = "Write output to a file [standard output]")] + #[arg(value_parser = check_prefix_path)] + pub output: Option, + + #[clap(help_heading("Advanced"))] + #[clap(short = 'O')] + #[clap(long = "output-type")] + #[clap(value_name = "OUTPUT_TYPE")] + #[clap(help = "Output type: u|b|v|z, u/b: un/compressed BCF, v/z: un/compressed VCF")] + #[clap(value_parser = merge_validate_output_type)] + pub output_type: Option, + + #[clap(help_heading("Advanced"))] + #[clap(long = "skip-n")] + #[clap(value_name = "SKIP_N")] + #[clap(help = "Skip the first N records")] + pub skip_n: Option, + + #[clap(help_heading("Advanced"))] + #[clap(long = "process-n")] + #[clap(value_name = "process_N")] + #[clap(help = "Only process N records")] + pub process_n: Option, + + #[clap(help_heading("Advanced"))] + #[clap(long = "print-header")] + #[clap(help = "Print only the merged header and exit")] + pub print_header: bool, + + #[clap(help_heading("Advanced"))] + #[clap(long = "force-single")] + #[clap(help = "Run even if there is only one file on input")] + pub force_single: bool, + + #[clap(help_heading("Advanced"))] + #[clap(hide = true)] + #[clap(long = "force-samples")] + #[clap(help = "Resolve duplicate sample names")] + pub force_samples: bool, + + #[clap(help_heading("Advanced"))] + #[clap(long = "no-version")] + #[clap(help = "Do not append version and command line to the header")] + pub no_version: bool, + + #[clap(help_heading("Advanced"))] + #[clap(hide = true)] + #[clap(long = "missing-to-ref")] + #[clap(help = "Assume genotypes at missing sites are 0/0")] + pub missing_to_ref: bool, + + #[clap(help_heading("Advanced"))] + #[clap(hide = true)] + #[clap(long = "strategy")] + #[clap(value_name = "STRATEGY")] + #[clap(help = "Set variant merging strategy to use")] + #[clap(value_parser(["exact"]))] + #[clap(default_value = "exact")] + pub merge_strategy: String, } #[derive(Parser, Debug)] @@ -391,3 +479,56 @@ fn scoring_from_string(s: &str) -> Result { bandwidth: values[5] as usize, }) } + +fn merge_validate_output_type(s: &str) -> Result { + let valid_prefixes = ["u", "b", "v", "z"]; + if valid_prefixes.contains(&s) { + return match s { + "u" => Ok(OutputType::Bcf { + is_uncompressed: true, + level: None, + }), + "v" => Ok(OutputType::Vcf { + is_uncompressed: true, + level: None, + }), + "b" => Ok(OutputType::Bcf { + is_uncompressed: false, + level: None, + }), + "z" => Ok(OutputType::Vcf { + is_uncompressed: false, + level: None, + }), + _ => unreachable!(), + }; + } + + // NOTE: Can't actually set compression level in rust/htslib at the moment + // if s.len() == 2 { + // let (prefix, suffix) = s.split_at(1); + // if (prefix == "b" || prefix == "z") && suffix.chars().all(|c| c.is_digit(10)) { + // return match prefix { + // "b" => Ok(OutputType::Bcf { + // is_uncompressed: false, + // level: Some(suffix.parse().unwrap()), + // }), + // "z" => Ok(OutputType::Vcf { + // is_uncompressed: false, + // level: Some(suffix.parse().unwrap()), + // }), + // _ => unreachable!(), + // }; + // } else if (prefix == "u" || prefix == "v") && suffix.chars().all(|c| c.is_digit(10)) { + // return Err(format!( + // "Error: compression level ({}) cannot be set on uncompressed streams ({})", + // suffix, prefix + // )); + // } + // } + + Err(format!( + "Invalid output type: {}. Must be one of u, b, v, z.", + s + )) +} diff --git a/src/commands.rs b/src/commands.rs index 2a60758..1265c59 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -1,3 +1,4 @@ pub mod genotype; +pub mod merge; pub mod plot; pub mod validate; diff --git a/src/commands/genotype.rs b/src/commands/genotype.rs index 3b1f439..d00c590 100644 --- a/src/commands/genotype.rs +++ b/src/commands/genotype.rs @@ -27,9 +27,9 @@ struct ThreadLocalData { } thread_local! { - static LOCAL_BAM_READER: ThreadLocalData = ThreadLocalData { + static LOCAL_BAM_READER: ThreadLocalData = const { ThreadLocalData { bam: RefCell::new(None), - }; + } }; } const CHANNEL_BUFFER_SIZE: usize = 2048; diff --git a/src/commands/merge.rs b/src/commands/merge.rs new file mode 100644 index 0000000..e94127e --- /dev/null +++ b/src/commands/merge.rs @@ -0,0 +1,20 @@ +use crate::cli::MergeArgs; +use crate::merge::vcf_processor::VcfProcessor; +use crate::utils::Result; +use std::time; + +pub fn merge(args: MergeArgs) -> Result<()> { + let start_timer = time::Instant::now(); + + let mut vcf_processor = VcfProcessor::new(&args)?; + + if args.print_header { + return Ok(()); + } + + vcf_processor.merge_variants(); + + // TODO: If --output, --write-index is set and the output is compressed, index the file + log::info!("Total execution time: {:.2?}", start_timer.elapsed()); + Ok(()) +} diff --git a/src/hmm/builder.rs b/src/hmm/builder.rs index 3e0f2b8..ee237b0 100644 --- a/src/hmm/builder.rs +++ b/src/hmm/builder.rs @@ -188,7 +188,7 @@ mod tests { use super::*; use crate::hmm::spans::Span; - fn summarize(spans: &Vec) -> Vec<(usize, usize, usize)> { + fn summarize(spans: &[Span]) -> Vec<(usize, usize, usize)> { let mut summary = Vec::new(); for (motif_index, group) in &spans .iter() diff --git a/src/hmm/purity.rs b/src/hmm/purity.rs index 1544536..59a7a8f 100644 --- a/src/hmm/purity.rs +++ b/src/hmm/purity.rs @@ -80,7 +80,7 @@ mod tests { let hmm = build_hmm(&motifs); // GCNGCNGCNGXN let query = "GCAGCCGCTGAG"; - let states = hmm.label(&query); + let states = hmm.label(query); let purity = calc_purity(query.as_bytes(), &hmm, &motifs, &states); assert_eq!(purity, 11.0 / 12.0); } @@ -90,8 +90,8 @@ mod tests { let motifs = vec!["CAG".as_bytes().to_vec(), "CCG".as_bytes().to_vec()]; let hmm = build_hmm(&motifs); let query = ""; - let states = hmm.label(&query); - let purity = calc_purity(&query.as_bytes(), &hmm, &motifs, &states); + let states = hmm.label(query); + let purity = calc_purity(query.as_bytes(), &hmm, &motifs, &states); assert!(purity.is_nan()); } } diff --git a/src/lib.rs b/src/lib.rs index 379efa6..aa02c00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod cli; pub mod commands; pub mod hmm; +pub mod merge; pub mod trgt; pub mod trvz; pub mod utils; diff --git a/src/main.rs b/src/main.rs index f74525d..09f6c4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use clap::Parser; use trgt::{ cli::{init_verbose, Cli, Command, FULL_VERSION}, - commands::{genotype, plot, validate}, + commands::{genotype, merge, plot, validate}, utils::{handle_error_and_exit, Result}, }; @@ -12,6 +12,7 @@ fn runner() -> Result<()> { Command::Genotype(_) => "genotype", Command::Plot(_) => "plot", Command::Validate(_) => "validate", + Command::Merge(_) => "merge", }; log::info!( @@ -24,6 +25,7 @@ fn runner() -> Result<()> { Command::Genotype(args) => genotype::trgt(args)?, Command::Plot(args) => plot::trvz(args)?, Command::Validate(args) => validate::validate(args)?, + Command::Merge(args) => merge::merge(args)?, } log::info!("{} end", env!("CARGO_PKG_NAME")); Ok(()) diff --git a/src/merge/mod.rs b/src/merge/mod.rs new file mode 100644 index 0000000..71b5691 --- /dev/null +++ b/src/merge/mod.rs @@ -0,0 +1,4 @@ +pub mod strategy; +pub mod vcf_processor; +pub mod vcf_reader; +pub mod vcf_writer; diff --git a/src/merge/strategy/exact.rs b/src/merge/strategy/exact.rs new file mode 100644 index 0000000..74df07b --- /dev/null +++ b/src/merge/strategy/exact.rs @@ -0,0 +1,196 @@ +use rust_htslib::bcf::record::GenotypeAllele; +use std::collections::{HashMap, HashSet}; + +pub fn merge_exact( + sample_gts: Vec>, + sample_alleles: Vec>, +) -> (Vec>, Vec<&[u8]>) { + let mut ref_allele: Option<&[u8]> = None; + let mut all_alleles: HashSet<&[u8]> = HashSet::new(); + + for sample_allele in sample_alleles.iter() { + if !sample_allele.is_empty() { + if let Some(ref_allele) = &ref_allele { + assert_eq!( + ref_allele, &sample_allele[0], + "Reference alleles do not match" + ); + } else { + ref_allele = Some(sample_allele[0]); + } + for allele in &sample_allele[1..] { + all_alleles.insert(allele); + } + } + } + let ref_allele = ref_allele.expect("No reference allele found"); + + let mut sorted_alleles: Vec<&[u8]> = all_alleles.into_iter().collect(); + sorted_alleles.sort_by_key(|a| a.len()); + sorted_alleles.insert(0, ref_allele); + + let allele_to_index: HashMap<&[u8], usize> = sorted_alleles + .iter() + .enumerate() + .map(|(idx, &allele)| (allele, idx)) + .collect(); + + let mut out_sample_gts: Vec> = Vec::new(); + for (i, sample_gt) in sample_gts.iter().enumerate() { + let mut out_gt: Vec = Vec::new(); + for gt in sample_gt { + match gt { + GenotypeAllele::PhasedMissing | GenotypeAllele::UnphasedMissing => out_gt.push(*gt), + GenotypeAllele::Phased(pos) | GenotypeAllele::Unphased(pos) => { + let pos_usize: usize = (*pos).try_into().expect("Index out of range"); + let pos_converted = allele_to_index[&sample_alleles[i][pos_usize]]; + let new_gt = match gt { + GenotypeAllele::Phased(_) => GenotypeAllele::Phased(pos_converted as i32), + GenotypeAllele::Unphased(_) => { + GenotypeAllele::Unphased(pos_converted as i32) + } + _ => unreachable!(), + }; + out_gt.push(new_gt); + } + } + } + out_sample_gts.push(out_gt); + } + (out_sample_gts, sorted_alleles) +} + +#[allow(dead_code)] +fn vec_to_comma_separated_string(vec: Vec<&[u8]>) -> String { + vec.into_iter() + .map(|slice| String::from_utf8_lossy(slice).to_string()) + .collect::>() + .join(",") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge_exact() { + let sample_gts = vec![ + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)], + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)], + vec![GenotypeAllele::Unphased(0), GenotypeAllele::Unphased(0)], + vec![ + GenotypeAllele::UnphasedMissing, + GenotypeAllele::UnphasedMissing, + ], + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)], + ]; + + let sample_alleles = vec![ + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA".as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA".as_ref(), + ], + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA".as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + ], + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + ], + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + ], + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + b"CGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGCGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTG" + .as_ref(), + b"CGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGCGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTG" + .as_ref(), + ], + ]; + + let (out_gts, sorted_alleles) = merge_exact(sample_gts, sample_alleles); + + assert_eq!( + sorted_alleles[0], + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + ); + + assert_eq!(sorted_alleles.len(), 6); + + assert_eq!( + out_gts[0], + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)] + ); + assert_eq!( + out_gts[1], + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(3)] + ); + assert_eq!( + out_gts[2], + vec![GenotypeAllele::Unphased(0), GenotypeAllele::Unphased(0)] + ); + assert_eq!( + out_gts[3], + vec![ + GenotypeAllele::UnphasedMissing, + GenotypeAllele::UnphasedMissing + ] + ); + assert_eq!( + out_gts[4], + vec![GenotypeAllele::Unphased(4), GenotypeAllele::Unphased(5)] + ); + } + + #[test] + fn test_merge_exact_phasing() { + let sample_gts = vec![ + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)], + vec![GenotypeAllele::Phased(1), GenotypeAllele::Phased(2)], + ]; + + let sample_alleles = vec![ + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA".as_ref(), + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA".as_ref(), + ], + vec![ + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + .as_ref(), + b"CGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGCGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTG" + .as_ref(), + b"CGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGCGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTGGGGTG" + .as_ref(), + ], + ]; + + let (out_gts, sorted_alleles) = merge_exact(sample_gts, sample_alleles); + + assert_eq!( + sorted_alleles[0], + b"CAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATAAAATA" + ); + + assert_eq!(sorted_alleles.len(), 5); + + assert_eq!( + out_gts[0], + vec![GenotypeAllele::Unphased(1), GenotypeAllele::Unphased(2)] + ); + assert_eq!( + out_gts[1], + vec![GenotypeAllele::Phased(3), GenotypeAllele::Phased(4)] + ); + } +} diff --git a/src/merge/strategy/mod.rs b/src/merge/strategy/mod.rs new file mode 100644 index 0000000..e6fbbc7 --- /dev/null +++ b/src/merge/strategy/mod.rs @@ -0,0 +1 @@ +pub mod exact; diff --git a/src/merge/vcf_processor.rs b/src/merge/vcf_processor.rs new file mode 100644 index 0000000..663207a --- /dev/null +++ b/src/merge/vcf_processor.rs @@ -0,0 +1,585 @@ +use super::{ + strategy::exact::merge_exact, + vcf_reader::{VcfReader, VcfReaders}, + vcf_writer::VcfWriter, +}; +use crate::{ + cli::MergeArgs, + utils::{open_genome_reader, Result}, +}; +use once_cell::sync::Lazy; +use rust_htslib::{ + bcf::{self, header::HeaderView, record::GenotypeAllele, Record}, + faidx, +}; +use semver::Version; +use std::{any::Any, cmp::Ordering, collections::BinaryHeap, env}; + +const MISSING_INTEGER: i32 = i32::MIN; +const VECTOR_END_INTEGER: i32 = i32::MIN + 1; +static MISSING_FLOAT: Lazy = Lazy::new(|| f32::from_bits(0x7F80_0001)); +static VECTOR_END_FLOAT: Lazy = Lazy::new(|| f32::from_bits(0x7F80_0002)); + +fn _vec_to_comma_separated_string(vec: Vec<&[u8]>) -> String { + vec.into_iter() + .map(|slice| String::from_utf8_lossy(slice).to_string()) + .collect::>() + .join(",") +} + +fn _header_to_string(header: &bcf::Header) -> String { + unsafe { + let header_ptr = header.inner; + let mut header_len: i32 = 0; + let header_cstr = rust_htslib::htslib::bcf_hdr_fmt_text(header_ptr, 0, &mut header_len); + std::ffi::CStr::from_ptr(header_cstr) + .to_string_lossy() + .into_owned() + } +} + +trait PushMissingAndEnd: Any { + fn missing() -> Self; + fn vector_end() -> Self; + + fn push_missing_and_end(vec: &mut Vec) + where + Self: Sized, + { + vec.push(Self::missing()); + vec.push(Self::vector_end()); + } +} + +macro_rules! impl_push_missing_and_end { + ($type:ty, $missing:expr, $end:expr) => { + impl PushMissingAndEnd for $type { + fn missing() -> Self { + $missing + } + + fn vector_end() -> Self { + $end + } + } + }; + ($type:ty, $missing:expr, $end:expr, $custom_push:expr) => { + impl PushMissingAndEnd for $type { + fn missing() -> Self { + $missing + } + + fn vector_end() -> Self { + $end + } + + #[allow(clippy::redundant_closure_call)] + fn push_missing_and_end(vec: &mut Vec) { + ($custom_push)(vec); + } + } + }; +} + +impl_push_missing_and_end!(i32, MISSING_INTEGER, VECTOR_END_INTEGER); +impl_push_missing_and_end!(f32, *MISSING_FLOAT, *VECTOR_END_FLOAT); +impl_push_missing_and_end!(Vec, Vec::new(), Vec::new(), |vec: &mut Vec>| { + vec.push(vec![b'.']); +}); + +enum FieldType { + String, + Integer, +} + +#[derive(Debug)] +struct VcfRecordWithSource { + record: bcf::Record, + reader_index: usize, +} + +impl PartialEq for VcfRecordWithSource { + fn eq(&self, other: &Self) -> bool { + self.record.rid() == other.record.rid() && self.record.pos() == other.record.pos() + } +} + +impl Eq for VcfRecordWithSource {} + +impl PartialOrd for VcfRecordWithSource { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for VcfRecordWithSource { + fn cmp(&self, other: &Self) -> Ordering { + match self.record.rid().cmp(&other.record.rid()) { + Ordering::Equal => self.record.pos().cmp(&other.record.pos()).reverse(), + other => other.reverse(), + } + } +} + +pub struct VcfProcessor { + pub readers: Vec, + pub writer: VcfWriter, + pub genome_reader: faidx::Reader, // TODO: Make optional? Only needed for <1.0 + // TODO: add args struct + pub skip_n: usize, + pub process_n: usize, + pub needs_padding: bool, +} + +impl VcfProcessor { + pub fn new(args: &MergeArgs) -> Result { + let vcf_readers = VcfReaders::new(args.vcfs.clone())?; + if vcf_readers.readers.len() == 1 && !args.force_single { + return Err("Expected two or more files to merge, got only one. Use --force-single to proceed anyway".into()); + } + + let genome_reader = open_genome_reader(&args.genome_path)?; + let out_header = Self::create_output_header(&vcf_readers, args)?; + let writer = VcfWriter::new(&out_header, &args.output_type, args.output.as_ref())?; + + let needs_padding = vcf_readers + .readers + .iter() + .any(|reader| reader.version.major < Version::new(1, 0, 0).major); + + Ok(VcfProcessor { + readers: vcf_readers.readers, + writer, + genome_reader, + skip_n: args.skip_n.unwrap_or(0), + process_n: args.process_n.unwrap_or(usize::MAX), + needs_padding, + }) + } + + fn create_output_header(vcf_readers: &VcfReaders, args: &MergeArgs) -> Result { + let mut out_header = bcf::Header::new(); + vcf_readers.merge_headers(&mut out_header, args.force_samples)?; + + // Update header fields to be consistent + out_header.remove_format(b"ALCI"); + out_header.remove_format(b"AM"); + out_header.push_record( + b"##FORMAT=", + ); + out_header.push_record( + b"##FORMAT=", + ); + + if !args.no_version { + Self::add_version_info(&mut out_header); + } + + Ok(out_header) + } + + fn add_version_info(out_header: &mut bcf::Header) { + let version_line = format!( + "##{}Version={}", + env!("CARGO_PKG_NAME"), + *crate::cli::FULL_VERSION + ); + out_header.push_record(version_line.as_bytes()); + + let command_line = env::args().collect::>().join(" "); + let command_line = format!("##{}Command={}", env!("CARGO_PKG_NAME"), command_line); + out_header.push_record(command_line.as_bytes()); + } + + fn set_info_field(&mut self, record: &Record, field_name: &[u8], field_type: FieldType) { + match field_type { + FieldType::String => { + let info_field = record.info(field_name).string().unwrap().unwrap(); + self.writer + .dummy_record + .push_info_string(field_name, &info_field) + .unwrap(); + } + FieldType::Integer => { + let info_field = record.info(field_name).integer().unwrap().unwrap(); + self.writer + .dummy_record + .push_info_integer(field_name, &info_field) + .unwrap(); + } + } + } + + fn merge_variant(&mut self, sample_records: &[Option]) { + let template_index = sample_records.iter().position(|r| r.is_some()).unwrap(); + let template_record = sample_records[template_index].as_ref().unwrap(); + + self.writer.dummy_record.set_rid(template_record.rid()); + self.writer.dummy_record.set_pos(template_record.pos()); + self.writer.dummy_record.set_qual(template_record.qual()); + + // TODO: Consolidate logic to allow for generic INFO fields: i32, f32, etc... + self.set_info_field(template_record, b"TRID", FieldType::String); + self.set_info_field(template_record, b"END", FieldType::Integer); + self.set_info_field(template_record, b"MOTIFS", FieldType::String); + self.set_info_field(template_record, b"STRUC", FieldType::String); + + // TODO: Clean this up + // TODO: Consolidate logic to allow for generic FORMAT fields: i32, f32, etc... + let mut als = Vec::new(); + let mut allrs = Vec::new(); + let mut sds = Vec::new(); + let mut mcs = Vec::new(); + let mut mss = Vec::new(); + let mut aps = Vec::new(); + let mut ams = Vec::new(); + let mut gt_vecs = Vec::new(); + let mut alleles = Vec::new(); + + for record in sample_records.iter() { + if let Some(record) = record { + // TODO: Allow multiple Samples per record, at the moment we just take the first element + alleles.push(record.alleles()); + + // GT + let gt_field = record.genotypes().unwrap(); + let gt = gt_field.get(0); + gt_vecs.push(gt.iter().copied().collect()); + + // TODO: Factor out redundancy + let al_field = record + .format(b"AL") + .integer() + .expect("Error accessing FORMAT AL"); + als.extend(al_field[0].iter().copied()); + if al_field[0].len() == 1 { + als.push(VECTOR_END_INTEGER); + } + + let allr = match record.format(b"ALLR").string() { + Ok(field) => field[0].to_vec(), + // Handle TRGT <=v0.3.4 + Err(_) => { + let alci_field = record.format(b"ALCI").string().unwrap(); + alci_field[0].to_vec() + } + }; + allrs.push(allr); + + let sd_field = record + .format(b"SD") + .integer() + .expect("Error accessing FORMAT SD"); + sds.extend(sd_field[0].iter().copied()); + if sd_field[0].len() == 1 { + sds.push(VECTOR_END_INTEGER); + } + + let mc_field = record + .format(b"MC") + .string() + .expect("Error acessing FORMAT MC"); + let mc = mc_field[0].to_vec(); + mcs.push(mc); + + let ms_field = record + .format(b"MS") + .string() + .expect("Error acessing FORMAT MS"); + let ms = ms_field[0].to_vec(); + mss.push(ms); + + let ap_field = record + .format(b"AP") + .float() + .expect("Error accessing FORMAT AP"); + aps.extend(ap_field[0].iter().copied()); + if ap_field[0].len() == 1 { + aps.push(*VECTOR_END_FLOAT); + } + + let am_field = match record.format(b"AM").float() { + Ok(field) => field[0].to_vec(), + // Handle TRGT <=v0.4.0 + Err(_) => { + let int_field = record + .format(b"AM") + .integer() + .expect("Error accessing FORMAT AM as an integer"); + int_field[0] + .iter() + .map(|&i| { + // Account for missing values + if i == i32::MIN { + *MISSING_FLOAT + } else { + i as f32 / 255.0 + } + }) + .collect::>() + } + }; + ams.extend(am_field.iter().copied()); + if am_field.len() == 1 { + ams.push(*VECTOR_END_FLOAT); + } + } else { + gt_vecs.push(vec![GenotypeAllele::UnphasedMissing]); + alleles.push(vec![]); + + PushMissingAndEnd::push_missing_and_end(&mut als); + PushMissingAndEnd::push_missing_and_end(&mut sds); + PushMissingAndEnd::push_missing_and_end(&mut aps); + PushMissingAndEnd::push_missing_and_end(&mut ams); + PushMissingAndEnd::push_missing_and_end(&mut allrs); + PushMissingAndEnd::push_missing_and_end(&mut mcs); + PushMissingAndEnd::push_missing_and_end(&mut mss); + } + } + + // Merge alleles and genotypes + let (out_gts, out_alleles) = merge_exact(gt_vecs, alleles); + self.writer.dummy_record.set_alleles(&out_alleles).unwrap(); + + // Flatten to a 1D 2D representation using + let mut gts_new: Vec = Vec::new(); + for sample_gt in out_gts { + let mut converted_sample_gt: Vec = + sample_gt.iter().map(|gt| i32::from(*gt)).collect(); + if converted_sample_gt.len() == 1 { + converted_sample_gt.push(VECTOR_END_INTEGER); + } + gts_new.extend(converted_sample_gt); + } + // + + self.writer + .dummy_record + .push_format_integer(b"GT", >s_new) + .unwrap(); + + self.writer + .dummy_record + .push_format_integer(b"AL", &als) + .unwrap(); + + self.writer + .dummy_record + .push_format_string(b"ALLR", &allrs) + .unwrap(); + + self.writer + .dummy_record + .push_format_integer(b"SD", &sds) + .unwrap(); + + self.writer + .dummy_record + .push_format_string(b"MC", &mcs) + .unwrap(); + + self.writer + .dummy_record + .push_format_string(b"MS", &mss) + .unwrap(); + + self.writer + .dummy_record + .push_format_float(b"AP", &aps) + .unwrap(); + + self.writer + .dummy_record + .push_format_float(b"AM", &ams) + .unwrap(); + + self.writer.writer.write(&self.writer.dummy_record).unwrap(); + + self.writer.dummy_record.clear(); + } + + fn init_heap(&mut self) -> BinaryHeap { + let mut heap = BinaryHeap::new(); + for (index, reader) in self.readers.iter_mut().enumerate() { + if reader.advance() { + heap.push(VcfRecordWithSource { + record: reader.current_record.clone(), + reader_index: index, + }); + } + } + heap + } + + fn update_heap( + &mut self, + heap: &mut BinaryHeap, + sample_records: &[Option], + ) { + for (index, record) in sample_records.iter().enumerate() { + if record.is_some() && self.readers[index].advance() { + heap.push(VcfRecordWithSource { + record: self.readers[index].current_record.clone(), + reader_index: index, + }); + } + } + } + + pub fn merge_variants(&mut self) { + let mut n = 0; + let mut n_processed = 0; + + let mut sample_records = vec![None; self.readers.len()]; + let mut heap = self.init_heap(); + while let Some(min_element) = heap.pop() { + let min_rid = min_element.record.rid().unwrap(); + let min_pos = min_element.record.pos(); + sample_records[min_element.reader_index] = Some(min_element.record); + + while let Some(peek_next_element) = heap.peek() { + if peek_next_element.record.rid().unwrap() == min_rid + && peek_next_element.record.pos() == min_pos + { + let next_element = heap.pop().unwrap(); + sample_records[next_element.reader_index] = Some(next_element.record); + } else { + break; + } + } + + if n >= self.skip_n { + log::info!("Processing: {}:{}", min_rid, min_pos); + if self.needs_padding { + let padding_base = self.get_padding_base( + min_rid, + min_pos, + &self.readers[min_element.reader_index].header, + ); + self.add_padding_base(&mut sample_records, padding_base); + } + + self.merge_variant(&sample_records); + n_processed += 1; + if n_processed >= self.process_n { + break; + } + } + n += 1; + + self.update_heap(&mut heap, &sample_records); + sample_records.fill(None); + } + } + + fn add_padding_base(&mut self, sample_records: &mut [Option], padding_base: Vec) { + for (index, record) in sample_records.iter_mut().enumerate() { + if self.readers[index].version.major < Version::new(1, 0, 0).major { + if let Some(record) = record { + let al_0 = record + .format(b"AL") + .integer() + .expect("Error accessing FORMAT AL")[0] + .iter() + .min() + .cloned() + .unwrap(); + // Zero-length allele records do not need to be updated + if al_0 != 0 { + let new_alleles: Vec> = record + .alleles() + .iter() + .map(|allele| { + let mut new_allele = padding_base.to_vec(); + new_allele.extend_from_slice(allele); + new_allele + }) + .collect(); + let new_alleles_refs: Vec<&[u8]> = + new_alleles.iter().map(|a| a.as_slice()).collect(); + record + .set_alleles(&new_alleles_refs) + .expect("Failed to set alleles") + } + } + } + } + } + + fn get_padding_base(&self, rid: u32, pos: i64, header: &HeaderView) -> Vec { + let chrom = header.rid2name(rid).unwrap(); + let chrom_str = std::str::from_utf8(chrom).expect("Invalid UTF-8 sequence"); + self.genome_reader + .fetch_seq(chrom_str, pos as usize, pos as usize) + .map(|seq| seq.to_vec()) + .ok() + .unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rust_htslib::bcf::{Read, Reader}; + use std::collections::BinaryHeap; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_vcf_record_wrapper_heap() { + let mut temp_file = NamedTempFile::new().unwrap(); + writeln!(temp_file, "##fileformat=VCFv4.2").unwrap(); + writeln!(temp_file, "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO").unwrap(); + let reader = Reader::from_path(temp_file.path()).unwrap(); + + let mut record1 = reader.empty_record(); + record1.set_rid(Some(1)); + record1.set_pos(100); + + let mut record2 = reader.empty_record(); + record2.set_rid(Some(1)); + record2.set_pos(2000); + + let mut record3 = reader.empty_record(); + record3.set_rid(Some(1)); + record3.set_pos(50); + + let mut record4 = reader.empty_record(); + record4.set_rid(Some(10)); + record4.set_pos(99); + + let mut heap = BinaryHeap::new(); + heap.push(VcfRecordWithSource { + record: record1, + reader_index: 0, + }); + heap.push(VcfRecordWithSource { + record: record4, + reader_index: 3, + }); + heap.push(VcfRecordWithSource { + record: record2, + reader_index: 1, + }); + heap.push(VcfRecordWithSource { + record: record3, + reader_index: 2, + }); + + let r = heap.pop().unwrap(); + assert_eq!(r.record.rid(), Some(1)); + assert_eq!(r.record.pos(), 50); + + let r = heap.pop().unwrap(); + assert_eq!(r.record.rid(), Some(1)); + assert_eq!(r.record.pos(), 100); + + let r = heap.pop().unwrap(); + assert_eq!(r.record.rid(), Some(1)); + assert_eq!(r.record.pos(), 2000); + + let r = heap.pop().unwrap(); + assert_eq!(r.record.rid(), Some(10)); + assert_eq!(r.record.pos(), 99); + } +} diff --git a/src/merge/vcf_reader.rs b/src/merge/vcf_reader.rs new file mode 100644 index 0000000..e5f6d64 --- /dev/null +++ b/src/merge/vcf_reader.rs @@ -0,0 +1,174 @@ +use crate::utils::Result; +use rust_htslib::bcf::{self, header::HeaderView, Header, HeaderRecord, Read}; +use semver::Version; +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; + +pub struct VcfReader { + pub reader: bcf::IndexedReader, + pub header: bcf::header::HeaderView, + pub current_record: bcf::Record, + pub version: Version, + pub index: usize, +} + +impl VcfReader { + pub fn new(file: PathBuf, index: usize) -> Result { + log::info!("Start loading VCF {:?}", &file); + // TODO: Check if file is a VCF + // TODO: Check if indexed VCF + // TODO: Check if valid VCF + let reader = bcf::IndexedReader::from_path(&file) + .map_err(|e| format!("Failed to open VCF file {}: {}", file.display(), e))?; + let header = reader.header().clone(); + + let version = get_trgt_version(&header, &file)?; + log::debug!("{:?} has version: {}", file.file_name().unwrap(), version); + + if header.sample_count() > 1 { + return Err(format!( + "Unsupported: VCF file with multiple samples: {}", + file.display() + )); + } + + // TODO: Create a normalized struct for variant records + let current_record = reader.empty_record(); + log::info!("Finished loading VCF {:?}", &file); + Ok(VcfReader { + reader, + header, + current_record, + version, + index, + }) + } + + pub fn advance(&mut self) -> bool { + match self.reader.read(&mut self.current_record) { + Some(Ok(())) => { + self.update_record_for_version(); + true + } + Some(Err(_)) | None => false, + } + } + fn update_record_for_version(&mut self) { + if self.version.major < Version::new(1, 0, 0).major { + // Only zero-length alleles had padding in earlier versions + let al_0 = self + .current_record + .format(b"AL") + .integer() + .expect("Error accessing FORMAT AL")[0] + .iter() + .min() + .cloned() + .unwrap(); + if al_0 != 0 { + self.current_record.set_pos(self.current_record.pos() - 1); + } + } + } +} + +fn get_trgt_version(vcf_header: &HeaderView, file: &Path) -> Result { + // TODO: Add logic to deal with merged TRGT VCFs (assume latest version?) + let mut trgt_version = None; + + for record in vcf_header.header_records().iter() { + if let HeaderRecord::Generic { key, value } = record { + if key == "trgtVersion" { + trgt_version = Some(value.clone()); + break; + } + } + } + + // If trgtVersion is not in the header its either a , +} + +impl VcfReaders { + pub fn new(vcf_files: Vec) -> Result { + let readers = vcf_files + .into_iter() + .enumerate() + .map(|(index, file)| VcfReader::new(file, index)) + .collect::>>()?; + Ok(VcfReaders { readers }) + } + + pub fn merge_headers(&self, dst_header: &mut Header, force_samples: bool) -> Result<()> { + let mut observed_sample_ids = HashSet::new(); + + for reader in &self.readers { + let src_header = &reader.header; + // TODO: error handling + unsafe { + dst_header.inner = + rust_htslib::htslib::bcf_hdr_merge(dst_header.inner, src_header.inner); + } + + for sample_id in src_header.samples() { + if observed_sample_ids.contains(sample_id) { + if force_samples { + continue; // If forcing samples, skip duplicates + } else { + return Err(format!( + "Duplicate sample ID found: {}", + String::from_utf8_lossy(sample_id) + )); + } + } + observed_sample_ids.insert(sample_id.to_vec()); + dst_header.push_sample(sample_id); + } + } + + unsafe { + rust_htslib::htslib::bcf_hdr_sync(dst_header.inner); + } + + Ok(()) + } +} diff --git a/src/merge/vcf_writer.rs b/src/merge/vcf_writer.rs new file mode 100644 index 0000000..1594765 --- /dev/null +++ b/src/merge/vcf_writer.rs @@ -0,0 +1,96 @@ +use crate::utils::Result; +use rust_htslib::bcf; + +#[derive(Debug, Clone)] +pub enum OutputType { + Vcf { + is_uncompressed: bool, + level: Option, + }, + Bcf { + is_uncompressed: bool, + level: Option, + }, +} + +pub struct VcfWriter { + pub writer: bcf::Writer, + pub dummy_record: bcf::Record, +} + +impl VcfWriter { + pub fn new( + header: &bcf::Header, + output_type: &Option, + output: Option<&String>, + ) -> Result { + let output_type = match (output_type, output) { + (Some(output_type), _) => output_type.clone(), + (None, Some(path)) => Self::infer_output_type_from_extension(path)?, + (None, None) => OutputType::Vcf { + is_uncompressed: true, + level: None, + }, + }; + + log::debug!("{:?}", &output_type); + + let writer = match output { + Some(path) => { + let (is_uncompressed, format) = match output_type { + OutputType::Vcf { + is_uncompressed, .. + } => (is_uncompressed, bcf::Format::Vcf), + OutputType::Bcf { + is_uncompressed, .. + } => (is_uncompressed, bcf::Format::Bcf), + }; + bcf::Writer::from_path(path, header, is_uncompressed, format) + } + None => { + let (is_uncompressed, format) = match output_type { + OutputType::Vcf { + is_uncompressed, .. + } => (is_uncompressed, bcf::Format::Vcf), + OutputType::Bcf { + is_uncompressed, .. + } => (is_uncompressed, bcf::Format::Bcf), + }; + bcf::Writer::from_stdout(header, is_uncompressed, format) + } + } + .map_err(|e| format!("Failed to create writer: {}", e))?; + // writer.set_threads(4).unwrap(); + let dummy_record = writer.empty_record(); + Ok(VcfWriter { + writer, + dummy_record, + }) + } + + fn infer_output_type_from_extension(path: &str) -> Result { + let path_lower = path.to_lowercase(); + match path_lower.as_str() { + s if s.ends_with(".bcf.gz") => Ok(OutputType::Bcf { + is_uncompressed: false, + level: None, + }), + s if s.ends_with(".vcf.gz") || s.ends_with(".vcf.bgz") => Ok(OutputType::Vcf { + is_uncompressed: false, + level: None, + }), + s if s.ends_with(".bcf") => Ok(OutputType::Bcf { + is_uncompressed: true, + level: None, + }), + s if s.ends_with(".vcf") => Ok(OutputType::Vcf { + is_uncompressed: true, + level: None, + }), + _ => Ok(OutputType::Vcf { + is_uncompressed: true, + level: None, + }), + } + } +} diff --git a/src/trgt/genotype/genotype_cluster.rs b/src/trgt/genotype/genotype_cluster.rs index b4b5554..d13d62d 100644 --- a/src/trgt/genotype/genotype_cluster.rs +++ b/src/trgt/genotype/genotype_cluster.rs @@ -52,12 +52,18 @@ pub fn make_consensus( pub fn genotype(ploidy: Ploidy, seqs: &[&[u8]], trs: &[&str]) -> (Gt, Vec, Vec) { let mut dists = get_dist_matrix(seqs); let num_seqs = seqs.len(); - if ploidy == Ploidy::One { + if ploidy == Ploidy::One || num_seqs == 1 { let group: Vec = (0..num_seqs).collect(); let (allele, size) = make_consensus(num_seqs, trs, &dists, &group); - let gt = Gt::from(size); let classifications = vec![0; num_seqs]; - return (gt, vec![allele], classifications); + if ploidy == Ploidy::One { + let gt = Gt::from(size); + return (gt, vec![allele], classifications); + } + + // one read, two alleles + let gt = Gt::from([size.clone(), size]); + return (gt, vec![allele.clone(), allele], classifications); } let mut groups = cluster(num_seqs, &mut dists); @@ -133,15 +139,8 @@ pub fn genotype(ploidy: Ploidy, seqs: &[&[u8]], trs: &[&str]) -> (Gt, Vec Vec> { - if num_seqs == 0 { - return Vec::new(); - } - + assert!(num_seqs >= 2); assert_eq!(num_seqs * (num_seqs - 1) / 2, dists.len()); - if num_seqs == 1 { - return vec![vec![0]]; - } - if num_seqs == 2 { return vec![vec![0], vec![1]]; } diff --git a/src/trgt/reads/mod.rs b/src/trgt/reads/mod.rs index 8cb0bf4..c2a04d2 100644 --- a/src/trgt/reads/mod.rs +++ b/src/trgt/reads/mod.rs @@ -5,4 +5,5 @@ mod meth; mod snp; mod read; +pub use read::get_rq_tag; pub use read::HiFiRead; diff --git a/src/trgt/reads/read.rs b/src/trgt/reads/read.rs index 9f9babe..6f7d028 100644 --- a/src/trgt/reads/read.rs +++ b/src/trgt/reads/read.rs @@ -197,7 +197,7 @@ fn get_ml_tag(rec: &bam::Record) -> Option { /// /// # Returns /// Returns an `Option` which is `Some` if the RQ tag is present and can be parsed as a float, otherwise `None`. -fn get_rq_tag(rec: &bam::Record) -> Option { +pub fn get_rq_tag(rec: &bam::Record) -> Option { match rec.aux(b"rq") { Ok(Aux::Float(value)) => Some(f64::from(value)), _ => None, diff --git a/src/trgt/workflows/tr.rs b/src/trgt/workflows/tr.rs index 25fb02b..60b8010 100644 --- a/src/trgt/workflows/tr.rs +++ b/src/trgt/workflows/tr.rs @@ -1,14 +1,17 @@ use super::{Allele, Genotype, LocusResult}; use crate::hmm::{ - build_hmm, calc_purity, collapse_labels, count_motifs, replace_invalid_bases, Annotation, + build_hmm, calc_purity, collapse_labels, count_motifs, replace_invalid_bases, Annotation, Hmm, }; +use crate::trgt::reads::get_rq_tag; use crate::trgt::{ genotype::{find_tr_spans, genotype_cluster, genotype_flank, genotype_size, Gt}, locus::Locus, reads::HiFiRead, }; use crate::utils::{Genotyper, Ploidy, Result, TrgtScoring}; -use itertools::Itertools; +use itertools::{izip, Itertools}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use rust_htslib::bam::{self, Read, Record}; use std::vec; @@ -28,23 +31,33 @@ pub fn analyze( if locus.ploidy == Ploidy::Zero { return Ok(LocusResult::empty()); } - let reads = extract_reads( - locus, - bam, - params.search_flank_len as u32, - params.min_read_qual, - )?; - log::debug!("{}: Collected {} reads", locus.id, reads.len()); + let reads = extract_reads(locus, bam, params)?; let clip_radius = 2 * params.search_flank_len; let reads = clip_reads(locus, clip_radius, reads); log::debug!("{}: {} reads left after clipping", locus.id, reads.len()); let (reads, spans) = get_spanning_reads(locus, params, reads); + if reads.is_empty() { return Ok(LocusResult::empty()); } + const MIN_RQ_FOR_PURITY: f64 = 0.9; + let (reads, spans) = if params.min_read_qual < MIN_RQ_FOR_PURITY { + let ret = filter_impure_trs(locus, &reads, &spans, MIN_RQ_FOR_PURITY); + if ret.0.len() < reads.len() { + log::warn!( + "{}: Filtered out {} impure reads", + locus.id, + reads.len() - ret.0.len() + ); + } + ret + } else { + (reads, spans) + }; + let trs = reads .iter() .zip(spans.iter()) @@ -256,9 +269,12 @@ fn assign_read(gt: &Gt, tr_len: usize) -> Assignment { fn extract_reads( locus: &Locus, bam: &mut bam::IndexedReader, - flank_len: u32, - min_read_qual: f64, + params: &Params, ) -> Result> { + let flank_len = params.search_flank_len as u32; + let min_read_qual = params.min_read_qual; + let reservoir_threshold = params.max_depth * 3; + let extraction_region = ( locus.region.contig.as_str(), locus.region.start.saturating_sub(flank_len), @@ -267,32 +283,81 @@ fn extract_reads( let mut reads = Vec::new(); if let Err(msg) = bam.fetch(extraction_region) { - log::warn!("{}", msg); + log::warn!("Fetch error: {}", msg); return Ok(reads); } - let mut num_filtered = 0; + let mut n_filt = 0; + let mut n_reads = 0; let mut record = Record::new(); - while let Some(result) = bam.read(&mut record) { - match result { - Ok(_) => { + while n_reads < reservoir_threshold { + match bam.read(&mut record) { + Some(Ok(_)) => { if record.is_supplementary() || record.is_secondary() { continue; } - let read = HiFiRead::from_hts_rec(&record, &locus.region); - match read.read_qual { - Some(qual) if qual < min_read_qual => num_filtered += 1, - _ => reads.push(read), + + if get_rq_tag(&record).unwrap_or(0.0) < min_read_qual { + n_filt += 1; + continue; + } + + reads.push(HiFiRead::from_hts_rec(&record, &locus.region)); + n_reads += 1; + } + Some(Err(err)) => Err(err.to_string())?, + None => break, + } + } + + // If more reads are available and the reservoir is full -> reservoir sample + if n_reads >= reservoir_threshold { + log::warn!("{}: Reservoir sampling reads", locus.id); + let mut rng = StdRng::seed_from_u64(42); + + while let Some(result) = bam.read(&mut record) { + match result { + Ok(_) => { + if record.is_supplementary() || record.is_secondary() { + continue; + } + + if get_rq_tag(&record).unwrap_or(0.0) < min_read_qual { + n_filt += 1; + continue; + } + + let j = rng.gen_range(0..n_reads); + if j < reservoir_threshold { + reads[j] = HiFiRead::from_hts_rec(&record, &locus.region); + } + n_reads += 1; } + Err(_) => result.map_err(|e| e.to_string())?, } - Err(_) => result.map_err(|e| e.to_string())?, } } - if num_filtered > 0 { - let total = num_filtered + reads.len(); - log::warn!("Quality filtered {} out of {} reads", num_filtered, total); + if n_filt > 0 { + log::warn!( + "{}: Quality filtered {}/{} reads", + locus.id, + n_filt, + n_filt + n_reads + ); } + + if n_reads > reads.len() { + log::debug!( + "{}: Randomly sampled {} out of {} reads", + locus.id, + reads.len(), + n_reads + ); + } else { + log::debug!("{}: Collected {} reads", locus.id, reads.len()); + } + Ok(reads) } @@ -333,6 +398,52 @@ fn get_tr_meth(read: &HiFiRead, span: &(usize, usize)) -> Option { } } +fn filter_impure_trs( + locus: &Locus, + reads: &[HiFiRead], + spans: &[(usize, usize)], + rq_cutoff: f64, +) -> (Vec, Vec<(usize, usize)>) { + let max_filter = std::cmp::max(1_usize, (0.02 * (reads.len() as f64)).round() as usize); + let mut num_filtered = 0; + let mut hmm = Hmm::new(0); + let mut motifs = Vec::new(); + const PURITY_CUTOFF: f64 = 0.7; + izip!(reads, spans) + .filter(|(read, span)| { + if num_filtered == max_filter { + return true; + } + if let Some(rq) = read.read_qual { + if rq >= rq_cutoff { + return true; + } + } + // since HMM building is costly, we will do a "lazy build", i.e., only + // run the constructor if we find a read with low rq + if hmm.num_states == 0 { + motifs = locus + .motifs + .iter() + .map(|m| replace_invalid_bases(m, &['A', 'T', 'C', 'G', 'N'])) + .map(|m| m.as_bytes().to_vec()) + .collect_vec(); + + hmm = build_hmm(&motifs); + } + let seq = std::str::from_utf8(&read.bases[span.0..span.1]).unwrap(); + let seq = replace_invalid_bases(seq, &['A', 'T', 'C', 'G']); + let labels = hmm.label(&seq); + if calc_purity(seq.as_bytes(), &hmm, &motifs, &labels) >= PURITY_CUTOFF { + return true; + } + num_filtered += 1; + false + }) + .map(|(r, s)| (r.clone(), s)) + .multiunzip() +} + fn label_with_hmm(locus: &Locus, seqs: &Vec) -> Vec { let motifs = locus .motifs