Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tdigest getPMF() and getCDF() #612

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.datasketches.quantilescommon;

import static org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE;
import org.apache.datasketches.common.SketchesArgumentException;

/**
* The Quantiles API for item type <i>double</i>.
Expand All @@ -33,7 +34,7 @@ public interface QuantilesDoublesAPI extends QuantilesAPI {
* This is equivalent to {@link #getCDF(double[], QuantileSearchCriteria) getCDF(splitPoints, INCLUSIVE)}
* @param splitPoints an array of <i>m</i> unique, monotonically increasing items.
* @return a discrete CDF array of m+1 double ranks (or cumulative probabilities) on the interval [0.0, 1.0].
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double[] getCDF(double[] splitPoints) {
return getCDF(splitPoints, INCLUSIVE);
Expand Down Expand Up @@ -70,7 +71,7 @@ default double[] getCDF(double[] splitPoints) {
*
* @param searchCrit the desired search criteria.
* @return a discrete CDF array of m+1 double ranks (or cumulative probabilities) on the interval [0.0, 1.0].
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double[] getCDF(double[] splitPoints, QuantileSearchCriteria searchCrit);

Expand All @@ -79,7 +80,7 @@ default double[] getCDF(double[] splitPoints) {
* item returned by <i>getQuantile(1.0)</i>.
*
* @return the maximum item of the stream
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double getMaxItem();

Expand All @@ -88,15 +89,15 @@ default double[] getCDF(double[] splitPoints) {
* item returned by <i>getQuantile(0.0)</i>.
*
* @return the minimum item of the stream
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double getMinItem();

/**
* This is equivalent to {@link #getPMF(double[], QuantileSearchCriteria) getPMF(splitPoints, INCLUSIVE)}
* @param splitPoints an array of <i>m</i> unique, monotonically increasing items.
* @return a PMF array of m+1 probability masses as doubles on the interval [0.0, 1.0].
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double[] getPMF(double[] splitPoints) {
return getPMF(splitPoints, INCLUSIVE);
Expand Down Expand Up @@ -140,15 +141,15 @@ default double[] getPMF(double[] splitPoints) {
*
* @param searchCrit the desired search criteria.
* @return a PMF array of m+1 probability masses as doubles on the interval [0.0, 1.0].
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double[] getPMF(double[] splitPoints, QuantileSearchCriteria searchCrit);

/**
* This is equivalent to {@link #getQuantile(double, QuantileSearchCriteria) getQuantile(rank, INCLUSIVE)}
* @param rank the given normalized rank, a double in the range [0.0, 1.0].
* @return the approximate quantile given the normalized rank.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double getQuantile(double rank) {
return getQuantile(rank, INCLUSIVE);
Expand All @@ -163,7 +164,7 @@ default double getQuantile(double rank) {
* If EXCLUSIVE, he given rank includes all quantiles &lt;
* the quantile directly corresponding to the given rank.
* @return the approximate quantile given the normalized rank.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
* @see org.apache.datasketches.quantilescommon.QuantileSearchCriteria
*/
double getQuantile(double rank, QuantileSearchCriteria searchCrit);
Expand All @@ -180,7 +181,7 @@ default double getQuantile(double rank) {
* @param rank the given normalized rank
* @return the lower bound of the quantile confidence interval in which the quantile of the
* given rank exists.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double getQuantileLowerBound(double rank);

Expand All @@ -196,7 +197,7 @@ default double getQuantile(double rank) {
* @param rank the given normalized rank
* @return the upper bound of the quantile confidence interval in which the true quantile of the
* given rank exists.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
double getQuantileUpperBound(double rank);

Expand All @@ -205,7 +206,7 @@ default double getQuantile(double rank) {
* @param ranks the given array of normalized ranks, each of which must be
* in the interval [0.0,1.0].
* @return an array of quantiles corresponding to the given array of normalized ranks.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double[] getQuantiles(double[] ranks) {
return getQuantiles(ranks, INCLUSIVE);
Expand All @@ -219,7 +220,7 @@ default double[] getQuantiles(double[] ranks) {
* @param searchCrit if INCLUSIVE, the given ranks include all quantiles &le;
* the quantile directly corresponding to each rank.
* @return an array of quantiles corresponding to the given array of normalized ranks.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
* @see org.apache.datasketches.quantilescommon.QuantileSearchCriteria
*/
double[] getQuantiles(double[] ranks, QuantileSearchCriteria searchCrit);
Expand All @@ -228,7 +229,7 @@ default double[] getQuantiles(double[] ranks) {
* This is equivalent to {@link #getRank(double, QuantileSearchCriteria) getRank(quantile, INCLUSIVE)}
* @param quantile the given quantile
* @return the normalized rank corresponding to the given quantile
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double getRank(double quantile) {
return getRank(quantile, INCLUSIVE);
Expand All @@ -240,7 +241,7 @@ default double getRank(double quantile) {
* @param quantile the given quantile
* @param searchCrit if INCLUSIVE the given quantile is included into the rank.
* @return the normalized rank corresponding to the given quantile
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
* @see org.apache.datasketches.quantilescommon.QuantileSearchCriteria
*/
double getRank(double quantile, QuantileSearchCriteria searchCrit);
Expand All @@ -249,7 +250,7 @@ default double getRank(double quantile) {
* This is equivalent to {@link #getRanks(double[], QuantileSearchCriteria) getRanks(quantiles, INCLUSIVE)}
* @param quantiles the given array of quantiles
* @return an array of normalized ranks corresponding to the given array of quantiles.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
*/
default double[] getRanks(double[] quantiles) {
return getRanks(quantiles, INCLUSIVE);
Expand All @@ -262,7 +263,7 @@ default double[] getRanks(double[] quantiles) {
* @param quantiles the given array of quantiles
* @param searchCrit if INCLUSIVE, the given quantiles include the rank directly corresponding to each quantile.
* @return an array of normalized ranks corresponding to the given array of quantiles.
* @throws IllegalArgumentException if sketch is empty.
* @throws SketchesArgumentException if sketch is empty.
* @see org.apache.datasketches.quantilescommon.QuantileSearchCriteria
*/
double[] getRanks(double[] quantiles, QuantileSearchCriteria searchCrit);
Expand Down
47 changes: 47 additions & 0 deletions src/main/java/org/apache/datasketches/tdigest/TDigestDouble.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.datasketches.memory.WritableBuffer;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.datasketches.quantilescommon.QuantilesAPI;
import org.apache.datasketches.quantilescommon.QuantilesUtil;

/**
* t-Digest for estimating quantiles and ranks.
Expand Down Expand Up @@ -125,6 +126,7 @@ public void merge(final TDigestDouble other) {
/**
* Process buffered values and merge centroids if needed
*/
// this method will become private in the next major version
public void compress() {
if (numBuffered_ == 0) { return; }
final int num = numBuffered_ + numCentroids_;
Expand Down Expand Up @@ -277,6 +279,51 @@ public double getQuantile(final double rank) {
return weightedAverage(centroidWeights_[numCentroids_ - 1], w1, maxValue_, w2);
}

/**
* Returns an approximation to the Probability Mass Function (PMF) of the input stream
* given a set of split points.
*
* @param splitPoints an array of <i>m</i> unique, monotonically increasing values
* that divide the input domain into <i>m+1</i> consecutive disjoint intervals (bins).
*
* @return an array of m+1 doubles each of which is an approximation
* to the fraction of the input stream values (the mass) that fall into one of those intervals.
* @throws SketchesStateException if sketch is empty.
*/
public double[] getPMF(final double[] splitPoints) {
final double[] buckets = getCDF(splitPoints);
for (int i = buckets.length; i-- > 1; ) {
buckets[i] -= buckets[i - 1];
}
return buckets;
}

/**
* Returns an approximation to the Cumulative Distribution Function (CDF), which is the
* cumulative analog of the PMF, of the input stream given a set of split points.
*
* @param splitPoints an array of <i>m</i> unique, monotonically increasing values
* that divide the input domain into <i>m+1</i> consecutive disjoint intervals.
*
* @return an array of m+1 doubles, which are a consecutive approximation to the CDF
* of the input stream given the splitPoints. The value at array position j of the returned
* CDF array is the sum of the returned values in positions 0 through j of the returned PMF
* array. This can be viewed as array of ranks of the given split points plus one more value
* that is always 1.
* @throws SketchesStateException if sketch is empty.
*/
public double[] getCDF(final double[] splitPoints) {
if (isEmpty()) { throw new SketchesStateException(QuantilesAPI.EMPTY_MSG); }
QuantilesUtil.checkDoublesSplitPointsOrder(splitPoints);
final int len = splitPoints.length + 1;
final double[] ranks = new double[len];
for (int i = 0; i < len - 1; i++) {
ranks[i] = getRank(splitPoints[i]);
}
ranks[len - 1] = 1.0;
return ranks;
}

/**
* Computes size needed to serialize the current state.
* @return size in bytes needed to serialize this tdigest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public void empty() {
assertThrows(SketchesStateException.class, () -> td.getMaxValue());
assertThrows(SketchesStateException.class, () -> td.getRank(0));
assertThrows(SketchesStateException.class, () -> td.getQuantile(0.5));
assertThrows(SketchesStateException.class, () -> td.getPMF(new double[]{0}));
assertThrows(SketchesStateException.class, () -> td.getCDF(new double[]{0}));
}

@Test
Expand All @@ -65,9 +67,6 @@ public void manyValues() {
final TDigestDouble td = new TDigestDouble();
final int n = 10000;
for (int i = 0; i < n; i++) td.update(i);
// System.out.println(td.toString(true));
// td.compress();
// System.out.println(td.toString(true));
assertFalse(td.isEmpty());
assertEquals(td.getTotalWeight(), n);
assertEquals(td.getMinValue(), 0);
Expand All @@ -82,6 +81,14 @@ public void manyValues() {
assertEquals(td.getQuantile(0.9), n * 0.9, n * 0.9 * 0.01);
assertEquals(td.getQuantile(0.95), n * 0.95, n * 0.95 * 0.01);
assertEquals(td.getQuantile(1), n - 1);
final double[] pmf = td.getPMF(new double[] {n / 2});
assertEquals(pmf.length, 2);
assertEquals(pmf[0], 0.5, 0.0001);
assertEquals(pmf[1], 0.5, 0.0001);
final double[] cdf = td.getCDF(new double[] {n / 2});
assertEquals(cdf.length, 2);
assertEquals(cdf[0], 0.5, 0.0001);
assertEquals(cdf[1], 1.0);
}

@Test
Expand Down
Loading