Skip to content

Commit

Permalink
Adding a mixture distribution and a distribution interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed May 2, 2024
1 parent 3f4614b commit 8555926
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 9 deletions.
19 changes: 19 additions & 0 deletions Core/src/main/java/org/tribuo/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,25 @@ public static double[] generateCDF(double[] pmf) {
return cumulativeSum(pmf);
}

/**
* Validates that the supplied double array is a probability mass function.
* <p>
* That is, each element is bounded 0,1 and all elements sum to 1.
* @param pmf The PMF to check.
* @return True if it's a valid pmf.
*/
public static boolean validatePMF(double[] pmf) {
double total = 0.0;
for (double v : pmf) {
if ((v < 0) || (v > 1.0)) {
return false;
} else {
total += v;
}
}
return !(Math.abs(total - 1.0) > 1e-10);
}

/**
* Produces a cumulative sum array.
* @param input The input to sum.
Expand Down
51 changes: 51 additions & 0 deletions Math/src/main/java/org/tribuo/math/distributions/Distribution.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.distributions;

import org.tribuo.math.la.DenseVector;

import java.util.random.RandomGenerator;

/**
* Interface for probability distributions which can be sampled from.
* <p>
* The vector sampled represents a single sample from that (possibly multivariate)
* distribution rather than a sequence of samples.
*/
public interface Distribution {

/**
* Sample a single vector from this probability distribution.
* @return A vector sampled from the distribution.
*/
DenseVector sampleVector();

/**
* Sample a single vector from this probability distribution using the supplied RNG.
* @param otherRNG The RNG to use.
* @return A vector sampled from this distribution.
*/
DenseVector sampleVector(RandomGenerator otherRNG);

/**
* Sample a vector from this probability distribution and return it as an array.
* @return An array sampled from this distribution.
*/
default double[] sampleArray() {
return sampleVector().toArray();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.distributions;

import org.tribuo.math.la.DenseVector;
import org.tribuo.util.Util;

import java.util.Arrays;
import java.util.List;
import java.util.SplittableRandom;
import java.util.random.RandomGenerator;

/**
* A mixture distribution which samples from a set of internal distributions mixed by some probability distribution.
* @param <T> The inner distribution type.
*/
public final class MixtureDistribution<T extends Distribution> implements Distribution {

private final List<T> dists;

private final double[] mixingDistribution;

private final double[] cdf;

private final RandomGenerator rng;

private final long seed;

/**
* Construct a mixture distribution over the supplied components.
* @param distributions The distribution components.
* @param mixingDistribution The mixing distribution, must be a valid PMF.
* @param seed The RNG seed.
*/
public MixtureDistribution(List<T> distributions, DenseVector mixingDistribution, long seed) {
this(distributions, mixingDistribution.toArray(), seed);
}

/**
* Construct a mixture distribution over the supplied components.
* @param distributions The distribution components.
* @param mixingDistribution The mixing distribution, must be a valid PMF.
* @param seed The RNG seed.
*/
public MixtureDistribution(List<T> distributions, double[] mixingDistribution, long seed) {
this.dists = List.copyOf(distributions);
this.mixingDistribution = Arrays.copyOf(mixingDistribution, mixingDistribution.length);
this.seed = seed;
this.rng = new SplittableRandom(seed);
if (dists.size() != this.mixingDistribution.length) {
throw new IllegalArgumentException("Invalid distribution, expected the same number of components as probabilities, found " + dists.size() + " components, and " + this.mixingDistribution.length + " probabilities");
}
if (!Util.validatePMF(this.mixingDistribution)) {
throw new IllegalArgumentException("Invalid mixing distribution, was not a valid PMF, found " + Arrays.toString(this.mixingDistribution));
}
this.cdf = Util.generateCDF(this.mixingDistribution);
}

/**
* Returns the number of distributions.
* @return The number of distributions.
*/
public int getNumComponents() {
return dists.size();
}

/**
* Return a mixture component.
* @param i The index of the mixture component.
* @return The ith component.
*/
public T getComponent(int i) {
return dists.get(i);
}

/**
* Returns a copy of the mixing distribution.
* @return A copy of the mixing distribution.
*/
public double[] getMixingDistribution() {
return Arrays.copyOf(mixingDistribution, mixingDistribution.length);
}

@Override
public DenseVector sampleVector() {
return sampleVector(rng);
}

@Override
public DenseVector sampleVector(RandomGenerator otherRNG) {
int idx = Util.sampleFromCDF(cdf, otherRNG);
return dists.get(idx).sampleVector();
}

@Override
public String toString() {
return "Mixture(seed="+seed+",mixingDistribution="+ Arrays.toString(mixingDistribution) +",components="+dists+")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
/**
* A class for sampling from multivariate normal distributions.
*/
public final class MultivariateNormalDistribution {
public final class MultivariateNormalDistribution implements Distribution {

private final long seed;
private final Random rng;
Expand Down Expand Up @@ -231,6 +231,7 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova
* Sample a vector from this multivariate normal distribution.
* @return A sample from this distribution.
*/
@Override
public DenseVector sampleVector() {
return sampleVector(rng);
}
Expand All @@ -239,6 +240,7 @@ public DenseVector sampleVector() {
* Sample a vector from this multivariate normal distribution.
* @return A sample from this distribution.
*/
@Override
public DenseVector sampleVector(RandomGenerator otherRNG) {
DenseVector sampled = new DenseVector(means.size());
for (int i = 0; i < means.size(); i++) {
Expand All @@ -256,14 +258,6 @@ public DenseVector sampleVector(RandomGenerator otherRNG) {
return sampled;
}

/**
* Sample a vector from this multivariate normal distribution.
* @return A sample from this distribution.
*/
public double[] sampleArray() {
return sampleVector().toArray();
}

/**
* Gets a copy of the mean vector.
* @return A copy of the mean vector.
Expand Down

0 comments on commit 8555926

Please sign in to comment.