diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 20c57d0f..69acd85d 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -832,12 +832,19 @@ def test_persist(dask_client): assert new_graph_size < old_graph_size -def test_sample_objects(parquet_ensemble_with_divisions): +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble_with_divisions", + "parquet_ensemble_without_client", + ], +) +def test_sample_objects(data_fixture, request): """ Test Ensemble.sample_objects """ - ens = parquet_ensemble_with_divisions + ens = request.getfixturevalue(data_fixture) ens.source.repartition(npartitions=10).update_ensemble() ens.object.repartition(npartitions=5).update_ensemble() @@ -853,7 +860,8 @@ def test_sample_objects(parquet_ensemble_with_divisions): assert len(ens.object) == prior_obj_len assert len(ens.source) == prior_src_len - ens.client.close() + if data_fixture == "parquet_ensemble_with_divisions": + ens.client.close() def test_update_column_map(dask_client):