Skip to content

Commit

Permalink
Cache transposed A matrix and fix NaN bug (#6)
Browse files Browse the repository at this point in the history
* Cache A transpose for faster subspace iterations

* Debug NaN (#5)

* Fixes for edge-case samples with a median coverage of 0

MosdepthUtils: cache raw input matrix in addition to normalized
NormalizationOperations: enforce a non-zero minimum depth
NGSPCA: pass filename for temporary raw matrix
  • Loading branch information
jlanej authored Jan 8, 2020
1 parent 47a8f6e commit 846fb69
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
13 changes: 8 additions & 5 deletions ngspca/src/main/java/org/pankratzlab/ngspca/MosdepthUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ static List<String> getRegionsToUse(String mosDepthResultFile, REGION_STRATEGY r
* @throws ExecutionException
*/
static BlockRealMatrix processFiles(List<String> mosDepthResultFiles, Set<String> ucscRegions,
int threads,
String tmpRawFile, int threads,
Logger log) throws InterruptedException, ExecutionException {
if (mosDepthResultFiles.isEmpty()) {
String err = "No input files provided";
log.severe(err);
throw new IllegalArgumentException(err);
}
return loadAndNormalizeData(mosDepthResultFiles, ucscRegions, threads, log);
return loadAndNormalizeData(mosDepthResultFiles, ucscRegions, tmpRawFile, threads, log);
}

/**
Expand All @@ -86,8 +86,8 @@ static BlockRealMatrix processFiles(List<String> mosDepthResultFiles, Set<String
*/

private static BlockRealMatrix loadAndNormalizeData(List<String> mosDepthResultFiles,
Set<String> ucscRegions, int threads,
Logger log) {
Set<String> ucscRegions, String tmpRawFile,
int threads, Logger log) {

log.info("Initializing matrix to " + mosDepthResultFiles.size() + " columns and "
+ ucscRegions.size() + " rows");
Expand Down Expand Up @@ -142,8 +142,11 @@ private static BlockRealMatrix loadAndNormalizeData(List<String> mosDepthResultF

}
executor.shutdown();
log.info("Saving temporary raw matrix to " + tmpRawFile);
FileOps.writeSerial(dm, tmpRawFile, log);

log.info("Normalizing input matrix");
NormalizationOperations.foldChangeAndCenterRows(dm);
NormalizationOperations.foldChangeAndCenterRows(dm, log);
return dm;

}
Expand Down
19 changes: 12 additions & 7 deletions ngspca/src/main/java/org/pankratzlab/ngspca/NGSPCA.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,22 @@ private static void run(String input, String outputDir, String bedExclude,
log.info("Sampled " + regions.size() + " bins");

}
String tmpDm = outputDir + "tmp.mat.ser.gz";
// Store the raw input matrix
String tmpRawDm = outputDir + "tmp.raw.ser.gz";
// Store the temporary input matrix
String tmpNormDm = outputDir + "tmp.mat.ser.gz";

// populate input matrix and normalize
BlockRealMatrix dm;
if (!FileOps.fileExists(tmpDm) || overwrite) {
dm = MosdepthUtils.processFiles(mosDepthResultFiles, new HashSet<String>(regions), threads,
log);
FileOps.writeSerial(dm, tmpDm, log);
if (!FileOps.fileExists(tmpNormDm) || overwrite) {
dm = MosdepthUtils.processFiles(mosDepthResultFiles, new HashSet<String>(regions), tmpRawDm,
threads, log);
FileOps.writeSerial(dm, tmpNormDm, log);
} else {
log.info("Loading existing serialized file " + tmpDm);
dm = (BlockRealMatrix) FileOps.readSerial(tmpDm, log);
System.out.print("Loading");
System.err.print("Loading");
log.info("Loading existing serialized file " + tmpNormDm);
dm = (BlockRealMatrix) FileOps.readSerial(tmpNormDm, log);
}

log.info("Oversampling set to: " + numOversamples);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.pankratzlab.ngspca;

import java.util.logging.Logger;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.stat.ranking.NaNStrategy;
Expand All @@ -23,9 +24,9 @@ private NormalizationOperations() {
*
* @param m an {@link RealMatrix} that has been FC-ed by column and centered by row
*/
static void foldChangeAndCenterRows(RealMatrix dm) {
static void foldChangeAndCenterRows(RealMatrix dm, Logger log) {
// compute fold change
computeFoldChangeByColumn(dm);
computeFoldChangeByColumn(dm, log);
// center rows to median of 0
centerRowsToMedian(dm);
}
Expand All @@ -35,7 +36,7 @@ static void foldChangeAndCenterRows(RealMatrix dm) {
*
* @param dm the {@link RealMatrix} that will be converted
*/
private static void computeFoldChangeByColumn(RealMatrix dm) {
private static void computeFoldChangeByColumn(RealMatrix dm, Logger log) {
double[] medians = new double[dm.getColumnDimension()];

// convert columns to log2 fold-change from median
Expand All @@ -44,12 +45,16 @@ private static void computeFoldChangeByColumn(RealMatrix dm) {
for (int row = 0; row < dm.getRowDimension(); row++) {
tmp[row] += dm.getEntry(row, column);
}
medians[column] = median(tmp);
medians[column] = Math.max(median(tmp), MIN_DEPTH);
}
for (int row = 0; row < dm.getRowDimension(); row++) {
for (int column = 0; column < dm.getColumnDimension(); column++) {
double entry = dm.getEntry(row, column);
double standard = log2(Math.max(entry, MIN_DEPTH) / medians[column]);
if (Double.isNaN(standard)) {
throw new IllegalArgumentException("Invalid sample normalized value ("
+ Double.toString(Double.NaN) + ") detected");
}
dm.setEntry(row, column, standard);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.stat.ranking.NaNStrategy;
import Jama.Matrix;
import Jama.QRDecomposition;

Expand Down Expand Up @@ -64,6 +66,7 @@ public void fit(BlockRealMatrix A, int numberOfComponentsToStore, int niters, in
+ numComponents);
}
log.info("Initializing matrices");

int m = A.getRowDimension();
int n = A.getColumnDimension();
transpose = m < n;
Expand All @@ -78,6 +81,10 @@ public void fit(BlockRealMatrix A, int numberOfComponentsToStore, int niters, in
log.info("Selecting randomized Q");

RealMatrix Y = A.multiply(randn(n, Math.min(n, numComponents + numOversamples), randomSeed));

log.info("Caching A_t");
BlockRealMatrix A_t = A.transpose();

log.info("Beginning LU decomp iterations");
for (int i = 0; i < niters; i++) {
log.info("Subspace iteration: " + Integer.toString(i));
Expand All @@ -86,7 +93,7 @@ public void fit(BlockRealMatrix A, int numberOfComponentsToStore, int niters, in
log.info("Converting to RealMatrix");
Y = MatrixUtils.createRealMatrix(qr.getQ().getArray());
log.info("Computing A Y cross prod");
RealMatrix Z = A.transpose().multiply(Y);
RealMatrix Z = A_t.multiply(Y);
log.info("Z QR decomp");
Z = MatrixUtils.createRealMatrix(new QRDecomposition(new Matrix(Z.getData())).getQ()
.getArray());
Expand Down

0 comments on commit 846fb69

Please sign in to comment.