From 8a9f772cc5d863bd91b71bc14131378ef414af29 Mon Sep 17 00:00:00 2001 From: David Hensle Date: Thu, 28 Mar 2024 15:18:16 -0700 Subject: [PATCH] estimation mode tour checking --- .../models/non_mandatory_tour_frequency.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/activitysim/abm/models/non_mandatory_tour_frequency.py b/activitysim/abm/models/non_mandatory_tour_frequency.py index d032e3aae..972c4b3dc 100644 --- a/activitysim/abm/models/non_mandatory_tour_frequency.py +++ b/activitysim/abm/models/non_mandatory_tour_frequency.py @@ -236,7 +236,9 @@ def non_mandatory_tour_frequency( locals_dict = { "person_max_window": lambda x: person_max_window(state, x), - "person_available_periods": lambda x: person_available_periods(state, x), + "person_available_periods": lambda persons, start_bin, end_bin, continuous: person_available_periods( + state, persons, start_bin, end_bin, continuous + ), } expressions.assign_columns( @@ -425,14 +427,21 @@ def non_mandatory_tour_frequency( if estimator: # make sure they created the right tours survey_tours = estimation.manager.get_survey_table("tours").sort_index() - # FIXME below check needs to remove the pure-escort tours from the survey tours table - # non_mandatory_survey_tours = survey_tours[ - # survey_tours.tour_category == "non_mandatory" - # ] - # assert len(non_mandatory_survey_tours) == len(non_mandatory_tours) - # assert non_mandatory_survey_tours.index.equals( - # non_mandatory_tours.sort_index().index - # ) + non_mandatory_survey_tours = survey_tours[ + survey_tours.tour_category == "non_mandatory" + ] + # need to remove the pure-escort tours from the survey tours table for comparison below + if state.is_table("school_escort_tours"): + non_mandatory_survey_tours = non_mandatory_survey_tours[ + ~non_mandatory_survey_tours.index.isin( + state.get_table("school_escort_tours").index + ) + ] + + assert len(non_mandatory_survey_tours) == len(non_mandatory_tours) + assert non_mandatory_survey_tours.index.equals( + non_mandatory_tours.sort_index().index + ) # make sure they created tours with the expected tour_ids columns = ["person_id", "household_id", "tour_type", "tour_category"]