Skip to content

Commit

Permalink
Unblock NumPy 2.0 (#6991)
Browse files Browse the repository at this point in the history
* Unblock NumPy 2.0

* .

* Revert tensorflow min version

* Add CI tests for numpy2

* Implement test require_numpy1_on_windows

* Mark tests with require_numpy1_on_windows

* Fix test skip reason

* Add clarifying comment

---------

Co-authored-by: Albert Villanova del Moral <[email protected]>
  • Loading branch information
NeilGirdhar and albertvillanova authored Jul 12, 2024
1 parent 8419c40 commit dfc2b1b
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 7 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,30 @@ jobs:
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
test_py310_numpy2:
needs: check_code_quality
strategy:
matrix:
test: ['unit']
os: [ubuntu-latest, windows-latest]
deps_versions: [deps-latest]
continue-on-error: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Install uv
run: pip install --upgrade uv
- name: Install dependencies
run: uv pip install --system "datasets[tests_numpy2] @ ."
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
19 changes: 15 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
# For file locking
"filelock",
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17,<2.0.0", # Temporary upper version
"numpy>=1.17",
# Backend and serialization.
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
"pyarrow>=15.0.0",
Expand Down Expand Up @@ -166,7 +166,7 @@
"pytest-xdist",
# optional dependencies
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
Expand All @@ -176,11 +176,11 @@
"sqlalchemy",
"s3fs>=2021.11.1", # aligned with fsspec[http]>=2021.11.1; test only on python 3.7 for now
"protobuf<4.0.0", # 4.0.0 breaks compatibility with tensorflow<2.12
"tensorflow>=2.6.0",
"tensorflow>=2.6.0", # Issue installing 2.16.0 with Python 3.8; we rely on other dependencies pinning numpy < 2
"tiktoken",
"torch>=2.0.0",
"soundfile>=0.12.1",
"transformers",
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
]
Expand All @@ -189,6 +189,16 @@
TESTS_REQUIRE.extend(VISION_REQUIRE)
TESTS_REQUIRE.extend(AUDIO_REQUIRE)

NUMPY2_INCOMPATIBLE_LIBRARIES = [
"faiss-cpu",
"librosa",
"tensorflow",
"transformers",
]
TESTS_NUMPY2_REQUIRE = [
library for library in TESTS_REQUIRE if library.partition(">")[0] not in NUMPY2_INCOMPATIBLE_LIBRARIES
]

QUALITY_REQUIRE = ["ruff>=0.3.0"]

DOCS_REQUIRE = [
Expand All @@ -213,6 +223,7 @@
"streaming": [], # for backward compatibility
"dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE,
"tests": TESTS_REQUIRE,
"tests_numpy2": TESTS_NUMPY2_REQUIRE,
"quality": QUALITY_REQUIRE,
"benchmarks": BENCHMARKS_REQUIRE,
"docs": DOCS_REQUIRE,
Expand Down
3 changes: 2 additions & 1 deletion tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datasets.info import DatasetInfo
from datasets.utils.py_utils import asdict

from ..utils import require_jax, require_tf, require_torch
from ..utils import require_jax, require_numpy1_on_windows, require_tf, require_torch


class FeaturesTest(TestCase):
Expand Down Expand Up @@ -543,6 +543,7 @@ def test_cast_to_python_objects_pandas_timedelta(self):
casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]}))
self.assertDictEqual(casted_obj, {"a": [expected_obj]})

@require_numpy1_on_windows
@require_torch
def test_cast_to_python_objects_torch(self):
import torch
Expand Down
3 changes: 2 additions & 1 deletion tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import Audio, DownloadManager, Features, Image, Sequence, Value
from datasets.packaged_modules.webdataset.webdataset import WebDataset

from ..utils import require_librosa, require_pil, require_sndfile, require_torch
from ..utils import require_librosa, require_numpy1_on_windows, require_pil, require_sndfile, require_torch


@pytest.fixture
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_webdataset_with_features(image_wds_file):
assert isinstance(decoded["jpg"], PIL.Image.Image)


@require_numpy1_on_windows
@require_torch
def test_tensor_webdataset(tensor_wds_file):
import torch
Expand Down
4 changes: 4 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
require_dill_gt_0_3_2,
require_jax,
require_not_windows,
require_numpy1_on_windows,
require_pil,
require_polars,
require_pyspark,
Expand Down Expand Up @@ -420,6 +421,7 @@ def test_set_format_numpy_multiple_columns(self, in_memory):
self.assertIsInstance(dset[0]["col_2"], np.str_)
self.assertEqual(dset[0]["col_2"].item(), "a")

@require_numpy1_on_windows
@require_torch
def test_set_format_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1525,6 +1527,7 @@ def func_return_multi_row_pd_dataframe(x):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

@require_numpy1_on_windows
@require_torch
def test_map_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1590,6 +1593,7 @@ def func(example):
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

@require_numpy1_on_windows
@require_torch
def test_map_tensor_batched(self, in_memory):
import torch
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import (
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
require_numpy1_on_windows,
require_polars,
require_tf,
require_torch,
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_set_format_numpy(self):
self.assertEqual(dset_split[0]["col_2"].item(), "a")
del dset

@require_numpy1_on_windows
@require_torch
def test_set_format_torch(self):
import torch
Expand Down
3 changes: 3 additions & 0 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .utils import (
require_not_windows,
require_numpy1_on_windows,
require_regex,
require_spacy,
require_tiktoken,
Expand Down Expand Up @@ -303,6 +304,7 @@ def test_hash_tiktoken_encoding(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_tensor(self):
import torch
Expand All @@ -316,6 +318,7 @@ def test_hash_torch_tensor(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_generator(self):
import torch
Expand Down
5 changes: 5 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import (
require_jax,
require_librosa,
require_numpy1_on_windows,
require_pil,
require_polars,
require_sndfile,
Expand Down Expand Up @@ -353,6 +354,7 @@ def test_polars_formatter(self):
assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all()
assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all()

@require_numpy1_on_windows
@require_torch
def test_torch_formatter(self):
import torch
Expand All @@ -373,6 +375,7 @@ def test_torch_formatter(self):
torch.testing.assert_close(batch["c"], torch.tensor(_COL_C, dtype=torch.float32))
assert batch["c"].shape == np.array(_COL_C).shape

@require_numpy1_on_windows
@require_torch
def test_torch_formatter_torch_tensor_kwargs(self):
import torch
Expand All @@ -389,6 +392,7 @@ def test_torch_formatter_torch_tensor_kwargs(self):
self.assertEqual(batch["a"].dtype, torch.float16)
self.assertEqual(batch["c"].dtype, torch.float16)

@require_numpy1_on_windows
@require_torch
@require_pil
def test_torch_formatter_image(self):
Expand Down Expand Up @@ -975,6 +979,7 @@ def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table):
tf.debugging.assert_equal(batch["col_float"], tf.ragged.constant(list_float, dtype=tf.float32))


@require_numpy1_on_windows
@require_torch
@pytest.mark.parametrize(
"cast_schema",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
is_rng_equal,
require_dill_gt_0_3_2,
require_not_windows,
require_numpy1_on_windows,
require_pyspark,
require_tf,
require_torch,
Expand Down Expand Up @@ -1279,6 +1280,7 @@ def gen(shard_names):
assert dataset.n_shards == len(shard_names)


@require_numpy1_on_windows
def test_iterable_dataset_from_file(dataset: IterableDataset, arrow_file: str):
with assert_arrow_memory_doesnt_increase():
dataset_from_file = IterableDataset.from_file(arrow_file)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
zip_dict,
)

from .utils import require_tf, require_torch
from .utils import require_numpy1_on_windows, require_tf, require_torch


def np_sum(x): # picklable for multiprocessing
Expand Down Expand Up @@ -151,6 +151,7 @@ def gen_random_output():
np.testing.assert_equal(out1, out2)
self.assertGreater(np.abs(out1 - out3).sum(), 0)

@require_numpy1_on_windows
@require_torch
def test_torch(self):
import torch
Expand Down
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def parse_flag_from_env(key, default=False):

require_faiss = pytest.mark.skipif(find_spec("faiss") is None or sys.platform == "win32", reason="test requires faiss")

require_numpy1_on_windows = pytest.mark.skipif(
version.parse(importlib.metadata.version("numpy")) >= version.parse("2.0.0") and sys.platform == "win32",
reason="test requires numpy < 2.0 on windows",
)


def require_regex(test_case):
"""
Expand Down

0 comments on commit dfc2b1b

Please sign in to comment.