Skip to content

Commit

Permalink
testing with sharrow
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed Apr 23, 2024
1 parent 11af951 commit 535c123
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions test/test_mtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ def regress(ext, out_dir):


@pytest.mark.parametrize(
"chunk_training_mode,recode_pipeline_columns",
"chunk_training_mode,recode_pipeline_columns,sharrow_enabled",
[
("disabled", True),
("explicit", False),
("explicit", True, True),
],
)
def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns, sharrow_enabled):
import activitysim.abm # register components # noqa: F401

out_dir = _test_path(f"output-progressive-recode{recode_pipeline_columns}")
Expand Down Expand Up @@ -267,8 +268,16 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
],
},
"recode_pipeline_columns": recode_pipeline_columns,
"trace_hh_id": 1196298,
}

if sharrow_enabled and not recode_pipeline_columns:
raise ValueError("sharrow_enabled requires recode_pipeline_columns")

if sharrow_enabled:
settings["sharrow"] = "test"
del settings["trace_hh_id"]

state = workflow.State.make_default(
working_dir=working_dir,
configs_dir=("ext-configs", "configs"),
Expand All @@ -279,11 +288,11 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
)
state.filesystem.persist_sharrow_cache()
state.logging.config_logger()
state.settings.trace_hh_id = 1196298

assert state.settings.models == EXPECTED_MODELS
assert state.settings.chunk_size == 0
assert not state.settings.sharrow
if not sharrow_enabled:
assert not state.settings.sharrow

ref_pipeline = Path(__file__).parent.joinpath(
f"reference-pipeline-extended-recode{recode_pipeline_columns}.zip"
Expand Down Expand Up @@ -319,3 +328,7 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):


if __name__ == "__main__":
test_mtc_extended_progressive("disabled", True, False)
test_mtc_extended_progressive("explicit", False, False)
test_mtc_extended_progressive("explicit", True, True)

0 comments on commit 535c123

Please sign in to comment.