From 78cafec22d3ab4b12db4ccc6240cb9595451ceb3 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Tue, 30 Jan 2024 14:54:11 -0800 Subject: [PATCH] add tests --- src/tape/utils/__init__.py | 3 ++- src/tape/utils/utils.py | 23 +++++++++++++++++++ tests/tape_tests/test_ensemble.py | 28 +++++++++++++++++++++++ tests/tape_tests/test_ensemble_frame.py | 30 +++++++++++++++++++++++++ tests/tape_tests/test_utils.py | 19 +++++++++++++++- 5 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 src/tape/utils/utils.py diff --git a/src/tape/utils/__init__.py b/src/tape/utils/__init__.py index f9bb7851..81d85d44 100644 --- a/src/tape/utils/__init__.py +++ b/src/tape/utils/__init__.py @@ -1 +1,2 @@ -from .column_mapper import ColumnMapper # noqa +from .column_mapper import ColumnMapper +from .utils import IndexCallable # noqa diff --git a/src/tape/utils/utils.py b/src/tape/utils/utils.py new file mode 100644 index 00000000..9216b528 --- /dev/null +++ b/src/tape/utils/utils.py @@ -0,0 +1,23 @@ +class IndexCallable: + """Provide getitem syntax for functions + + >>> def inc(x): + ... return x + 1 + + >>> I = IndexCallable(inc) + >>> I[3] + 4 + """ + + __slots__ = ("fn", "dirty", "ensemble") + + def __init__(self, fn, dirty, ensemble): + self.fn = fn + self.dirty = dirty # propagate metadata + self.ensemble = ensemble # propagate ensemble metadata + + def __getitem__(self, key): + result = self.fn(key) + result.dirty = self.dirty # propagate metadata + result.ensemble = self.ensemble # propagate ensemble metadata + return result diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 6ee17f75..a1297e42 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -832,6 +832,34 @@ def test_persist(dask_client): assert new_graph_size < old_graph_size +@pytest.mark.parametrize("overwrite", [False, True]) +def test_sample(parquet_ensemble_with_divisions, overwrite): + """ + Test Ensemble.sample + """ + + ens = parquet_ensemble_with_divisions + ens.source.repartition(npartitions=10).update_ensemble() + ens.object.repartition(npartitions=5).update_ensemble() + + prior_obj_len = len(ens.object) + prior_src_len = len(ens.source) + + new_ens = ens.sample(frac=0.3, overwrite=overwrite) + + assert len(new_ens.object) < prior_obj_len # frac is not exact + assert len(new_ens.source) < prior_src_len # should affect source + + if overwrite: + # should have also affected ens in-place + assert len(ens.object) < prior_obj_len + assert len(ens.source) < prior_src_len + else: + # ens should not have been affected + assert len(ens.object) == prior_obj_len + assert len(ens.source) == prior_src_len + + def test_update_column_map(dask_client): """ Test that we can update the column maps in an Ensemble. diff --git a/tests/tape_tests/test_ensemble_frame.py b/tests/tape_tests/test_ensemble_frame.py index 5ed01488..7c99ae7a 100644 --- a/tests/tape_tests/test_ensemble_frame.py +++ b/tests/tape_tests/test_ensemble_frame.py @@ -409,3 +409,33 @@ def test_coalesce(dask_client, drop_inputs): # The input columns should still be present for col in ["flux1", "flux2", "flux3"]: assert col in ens.source.columns + + +def test_partition_slicing(parquet_ensemble_with_divisions): + """ + Test that partition slicing propagates EnsembleFrame metadata + """ + ens = parquet_ensemble_with_divisions + + ens.source.repartition(npartitions=10).update_ensemble() + ens.object.repartition(npartitions=5).update_ensemble() + + prior_obj_len = len(ens.object) + prior_src_len = len(ens.source) + + # slice on object + ens.object.partitions[0:3].update_ensemble() + ens._lazy_sync_tables("all") # sync needed as len() won't trigger one + + assert ens.object.npartitions == 3 # should return exactly 3 partitions + assert len(ens.source) < prior_src_len # should affect source + + prior_obj_len = len(ens.object) + prior_src_len = len(ens.source) + + # slice on source + ens.source.partitions[0:2].update_ensemble() + ens._lazy_sync_tables("all") # sync needed as len() won't trigger one + + assert ens.source.npartitions == 2 # should return exactly 2 partitions + assert len(ens.object) < prior_src_len # should affect objects diff --git a/tests/tape_tests/test_utils.py b/tests/tape_tests/test_utils.py index 124e3ab2..7fbec401 100644 --- a/tests/tape_tests/test_utils.py +++ b/tests/tape_tests/test_utils.py @@ -1,5 +1,22 @@ import pytest -from tape.utils import ColumnMapper +from tape.utils import ColumnMapper, IndexCallable + + +def test_index_callable(parquet_ensemble): + """ + Test the basic function of the IndexCallable object + """ + + ens = parquet_ensemble + + source_ic = IndexCallable(ens.source._partitions, True, "ensemble") + + # grab the first (and only) source partition + sliced_source_frame = source_ic[0] + + # ensure that the metadata propagates to the result + assert sliced_source_frame.dirty is True + assert sliced_source_frame.ensemble == "ensemble" def test_column_mapper():