Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Too low tolerance in test_solve_sample_methods #618

Open
gspr opened this issue Apr 12, 2024 · 1 comment
Open

Too low tolerance in test_solve_sample_methods #618

gspr opened this issue Apr 12, 2024 · 1 comment

Comments

@gspr
Copy link
Contributor

gspr commented Apr 12, 2024

Describe the bug

In test_solve_sample_methods, ot.solve_sample is called with the same random argument in the two first positions. The expectation is, of course, to find a solution of 0. This is checked with np.testing.assert_allclose(sol2.value, 0), with sol2 being said solution. Since assert_allclose defaults to an rtol of 1e-7 and an atol of 0, this means that since the desired value is 0, no deviation is allowed (zero tolerance). This test thus checks for exact equality, and can therefore easily fail.

To Reproduce

Steps to reproduce the behavior:

  1. Load attached pot-bug.npy.gz file into variable x. (The attachment is gzip-compressed because apparently GitHub doesn't like the file extension .npy?)
  2. Compute sol2 = ot.solve_sample(x, x, **{'method': 'gaussian'})).
  3. Observe that np.testing.assert_allclose(sol2.value, 0) fails.

Code sample

See steps to reproduce above.

Expected behavior

The test should use a non-zero atol in assert_allclose to allow for floating point rounding errors.

Environment (please complete the following information):

  • OS: Linux
  • Python version: Python 3.11.2
  • How was POT installed: source, from git commit ab12dd6
  • Build command you used (if compiling from source): pip3 install --user --break-system-packages . from source tree

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.1.0-18-amd64-x86_64-with-glibc2.36
Python 3.11.2 (main, Mar 13 2023, 12:18:29) [GCC 12.2.0]
NumPy 1.24.2
SciPy 1.10.1
POT 0.9.3dev
@rflamary
Copy link
Collaborator

Hello @gspr I agree with you. Could you do a PR to implement this change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants