Skip to content

Commit

Permalink
Improved UnscentedTransformProvider methods.
Browse files Browse the repository at this point in the history
Reduced duplications in UnscentedKalmanFilter.
  • Loading branch information
MaximeJo committed Mar 22, 2024
1 parent 402d89e commit b4fdee4
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.linear.SemiDefinitePositiveCholeskyDecomposition;
Expand Down Expand Up @@ -82,43 +81,6 @@ public RealVector[] unscentedTransform(final RealVector state, final RealMatrix

}

/** {@inheritDoc}. */
@Override
public Pair<RealVector, RealMatrix> inverseUnscentedTransform(final RealVector[] sigmaPoints) {

// State dimension
final int stateDimension = sigmaPoints[0].getDimension();

// Compute weighted mean
// ---------------------

RealVector weightedMean = new ArrayRealVector(stateDimension);

// Compute the weight coefficients wm
final RealVector wm = getWm();

// Weight each sigma point and sum them
for (int i = 0; i <= 2 * stateDimension; i++) {
weightedMean = weightedMean.add(sigmaPoints[i].mapMultiply(wm.getEntry(i)));
}

// Compute covariance matrix
// -------------------------

RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(stateDimension, stateDimension);

// Compute the weight coefficients wc
final RealVector wc = getWc();

// Reconstruct the covariance
for (int i = 0; i <= 2 * stateDimension; i++) {
final RealMatrix diff = MatrixUtils.createColumnRealMatrix(sigmaPoints[i].subtract(weightedMean).toArray());
covarianceMatrix = covarianceMatrix.add(diff.multiplyTransposed(diff).scalarMultiply(wc.getEntry(i)));
}

return new Pair<>(weightedMean, covarianceMatrix);
}

/**
* Get the factor applied to the covariance matrix during the unscented transform.
* @return the factor applied to the covariance matrix during the unscented transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.hipparchus.util;

import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;

Expand All @@ -33,12 +35,84 @@ public interface UnscentedTransformProvider {
*/
RealVector[] unscentedTransform(RealVector state, RealMatrix covariance);

/**
* Computes a weighted mean state from a given set of sigma points.
* <p>
* This method can be used for computing both the mean state and the mean measurement
* in an Unscented Kalman filter.
* <p>
* It corresponds to Equation 17 of "Wan, E. A., & Van Der Merwe, R. The unscented Kalman filter for nonlinear estimation"
* <p>
* @param sigmaPoints input samples
* @return weighted mean state
*/
default RealVector getUnscentedMeanState(RealVector[] sigmaPoints) {

// Sigma point dimension
final int sigmaPointDimension = sigmaPoints[0].getDimension();

// Compute weighted mean
// ---------------------

RealVector weightedMean = new ArrayRealVector(sigmaPointDimension);

// Compute the weight coefficients wm
final RealVector wm = getWm();

// Weight each sigma point and sum them
for (int i = 0; i < sigmaPoints.length; i++) {
weightedMean = weightedMean.add(sigmaPoints[i].mapMultiply(wm.getEntry(i)));
}

return weightedMean;
}

/** Computes the unscented covariance matrix from a weighted mean state and a set of sigma points.
* <p>
* This method can be used for computing both the predicted state
* covariance matrix and the innovation covariance matrix in an Unscented Kalman filter.
* <p>
* It corresponds to Equation 18 of "Wan, E. A., & Van Der Merwe, R. The unscented Kalman filter for nonlinear estimation"
* <p>
* @param sigmaPoints input sigma points
* @param meanState weighted mean state
* @return the unscented covariance matrix
*/
default RealMatrix getUnscentedCovariance(RealVector[] sigmaPoints, RealVector meanState) {

// State dimension
final int stateDimension = meanState.getDimension();

// Compute covariance matrix
// -------------------------

RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(stateDimension, stateDimension);

// Compute the weight coefficients wc
final RealVector wc = getWc();

// Reconstruct the covariance
for (int i = 0; i < sigmaPoints.length; i++) {
final RealMatrix diff = MatrixUtils.createColumnRealMatrix(sigmaPoints[i].subtract(meanState).toArray());
covarianceMatrix = covarianceMatrix.add(diff.multiplyTransposed(diff).scalarMultiply(wc.getEntry(i)));
}

return covarianceMatrix;
}

/**
* Perform the inverse unscented transform from an array of sigma points.
* @param sigmaPoints array containing the sigma points of the unscented transform
* @return mean state and associated covariance
*/
Pair<RealVector, RealMatrix> inverseUnscentedTransform(RealVector[] sigmaPoints);
default Pair<RealVector, RealMatrix> inverseUnscentedTransform(RealVector[] sigmaPoints) {

// Mean state
final RealVector meanState = getUnscentedMeanState(sigmaPoints);

// Return state and covariance
return new Pair<>(meanState, getUnscentedCovariance(sigmaPoints, meanState));
}

/**
* Get the covariance weights.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ public void testInverseUnscentedTransform() {
MatrixUtils.createRealVector(new double[] {0.0, 1.0}),
MatrixUtils.createRealVector(new double[] {1.0, 0.0})};
// Action
final Pair<RealVector, RealMatrix> out = julier.inverseUnscentedTransform(sigmaPoints);
final RealVector state = out.getFirst();
final RealMatrix covariance = out.getSecond();
final Pair<RealVector, RealMatrix> inverse = julier.inverseUnscentedTransform(sigmaPoints);
final RealVector state = inverse.getFirst();
final RealMatrix covariance = inverse.getSecond();

// Verify
Assert.assertEquals(2, state.getDimension());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void testUnscentedTransform() {
checkSigmaPoint(sigma[3], 0.5, 1.0);
checkSigmaPoint(sigma[4], 1.0, 0.5);
}

/** Test inverse unscented transform */
@Test
public void testInverseUnscentedTransform() {
Expand All @@ -88,9 +88,9 @@ public void testInverseUnscentedTransform() {
MatrixUtils.createRealVector(new double[] {0.5, 1.0}),
MatrixUtils.createRealVector(new double[] {1.0, 0.5})};
// Action
final Pair<RealVector, RealMatrix> out = merwe.inverseUnscentedTransform(sigmaPoints);
final RealVector state = out.getFirst();
final RealMatrix covariance = out.getSecond();
final Pair<RealVector, RealMatrix> inverse = merwe.inverseUnscentedTransform(sigmaPoints);
final RealVector state = inverse.getFirst();
final RealMatrix covariance = inverse.getSecond();

// Verify
Assert.assertEquals(2, state.getDimension());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.hipparchus.filtering.kalman.KalmanFilter;
import org.hipparchus.filtering.kalman.Measurement;
import org.hipparchus.filtering.kalman.ProcessEstimate;
import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.MatrixDecomposer;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
Expand Down Expand Up @@ -113,7 +112,7 @@ public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final R

// Correction phase
final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
final RealVector predictedMeasurement = sum(predictedMeasurements, measurement.getValue().getDimension());
final RealVector predictedMeasurement = utProvider.getUnscentedMeanState(predictedMeasurements);
final RealMatrix r = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
final RealMatrix crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
predictedMeasurements, predictedMeasurement);
Expand All @@ -131,10 +130,10 @@ public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final R
private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {

// Computation of Eq. 17, weighted mean state
final RealVector predictedState = sum(predictedStates, n);
final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);

// Computation of Eq. 18, predicted covariance matrix
final RealMatrix predictedCovariance = computeCovariance(predictedStates, predictedState).add(noise);
final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);

predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
corrected = null;
Expand Down Expand Up @@ -220,7 +219,8 @@ private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predicte
return null;
}
// Computation of the innovation covariance matrix
final RealMatrix innovationCovarianceMatrix = computeCovariance(predictedMeasurements, predictedMeasurement);
final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);

// Add the measurement covariance
return innovationCovarianceMatrix.add(r);
}
Expand Down Expand Up @@ -254,70 +254,7 @@ private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStat
return crossCovarianceMatrix;
}

/**
* Computes a weighted mean parameter from a given samples.
* <p>
* This method can be used for computing both the mean state and the mean measurement.
* <p>
* It corresponds to the Equation 17 of "Wan, E. A., & Van Der Merwe, R.
* The unscented Kalman filter for nonlinear estimation"
* </p>
* @param samples input samples
* @param size size of the weighted mean parameter
* @return weighted mean parameter
*/
private RealVector sum(final RealVector[] samples, final int size) {

// Initialize the weighted mean parameter
RealVector mean = new ArrayRealVector(size);

// Mean weights
final RealVector wm = utProvider.getWm();

// Compute weighted mean parameter
for (int i = 0; i <= 2 * n; i++) {
mean = mean.add(samples[i].mapMultiply(wm.getEntry(i)));
}

// Return the weighted mean value
return mean;

}

/** Computes the covariance matrix.
* <p>
* This method can be used for computing both the predicted state
* covariance matrix and the innovation covariance matrix.
* <p>
* It corresponds to the Equation 18 of "Wan, E. A., & Van Der Merwe, R.
* The unscented Kalman filter for nonlinear estimation"
* </p>
* @param samples input samples
* @param state weighted mean parameter
* @return the covariance matrix
*/
private RealMatrix computeCovariance(final RealVector[] samples,
final RealVector state) {

// Initialize the covariance matrix, by using the size of the weighted mean parameter
final int dim = state.getDimension();
RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(dim, dim);

// Covariance weights
final RealVector wc = utProvider.getWc();

// Compute the covariance matrix
for (int i = 0; i <= 2 * n; i++) {
final RealVector diff = samples[i].subtract(state);
covarianceMatrix = covarianceMatrix.add(outer(diff, diff).scalarMultiply(wc.getEntry(i)));
}

// Return the covariance
return covarianceMatrix;

}

/** Conputes the outer product of two vectors.
/** Computes the outer product of two vectors.
* @param a first vector
* @param b second vector
* @return the outer product of a and b
Expand All @@ -336,7 +273,5 @@ private RealMatrix outer(final RealVector a, final RealVector b) {

// Return
return outMatrix;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.hipparchus.random.Well1024a;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.MerweUnscentedTransform;
import org.hipparchus.util.Pair;
import org.hipparchus.util.UnscentedTransformProvider;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -72,14 +71,6 @@ public RealVector getWm() {
public RealVector getWc() {
return new ArrayRealVector();
}

@Override
public Pair<RealVector, RealMatrix> inverseUnscentedTransform(RealVector[] sigmaPoints) {

final int stateDimension = sigmaPoints[0].getDimension();

return new Pair<>(sigmaPoints[0], MatrixUtils.createRealIdentityMatrix(stateDimension));
}
};

new UnscentedKalmanFilter<>(new CholeskyDecomposer(1.0e-15, 1.0e-15), process, initial, provider);
Expand Down

0 comments on commit b4fdee4

Please sign in to comment.