Skip to content

Commit

Permalink
Fixing parallel reduction by converting it into collect.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed May 20, 2024
1 parent 8555926 commit 356a355
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.ToDoubleFunction;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -342,24 +344,24 @@ public GaussianMixtureModel train(Dataset<ClusterID> examples, Map<String, Prove
Stream<SGDVector> dataMStream = Arrays.stream(data);
Stream<DenseVector> resMStream = Arrays.stream(responsibilities);
Stream<Vectors> zipMStream = StreamUtil.zip(dataMStream, resMStream, Vectors::new);
Tensor[] zeroTensorArr = switch (covarianceType) {
case FULL -> {
Supplier<Tensor[]> zeroTensor = switch (covarianceType) {
case FULL -> () -> {
Tensor[] output = new Tensor[numGaussians];
for (int j = 0; j < numGaussians; j++) {
output[j] = new DenseMatrix(numFeatures, numFeatures);
}
yield output;
}
case DIAGONAL, SPHERICAL -> {
return output;
};
case DIAGONAL, SPHERICAL -> () -> {
Tensor[] output = new Tensor[numGaussians];
for (int j = 0; j < numGaussians; j++) {
output[j] = new DenseVector(numFeatures);
}
yield output;
}
return output;
};
};
// Fix parallel behaviour
BiFunction<Tensor[], Vectors, Tensor[]> mStep = switch (covarianceType) {
BiConsumer<Tensor[], Vectors> mStep = switch (covarianceType) {
case FULL -> (Tensor[] input, Vectors v) -> {
for (int j = 0; j < numGaussians; j++) {
// Compute covariance contribution from current input
Expand All @@ -369,7 +371,6 @@ public GaussianMixtureModel train(Dataset<ClusterID> examples, Map<String, Prove
diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j));
curCov.intersectAndAddInPlace(diff.outer(diff));
}
return input;
};
case DIAGONAL -> (Tensor[] input, Vectors v) -> {
for (int j = 0; j < numGaussians; j++) {
Expand All @@ -380,7 +381,6 @@ public GaussianMixtureModel train(Dataset<ClusterID> examples, Map<String, Prove
diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j));
curCov.intersectAndAddInPlace(diff);
}
return input;
};
case SPHERICAL -> (Tensor[] input, Vectors v) -> {
for (int j = 0; j < numGaussians; j++) {
Expand All @@ -393,33 +393,27 @@ public GaussianMixtureModel train(Dataset<ClusterID> examples, Map<String, Prove
diff.set(mean);
curCov.intersectAndAddInPlace(diff);
}
return input;
};
};
BinaryOperator<Tensor[]> combineTensor = (Tensor[] a, Tensor[] b) -> {
Tensor[] output = new Tensor[a.length];
BiConsumer<Tensor[], Tensor[]> combineTensor = (Tensor[] a, Tensor[] b) -> {
for (int j = 0; j < a.length; j++) {
if (a[j] instanceof DenseMatrix aMat && b[j] instanceof DenseMatrix bMat) {
output[j] = aMat.add(bMat);
aMat.intersectAndAddInPlace(bMat);
} else if (a[j] instanceof DenseVector aVec && b[j] instanceof DenseVector bVec) {
output[j] = aVec.add(bVec);
aVec.intersectAndAddInPlace(bVec);
} else {
throw new IllegalStateException("Invalid types in reduce, expected both DenseMatrix or DenseVector, found " + a[j].getClass() + " and " + b[j].getClass());
}
}
return output;
};
if (parallel) {
throw new RuntimeException("Parallel mstep not implemented");
/*
try {
covariances = fjp.submit(() -> zipMStream.parallel().reduce(zeroTensorArr, mStep, combineTensor)).get();
covariances = fjp.submit(() -> zipMStream.parallel().collect(zeroTensor, mStep, combineTensor)).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
}
*/
} else {
covariances = zipMStream.reduce(zeroTensorArr, mStep, combineTensor);
covariances = zipMStream.collect(zeroTensor, mStep, combineTensor);
}

// renormalize mixing distribution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public class TestGMM {
private static final GMMTrainer diagonal = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.DIAGONAL,
GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1);

private static final GMMTrainer fullParallel = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL,
GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 4, 1);

private static final GMMTrainer plusPlusFull = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL,
GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1);

Expand Down Expand Up @@ -75,6 +78,11 @@ public void testPlusPlusFullEvaluation() {
runEvaluation(plusPlusFull);
}

@Test
public void testParallelEvaluation() {
runEvaluation(fullParallel);
}

public static void runEvaluation(GMMTrainer trainer) {
Dataset<ClusterID> data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L));
Dataset<ClusterID> test = ClusteringDataGenerator.gaussianClusters(500, 2L);
Expand Down Expand Up @@ -150,7 +158,6 @@ public void testPlusPlusInvalidExample() {
runInvalidExample(plusPlusFull);
}


public void runEmptyExample(GMMTrainer trainer) {
assertThrows(IllegalArgumentException.class, () -> {
Pair<Dataset<ClusterID>, Dataset<ClusterID>> p = ClusteringDataGenerator.denseTrainTest();
Expand Down Expand Up @@ -186,7 +193,7 @@ public void testSetInvocationCount() {

// The number of times to call train before final training.
// Original trainer will be trained numOfInvocations + 1 times
// New trainer will have it's invocation count set to numOfInvocations then trained once
// New trainer will have its invocation count set to numOfInvocations then trained once
int numOfInvocations = 2;

// Create the first model and train it numOfInvocations + 1 times
Expand Down

0 comments on commit 356a355

Please sign in to comment.