From 074cf3d71cdbddc4f3c7d963fcefed3f623f4c14 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Wed, 4 Oct 2023 14:23:50 -0700 Subject: [PATCH] add temporary cols test --- src/tape/ensemble.py | 4 ++-- tests/tape_tests/test_ensemble.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 73d53176..a2da1447 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1407,7 +1407,7 @@ def _sync_tables(self): # Drop Temporary Source Columns on Sync if len(self._source_temp): - self._source.drop(columns=self._source_temp) + self._source = self._source.drop(columns=self._source_temp) print(f"Temporary columns dropped from Source Table: {self._source_temp}") self._source_temp = [] @@ -1420,7 +1420,7 @@ def _sync_tables(self): # Drop Temporary Object Columns on Sync if len(self._object_temp): - self._object.drop(columns=self._object_temp) + self._object = self._object.drop(columns=self._object_temp) print(f"Temporary columns dropped from Object Table: {self._object_temp}") self._object_temp = [] diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 5dcf8ff2..ea6f0552 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -473,6 +473,37 @@ def test_lazy_sync_tables(parquet_ensemble): assert not parquet_ensemble._source_dirty +def test_temporary_cols(parquet_ensemble): + """ + Test that temporary columns are tracked and dropped as expected. + """ + + ens = parquet_ensemble + ens._object = ens._object.drop(columns=["nobs_r", "nobs_g", "nobs_total"]) + + # Make sure temp lists are available but empty + assert not len(ens._source_temp) + assert not len(ens._object_temp) + + ens.calc_nobs(temporary=True) # Generates "nobs_total" + + # nobs_total should be a temporary column + assert "nobs_total" in ens._object_temp + assert "nobs_total" in ens._object.columns + + # drop NaNs from source, source should be dirty now + ens.dropna(how="any", table="source") + + assert ens._source_dirty + + # try a sync + ens._sync_tables() + + # nobs_total should be removed + assert "nobs_total" not in ens._object_temp + assert "nobs_total" not in ens._object.columns + + def test_dropna(parquet_ensemble): # Try passing in an unrecognized 'table' parameter and verify an exception is thrown with pytest.raises(ValueError):