Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for read groups in EstimatePoolingFractions #639

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 134 additions & 77 deletions src/main/scala/com/fulcrumgenomics/bam/EstimatePoolingFractions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package com.fulcrumgenomics.bam

import java.lang.Math.{max, min}
import java.util

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.bam.api.SamSource
Expand All @@ -33,11 +34,9 @@ import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.sopt.{arg, clp}
import com.fulcrumgenomics.util.Metric.{Count, Proportion}
import com.fulcrumgenomics.util.{Io, Metric, Sequences}
import com.fulcrumgenomics.vcf.ByIntervalListVariantContextIterator
import com.fulcrumgenomics.vcf.api.{Variant, VcfSource}
import htsjdk.samtools.util.SamLocusIterator.LocusInfo
import htsjdk.samtools.util._
import htsjdk.variant.variantcontext.VariantContext
import htsjdk.variant.vcf.VCFFileReader
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression

@clp(group=ClpGroups.SamOrBam, description=
Expand All @@ -48,6 +47,11 @@ import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression
|for the alternative allele fractions at each SNP locus, using as inputs the individual sample's genotypes.
|Only SNPs that are bi-allelic within the pooled samples are used.
|
|Each sample's contribution of REF vs. ALT alleles at each site is derived in one of two ways. If
|the sample's genotype in the VCF has the `AF` attribute then the value from that field will be used. If the
|genotype has no AF attribute then the contribution is estimated based on the genotype (e.g. 0/0 will be 100%
|ref, 0/1 will be 50% ref and 50% alt, etc.).
|
|Various filtering parameters can be used to control which loci are used:
|
|- _--intervals_ will restrict analysis to variants within the described intervals
Expand All @@ -66,84 +70,111 @@ class EstimatePoolingFractions
@arg(flag='g', doc="Minimum genotype quality. Use -1 to disable.") val minGenotypeQuality: Int = 30,
@arg(flag='c', doc="Minimum (sequencing coverage @ SNP site / n_samples).") val minMeanSampleCoverage: Int = 6,
@arg(flag='m', doc="Minimum mapping quality.") val minMappingQuality: Int = 20,
@arg(flag='q', doc="Minimum base quality.") val minBaseQuality:Int = 5
@arg(flag='q', doc="Minimum base quality.") val minBaseQuality:Int = 5,
@arg(doc="Examine input reads by sample given in each read's read group.") bySample: Boolean = false
) extends FgBioTool with LazyLogging {
Io.assertReadable(vcf :: bam :: intervals.toList)

private val Ci99Width = 2.58 // Width of a 99% confidence interval in units of std err

private val AllReadGroupsName: String = "all"

/* Class to hold information about a single locus. */
case class Locus(chrom: String, pos: Int, ref: Char, alt: Char, expectedSampleFractions: Array[Double], var observedFraction: Option[Double] = None)
case class Locus(chrom: String,
pos: Int,
ref: Char,
alt: Char,
expectedSampleFractions: Array[Double],
var observedFraction: Map[String, Double] = Map.empty)

override def execute(): Unit = {
val vcfReader = new VCFFileReader(vcf.toFile)
val vcfReader = VcfSource(vcf)
val sampleNames = pickSamplesToUse(vcfReader)
val intervals = loadIntervals

// Get the expected fractions from the VCF
val vcfIterator = constructVcfIterator(vcfReader, intervals, sampleNames)
val loci = vcfIterator.filterNot(v => this.nonAutosomes.contains(v.getContig)).map { v => Locus(
chrom = v.getContig,
pos = v.getStart,
ref = v.getReference.getBaseString.charAt(0),
alt = v.getAlternateAllele(0).getBaseString.charAt(0),
expectedSampleFractions = sampleNames.map { s => val gt = v.getGenotype(s); if (gt.isHomRef) 0 else if (gt.isHet) 0.5 else 1.0 }
val loci = vcfIterator.filterNot(v => this.nonAutosomes.contains(v.chrom)).map { v => Locus(
chrom = v.chrom,
pos = v.pos,
ref = v.alleles.ref.bases.charAt(0),
alt = v.alleles.alts.head.value.charAt(0),
expectedSampleFractions = sampleNames.map { s =>
val gt = v.gt(s)
gt.get[IndexedSeq[Float]]("AF") match {
case None => if (gt.isHomRef) 0 else if (gt.isHet) 0.5 else 1.0
case Some(afs) => afs(0)
}
}
)}.toArray

logger.info(s"Loaded ${loci.length} bi-allelic SNPs from VCF.")

val coveredLoci = fillObserveredFractionAndFilter(loci, this.minMeanSampleCoverage * sampleNames.length)
fillObserveredFractionAndFilter(loci, this.minMeanSampleCoverage * sampleNames.length)

val observedSamples = loci.flatMap { locus => locus.observedFraction.keySet }.distinct.sorted
logger.info(f"Regressing on ${observedSamples.length}%,d input samples.")

logger.info(s"Regressing on ${coveredLoci.length} of ${loci.length} that met coverage requirements.")
val regression = new OLSMultipleLinearRegression
regression.setNoIntercept(true) // Intercept should be at 0!
regression.newSampleData(
coveredLoci.map(_.observedFraction.getOrElse(unreachable("observed fraction must be defined"))),
coveredLoci.map(_.expectedSampleFractions)
)

val regressionParams = regression.estimateRegressionParameters()
val total = regressionParams.sum
val fractions = regressionParams.map(_ / total)
val stderrs = regression.estimateRegressionParametersStandardErrors().map(_ / total)
logger.info(s"R^2 = ${regression.calculateRSquared()}")
logger.info(s"Sum of regression parameters = ${total}")

val metrics = sampleNames.toSeq.zipWithIndex.map { case (sample, index) =>
val sites = coveredLoci.count(l => l.expectedSampleFractions(index) > 0)
val singletons = coveredLoci.count(l => l.expectedSampleFractions(index) > 0 && l.expectedSampleFractions.sum == l.expectedSampleFractions(index))
PoolingFractionMetric(
sample = sample,
variant_sites = sites,
singletons = singletons,
estimated_fraction = fractions(index),
standard_error = stderrs(index),
ci99_low = max(0, fractions(index) - stderrs(index)*Ci99Width),
ci99_high = min(1, fractions(index) + stderrs(index)*Ci99Width))
}

Metric.write(output, metrics)
val metrics = observedSamples.flatMap { observedSample =>
logger.info(f"Examining $observedSample")
val (observedFractions, lociExpectedSampleFractions) = loci.flatMap { locus =>
locus.observedFraction.get(observedSample).map { observedFraction =>
(observedFraction, locus.expectedSampleFractions)
}
}.unzip
logger.info(f"Regressing on ${observedFractions.length}%,d of ${loci.length}%,d loci that met coverage requirements.")
regression.newSampleData(
observedFractions,
lociExpectedSampleFractions
)

val regressionParams = regression.estimateRegressionParameters()
val total = regressionParams.sum
val fractions = regressionParams.map(_ / total)
val stderrs = regression.estimateRegressionParametersStandardErrors().map(_ / total)
logger.info(s"R^2 = ${regression.calculateRSquared()}")
logger.info(s"Sum of regression parameters = ${total}")

if (regression.estimateRegressionParameters().exists(_ < 0)) {
logger.error("#################################################################################")
logger.error("# One or more samples is estimated to have fraction < 0. This is likely due to #")
logger.error("# incorrect samples being used, insufficient coverage and/or too few SNPs. #")
logger.error("#################################################################################")
fail(1)
}

if (regression.estimateRegressionParameters().exists(_ < 0)) {
logger.error("#################################################################################")
logger.error("# One or more samples is estimated to have fraction < 0. This is likely due to #")
logger.error("# incorrect samples being used, insufficient coverage and/or too few SNPs. #")
logger.error("#################################################################################")
fail(1)
sampleNames.toSeq.zipWithIndex.map { case (pool_sample, index) =>
val sites = lociExpectedSampleFractions.count(expectedSampleFractions => expectedSampleFractions(index) > 0)
val singletons = lociExpectedSampleFractions.count { expectedSampleFractions =>
expectedSampleFractions(index) > 0 && expectedSampleFractions.sum == expectedSampleFractions(index)
}
PoolingFractionMetric(
observed_sample = observedSample,
pool_sample = pool_sample,
variant_sites = sites,
singletons = singletons,
estimated_fraction = fractions(index),
standard_error = stderrs(index),
ci99_low = max(0, fractions(index) - stderrs(index)*Ci99Width),
ci99_high = min(1, fractions(index) + stderrs(index)*Ci99Width))
}
}

logger.info("Writing metrics")
Metric.write(output, metrics)
}

/** Verify a provided sample list, or if none provided retrieve the set of samples from the VCF. */
private def pickSamplesToUse(vcfReader: VCFFileReader): Array[String] = {
if (samples.nonEmpty) {
val samplesInVcf = vcfReader.getFileHeader.getSampleNamesInOrder.iterator.toSet
val missingSamples = samples.filterNot(samplesInVcf.contains)
private def pickSamplesToUse(vcfIn: VcfSource): Array[String] = {
if (this.samples.isEmpty) vcfIn.header.samples.toArray else {
val samplesInVcf = vcfIn.header.samples
val missingSamples = samples.toSet.diff(samplesInVcf.toSet)
if (missingSamples.nonEmpty) fail(s"Samples not present in VCF: ${missingSamples.mkString(", ")}")
else samples.toArray.sorted
}
else {
vcfReader.getFileHeader.getSampleNamesInOrder.iterator.toSeq.toArray.sorted // toSeq.toArray is necessary cos util.ArrayList.toArray() exists
}
}

/** Loads up and merges all the interval lists provided. Returns None if no intervals were specified. */
Expand All @@ -163,20 +194,18 @@ class EstimatePoolingFractions
}

/** Generates an iterator over non-filtered bi-allelic SNPs where all the required samples are genotyped. */
def constructVcfIterator(in: VCFFileReader, intervals: Option[IntervalList], samples: Array[String]): Iterator[VariantContext] = {
val vcfIterator: Iterator[VariantContext] = intervals match {
def constructVcfIterator(in: VcfSource, intervals: Option[IntervalList], samples: Seq[String]): Iterator[Variant] = {
val iterator: Iterator[Variant] = intervals match {
case None => in.iterator
case Some(is) => ByIntervalListVariantContextIterator(in, is)
case Some(is) => is.flatMap(i => in.query(i.getContig, i.getStart, i.getEnd))
}

val samplesAsUtilSet = CollectionUtil.makeSet(samples:_*)

vcfIterator
.filterNot(_.isFiltered)
.map(_.subContextFromSamples(samplesAsUtilSet, true))
.filter(v => v.isSNP && v.isBiallelic && !v.isMonomorphicInSamples)
.filter(_.getNoCallCount == 0)
.filter(v => v.getGenotypesOrderedByName.iterator.forall(gt => gt.getGQ >= this.minGenotypeQuality))
iterator
.filter(v => v.filters.isEmpty || v.filters == Variant.PassingFilters)
.filter(v => v.alleles.size == 2 && v.alleles.forall(a => a.value.length == 1)) // Just biallelic SNPs
.filter(v => samples.map(v.gt).forall(gt => gt.isFullyCalled && (this.minGenotypeQuality <= 0 || gt.get[Int]("GQ").exists(_ >= this.minGenotypeQuality))))
.map (v => v.copy(genotypes=v.genotypes.filter { case (s, _) => samples.contains(s)} ))
.filter(v => v.gts.flatMap(_.calls).toSet.size > 1) // Not monomorphic
}

/** Constructs a SamLocusIterator that will visit every locus in the input. */
Expand All @@ -192,45 +221,73 @@ class EstimatePoolingFractions
javaIteratorAsScalaIterator(iterator)
}

/** Computes the observed fraction of the alternate allele at the given locus*/
private def getObservedFraction(recordAndOffsets: Seq[SamLocusIterator.RecordAndOffset],
locus: Locus,
minCoverage: Int): Option[Double] = {
if (recordAndOffsets.length < minCoverage) None else {
val counts = BaseCounts(recordAndOffsets)
val (ref, alt) = (counts(locus.ref), counts(locus.alt))

// Somewhat redundant with check above, but this protects against a large fraction
// of Ns or other alleles, and also against a large proportion of overlapping reads
if (ref + alt < minCoverage) None else {
Some(alt / (ref + alt).toDouble)
}
}
}

/**
* Fills in the observedFraction field for each locus that meets coverage and then returns
* the subset of loci that met coverage.
*/
def fillObserveredFractionAndFilter(loci: Array[Locus], minCoverage: Int): Array[Locus] = {
def fillObserveredFractionAndFilter(loci: Array[Locus], minCoverage: Int): Unit = {
val locusIterator = constructBamIterator(loci)
locusIterator.zip(loci.iterator).foreach { case (locusInfo, locus) =>
if (locusInfo.getSequenceName != locus.chrom || locusInfo.getPosition != locus.pos) fail("VCF and BAM iterators out of sync.")

// A gross coverage check here to avoid a lot of work; better check below
if (locusInfo.getRecordAndOffsets.size() > minCoverage) {
val counts = BaseCounts(locusInfo)
val (ref, alt) = (counts(locus.ref), counts(locus.alt))

// Somewhat redundant with check above, but this protects against a large fraction
// of Ns or other alleles, and also against a large proportion of overlapping reads
if (ref + alt >= minCoverage) {
locus.observedFraction = Some(alt / (ref + alt).toDouble)
if (bySample) {
locus.observedFraction = locusInfo.getRecordAndOffsets.toSeq
.groupBy(_.getRecord.getReadGroup.getSample)
.flatMap { case (sample, recordAndOffsets) =>
val observedFraction = getObservedFraction(
recordAndOffsets = recordAndOffsets,
locus = locus,
minCoverage = minCoverage
)
observedFraction.map(frac => sample -> frac)
}
}
else {
val observedFraction = getObservedFraction(
recordAndOffsets = locusInfo.getRecordAndOffsets.toSeq,
locus = locus,
minCoverage = minCoverage
)
observedFraction.foreach { frac =>
locus.observedFraction = Map(AllReadGroupsName -> frac)
}
}
}

loci.filter(_.observedFraction.isDefined)
}
}

/**
* Metrics produced by `EstimatePoolingFractions` to quantify the estimated proportion of a sample
* mixture that is attributable to a specific sample with a known set of genotypes.
*
* @param sample The name of the sample within the pool being reported on.
* @param observed_sample The name of the input sample as reported in the read group, or "all" if all read groups are
* being treated as a single input sample.
* @param pool_sample The name of the sample within the pool being reported on.
* @param variant_sites How many sites were examined at which the reported sample is known to be variant.
* @param singletons How many of the variant sites were sites at which only this sample was variant.
* @param estimated_fraction The estimated fraction of the pool that comes from this sample.
* @param standard_error The standard error of the estimated fraction.
* @param ci99_low The lower bound of the 99% confidence interval for the estimated fraction.
* @param ci99_high The upper bound of the 99% confidence interval for the estimated fraction.
*/
case class PoolingFractionMetric(sample: String,
case class PoolingFractionMetric(observed_sample: String,
pool_sample: String,
variant_sites: Count,
singletons: Count,
estimated_fraction: Proportion,
Expand Down
Loading