Skip to content
This repository has been archived by the owner on Jan 14, 2025. It is now read-only.

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Jan 30, 2024
1 parent 2bcb5a3 commit 78cafec
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/tape/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .column_mapper import ColumnMapper # noqa
from .column_mapper import ColumnMapper
from .utils import IndexCallable # noqa
23 changes: 23 additions & 0 deletions src/tape/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class IndexCallable:
"""Provide getitem syntax for functions
>>> def inc(x):
... return x + 1
>>> I = IndexCallable(inc)
>>> I[3]
4
"""

__slots__ = ("fn", "dirty", "ensemble")

def __init__(self, fn, dirty, ensemble):
self.fn = fn
self.dirty = dirty # propagate metadata
self.ensemble = ensemble # propagate ensemble metadata

def __getitem__(self, key):
result = self.fn(key)
result.dirty = self.dirty # propagate metadata
result.ensemble = self.ensemble # propagate ensemble metadata
return result
28 changes: 28 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,34 @@ def test_persist(dask_client):
assert new_graph_size < old_graph_size


@pytest.mark.parametrize("overwrite", [False, True])
def test_sample(parquet_ensemble_with_divisions, overwrite):
"""
Test Ensemble.sample
"""

ens = parquet_ensemble_with_divisions
ens.source.repartition(npartitions=10).update_ensemble()
ens.object.repartition(npartitions=5).update_ensemble()

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

new_ens = ens.sample(frac=0.3, overwrite=overwrite)

assert len(new_ens.object) < prior_obj_len # frac is not exact
assert len(new_ens.source) < prior_src_len # should affect source

if overwrite:
# should have also affected ens in-place
assert len(ens.object) < prior_obj_len
assert len(ens.source) < prior_src_len
else:
# ens should not have been affected
assert len(ens.object) == prior_obj_len
assert len(ens.source) == prior_src_len


def test_update_column_map(dask_client):
"""
Test that we can update the column maps in an Ensemble.
Expand Down
30 changes: 30 additions & 0 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,33 @@ def test_coalesce(dask_client, drop_inputs):
# The input columns should still be present
for col in ["flux1", "flux2", "flux3"]:
assert col in ens.source.columns


def test_partition_slicing(parquet_ensemble_with_divisions):
"""
Test that partition slicing propagates EnsembleFrame metadata
"""
ens = parquet_ensemble_with_divisions

ens.source.repartition(npartitions=10).update_ensemble()
ens.object.repartition(npartitions=5).update_ensemble()

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

# slice on object
ens.object.partitions[0:3].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

assert ens.object.npartitions == 3 # should return exactly 3 partitions
assert len(ens.source) < prior_src_len # should affect source

prior_obj_len = len(ens.object)
prior_src_len = len(ens.source)

# slice on source
ens.source.partitions[0:2].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

assert ens.source.npartitions == 2 # should return exactly 2 partitions
assert len(ens.object) < prior_src_len # should affect objects
19 changes: 18 additions & 1 deletion tests/tape_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import pytest
from tape.utils import ColumnMapper
from tape.utils import ColumnMapper, IndexCallable


def test_index_callable(parquet_ensemble):
"""
Test the basic function of the IndexCallable object
"""

ens = parquet_ensemble

source_ic = IndexCallable(ens.source._partitions, True, "ensemble")

# grab the first (and only) source partition
sliced_source_frame = source_ic[0]

# ensure that the metadata propagates to the result
assert sliced_source_frame.dirty is True
assert sliced_source_frame.ensemble == "ensemble"


def test_column_mapper():
Expand Down

0 comments on commit 78cafec

Please sign in to comment.