Skip to content

Commit

Permalink
add temporary cols test
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Oct 4, 2023
1 parent 4049e03 commit 074cf3d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,7 @@ def _sync_tables(self):

# Drop Temporary Source Columns on Sync
if len(self._source_temp):
self._source.drop(columns=self._source_temp)
self._source = self._source.drop(columns=self._source_temp)
print(f"Temporary columns dropped from Source Table: {self._source_temp}")
self._source_temp = []

Check warning on line 1412 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1410-L1412

Added lines #L1410 - L1412 were not covered by tests

Expand All @@ -1420,7 +1420,7 @@ def _sync_tables(self):

# Drop Temporary Object Columns on Sync
if len(self._object_temp):
self._object.drop(columns=self._object_temp)
self._object = self._object.drop(columns=self._object_temp)
print(f"Temporary columns dropped from Object Table: {self._object_temp}")
self._object_temp = []

Expand Down
31 changes: 31 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,37 @@ def test_lazy_sync_tables(parquet_ensemble):
assert not parquet_ensemble._source_dirty


def test_temporary_cols(parquet_ensemble):
"""
Test that temporary columns are tracked and dropped as expected.
"""

ens = parquet_ensemble
ens._object = ens._object.drop(columns=["nobs_r", "nobs_g", "nobs_total"])

# Make sure temp lists are available but empty
assert not len(ens._source_temp)
assert not len(ens._object_temp)

ens.calc_nobs(temporary=True) # Generates "nobs_total"

# nobs_total should be a temporary column
assert "nobs_total" in ens._object_temp
assert "nobs_total" in ens._object.columns

# drop NaNs from source, source should be dirty now
ens.dropna(how="any", table="source")

assert ens._source_dirty

# try a sync
ens._sync_tables()

# nobs_total should be removed
assert "nobs_total" not in ens._object_temp
assert "nobs_total" not in ens._object.columns


def test_dropna(parquet_ensemble):
# Try passing in an unrecognized 'table' parameter and verify an exception is thrown
with pytest.raises(ValueError):
Expand Down

0 comments on commit 074cf3d

Please sign in to comment.