diff --git a/test/test_mtc.py b/test/test_mtc.py index cc6b055..a488d4b 100644 --- a/test/test_mtc.py +++ b/test/test_mtc.py @@ -227,13 +227,14 @@ def regress(ext, out_dir): @pytest.mark.parametrize( - "chunk_training_mode,recode_pipeline_columns", + "chunk_training_mode,recode_pipeline_columns,sharrow_enabled", [ ("disabled", True), ("explicit", False), + ("explicit", True, True), ], ) -def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns): +def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns, sharrow_enabled): import activitysim.abm # register components # noqa: F401 out_dir = _test_path(f"output-progressive-recode{recode_pipeline_columns}") @@ -267,8 +268,16 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns): ], }, "recode_pipeline_columns": recode_pipeline_columns, + "trace_hh_id": 1196298, } + if sharrow_enabled and not recode_pipeline_columns: + raise ValueError("sharrow_enabled requires recode_pipeline_columns") + + if sharrow_enabled: + settings["sharrow"] = "test" + del settings["trace_hh_id"] + state = workflow.State.make_default( working_dir=working_dir, configs_dir=("ext-configs", "configs"), @@ -279,11 +288,11 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns): ) state.filesystem.persist_sharrow_cache() state.logging.config_logger() - state.settings.trace_hh_id = 1196298 assert state.settings.models == EXPECTED_MODELS assert state.settings.chunk_size == 0 - assert not state.settings.sharrow + if not sharrow_enabled: + assert not state.settings.sharrow ref_pipeline = Path(__file__).parent.joinpath( f"reference-pipeline-extended-recode{recode_pipeline_columns}.zip" @@ -319,3 +328,7 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns): if __name__ == "__main__": + test_mtc_extended_progressive("disabled", True, False) + test_mtc_extended_progressive("explicit", False, False) + test_mtc_extended_progressive("explicit", True, True) +