diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index b257fb59..9571b434 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -375,18 +375,10 @@ public GaussianMixtureModel train(Dataset examples, Map a * a); + diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j)); + curCov.intersectAndAddInPlace(diff); } return input; }; @@ -394,19 +386,12 @@ public GaussianMixtureModel train(Dataset examples, Map a * a); + diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j)); + double mean = diff.sum() / numFeatures; + diff.set(mean); + curCov.intersectAndAddInPlace(diff); } return input; };