Skip to content

Commit

Permalink
[XLA] Remove hard-coded constants from chi-square test
Browse files Browse the repository at this point in the history
This makes it easier to adjust.

PiperOrigin-RevId: 587176721
  • Loading branch information
majnemer authored and tensorflower-gardener committed Dec 2, 2023
1 parent 09a463f commit 598681e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,7 @@ xla_test(
"//xla/client:local_client",
"//xla/client:xla_builder",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:test",
],
Expand Down
84 changes: 63 additions & 21 deletions third_party/xla/xla/tests/prng_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <array>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/types/span.h"
#include "unsupported/Eigen/SpecialFunctions" // from @eigen_archive
#include "xla/client/local_client.h"
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
Expand All @@ -43,8 +51,8 @@ class PrngTest : public ClientLibraryTestBase {
// of the given range size. `expected_count` is the number of times each
// possible value is expected to be generated. Thus, the sample size is
// `range_size * expected_count`.
double UniformChiSquared(int32_t range_size, int32_t expected_count,
int64_t seed = 42);
void UniformChiSquared(int32_t range_size, int32_t expected_count,
int64_t seed = 42);
};

template <typename T>
Expand Down Expand Up @@ -141,10 +149,30 @@ template <typename T>
T Square(T x) {
return x * x;
}

// Calculates the p-value (probability) of a given chi-square value and degrees
// of freedom.
double ChiSquarePValue(double chi_square, int dof) {
// We are doing a right-tailed test so the p-value is calculated as 1 - CDF.
//
// The CDF can be computed using the regularized lower incomplete gamma
// function like so:
// gammainc(dof/2, chi_square/2).
//
// Seeing as we are interested in 1-CDF, we can compute this using the
// regularized upper incomplete gamma function like so:
// gammaincc(dof/2, chi_square/2).
//
// NIST/SEMATECH e-Handbook of Statistical Methods, 1.3.6.6.6. Chi-Square
// Distribution: Cumulative Distribution Function
// https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm#cdf
return Eigen::numext::igammac(0.5 * dof, 0.5 * chi_square);
}

} // namespace

double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count,
int64_t seed) {
void PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count,
int64_t seed) {
int32_t sample_size = range_size * expected_count;

XlaBuilder builder(TestName());
Expand All @@ -157,34 +185,48 @@ double PrngTest::UniformChiSquared(int32_t range_size, int32_t expected_count,
std::vector<int32_t> counts(range_size, 0);
actual.EachCell<int32_t>(
[&counts](absl::Span<const int64_t>, int32_t value) { ++counts[value]; });
LOG(INFO) << "sample_size = " << sample_size;
LOG(INFO) << "range_size = " << range_size;
LOG(INFO) << "expected_count = " << expected_count;
for (int32_t i = 0; i < range_size; ++i) {
LOG(INFO) << "counts[" << i << "] = " << counts[i];
}
int64_t sum = 0;
for (int32_t i = 0; i < range_size; ++i) {
sum += Square(static_cast<int64_t>(counts[i] - expected_count));
}
return static_cast<double>(sum) / expected_count;
double chi_square = static_cast<double>(sum) / expected_count;
int64_t dof = range_size - 1;
double p_value = ChiSquarePValue(chi_square, dof);
const double kLevelOfSignificance = 1e-5;
// We have two hypotheses:
// - null hypothesis: the distribution we sampled from cannot be distinguished
// from a uniform random distribution.
// - alternate hypothesis: the distribution we sampled from can be
// distinguished from a uniform random distribution.
//
// The lower our calculated p-value, the less likely we would get this result
// if the null hypothesis were true. If our p-value is greater than or equal
// to `kLevelOfSignificance`, we cannot reject the null hypothesis.
//
// Another way of saying this is that if our p-value is greater than or equal
// to `kLevelOfSignificance` then we can consider our data randomly
// distributed with a confidence of 1-kLevelOfSignificance; otherwise, if our
// p-value is less than `kLevelOfSignificance` then our data is non-random
// with a confidence of 1-kLevelOfSignificance.
EXPECT_GE(p_value, kLevelOfSignificance);
}

// We only test distribution of uniform discrete PRNG as other types are based
// on it.
// These range sizes are arbitrary but include prime numbers, powers of 2, and
// other composite numbers.
// The level of significance in all these cases is 1/20.
// TODO(b/35723038): Use parametrized tests where possible.
XLA_TEST_F(PrngTest, Uniformity7) {
EXPECT_LT(UniformChiSquared(7, 256), 12.5916);
}
XLA_TEST_F(PrngTest, Uniformity61) {
EXPECT_LT(UniformChiSquared(61, 256), 79.0819);
}
XLA_TEST_F(PrngTest, Uniformity64) {
EXPECT_LT(UniformChiSquared(64, 256), 82.5287);
}
XLA_TEST_F(PrngTest, Uniformity108) {
EXPECT_LT(UniformChiSquared(108, 256), 132.144);
}
XLA_TEST_F(PrngTest, Uniformity256) {
EXPECT_LT(UniformChiSquared(256, 512), 293.248);
}
XLA_TEST_F(PrngTest, Uniformity7) { UniformChiSquared(7, 256); }
XLA_TEST_F(PrngTest, Uniformity61) { UniformChiSquared(61, 256); }
XLA_TEST_F(PrngTest, Uniformity64) { UniformChiSquared(64, 256); }
XLA_TEST_F(PrngTest, Uniformity108) { UniformChiSquared(108, 256); }
XLA_TEST_F(PrngTest, Uniformity256) { UniformChiSquared(256, 256); }

// TODO(b/134770669): May remove this test if we decide not to support map
// computations with kRng instructions.
Expand Down

0 comments on commit 598681e

Please sign in to comment.