Skip to content

Commit

Permalink
Fixing diagonal and spherical coveriance estimation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Apr 29, 2024
1 parent ceaeb21 commit 3f4614b
Showing 1 changed file with 10 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,38 +375,23 @@ public GaussianMixtureModel train(Dataset<ClusterID> examples, Map<String, Prove
for (int j = 0; j < numGaussians; j++) {
// Compute covariance contribution from current input
DenseVector curCov = (DenseVector) input[j];
double curResp = v.responsibility.get(j);
double mixing = newMixingDistribution.get(j);
for (int k = 0; k < numFeatures; k++) {
double currentCovValue = curCov.get(k);
double curMean = meanVectors[j].get(k);
double curData = v.data.get(k);
double dataSq = curResp * curData * curData / mixing;
double meanSq = curMean * curMean;
double dataMean = 2 * curResp * curData * curMean / mixing;
double update = currentCovValue + dataSq - dataMean + meanSq;
curCov.set(k, update);
}
DenseVector diff = (DenseVector) v.data.subtract(meanVectors[j]);
diff.foreachInPlace(a -> a * a);
diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j));
curCov.intersectAndAddInPlace(diff);
}
return input;
};
case SPHERICAL -> (Tensor[] input, Vectors v) -> {
for (int j = 0; j < numGaussians; j++) {
// Compute covariance contribution from current input
DenseVector curCov = (DenseVector) input[j];
double curResp = v.responsibility.get(j);
double mixing = newMixingDistribution.get(j);
double update = 0;
for (int k = 0; k < numFeatures; k++) {
double curMean = meanVectors[j].get(k);
double curData = v.data.get(k);
double dataSq = curResp * curData * curData / mixing;
double meanSq = curMean * curMean;
double dataMean = 2 * curResp * curData * curMean / mixing;
update += dataSq + meanSq - dataMean;
}
update = update / numFeatures;
curCov.scalarAddInPlace(update);
DenseVector diff = (DenseVector) v.data.subtract(meanVectors[j]);
diff.foreachInPlace(a -> a * a);
diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j));
double mean = diff.sum() / numFeatures;
diff.set(mean);
curCov.intersectAndAddInPlace(diff);
}
return input;
};
Expand Down

0 comments on commit 3f4614b

Please sign in to comment.