Skip to content

Commit

Permalink
Added missing unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
wenneman committed Oct 5, 2023
1 parent 977ecd8 commit 7bd3615
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,52 @@ def test_from_source_dict(dask_client):
assert obj_table.iloc[1][ens._nobs_tot_col] == 5


def test_read_source_dict(dask_client):
"""
Test that ensemble.from_source_dict() successfully creates data from a dictionary.
"""
ens = Ensemble(client=dask_client)

# Create some fake data with two IDs (8001, 8002), two bands ["g", "b"]
# and a few time steps. Leave out the flux data initially.
rows = {
"id": [8001, 8001, 8001, 8001, 8002, 8002, 8002, 8002, 8002],
"time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1],
"band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"],
"err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}

# We get an error without all of the required rows.
with pytest.raises(ValueError):
tape.read_source_dict(rows)

# Add the last row and build the ensemble.
rows["flux"] = [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0]

cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band")

ens = tape.read_source_dict(
rows,
column_mapper=cmap,
dask_client=dask_client
)

(obj_table, src_table) = ens.compute()

# Check that the loaded source table is correct.
assert src_table.shape[0] == 9
for i in range(9):
assert src_table.iloc[i][ens._flux_col] == rows[ens._flux_col][i]
assert src_table.iloc[i][ens._time_col] == rows[ens._time_col][i]
assert src_table.iloc[i][ens._band_col] == rows[ens._band_col][i]
assert src_table.iloc[i][ens._err_col] == rows[ens._err_col][i]

# Check that the derived object table is correct.
assert obj_table.shape[0] == 2
assert obj_table.iloc[0][ens._nobs_tot_col] == 4
assert obj_table.iloc[1][ens._nobs_tot_col] == 5


def test_insert(parquet_ensemble):
num_partitions = parquet_ensemble._source.npartitions
(old_object, old_source) = parquet_ensemble.compute()
Expand Down

0 comments on commit 7bd3615

Please sign in to comment.