diff --git a/activitysim/abm/models/disaggregate_accessibility.py b/activitysim/abm/models/disaggregate_accessibility.py index ab4f9acef..06164e6b2 100644 --- a/activitysim/abm/models/disaggregate_accessibility.py +++ b/activitysim/abm/models/disaggregate_accessibility.py @@ -846,6 +846,10 @@ def compute_disaggregate_accessibility( state.tracing.register_traceable_table(tablename, df) del df + disagg_model_settings = read_disaggregate_accessibility_yaml( + "disaggregate_accessibility.yaml" + ) + # Run location choice logsums = get_disaggregate_logsums( state, @@ -906,4 +910,22 @@ def compute_disaggregate_accessibility( for k, df in logsums.items(): state.add_table(k, df) + # available post-processing + for annotations in disagg_model_settings.get("postprocess_proto_tables", []): + tablename = annotations["tablename"] + df = state.get_dataframe(tablename) + assert df is not None + assert annotations is not None + assign_columns( + df=df, + model_settings={ + **annotations["annotate"], + **disagg_model_settings["suffixes"], + }, + trace_label=tracing.extend_trace_label( + "disaggregate_accessibility.postprocess", tablename + ), + ) + state.add_table(tablename, df) + return diff --git a/activitysim/abm/tables/disaggregate_accessibility.py b/activitysim/abm/tables/disaggregate_accessibility.py index 8ab0e0820..0372bd6f4 100644 --- a/activitysim/abm/tables/disaggregate_accessibility.py +++ b/activitysim/abm/tables/disaggregate_accessibility.py @@ -172,14 +172,13 @@ def disaggregate_accessibility(state: workflow.State): accessibility_cols = [ x for x in proto_accessibility_df.columns if "accessibility" in x ] + keep_cols = model_settings.get("KEEP_COLS", accessibility_cols) # Parse the merging parameters assert merging_params is not None # Check if already assigned! - if set(accessibility_cols).intersection(persons_merged_df.columns) == set( - accessibility_cols - ): + if set(keep_cols).intersection(persons_merged_df.columns) == set(keep_cols): return # Find the nearest zone (spatially) with accessibilities calculated @@ -211,7 +210,7 @@ def disaggregate_accessibility(state: workflow.State): # because it will get slightly different logsums for households in the same zone. # This is because different destination zones were selected. To resolve, get mean by cols. right_df = ( - proto_accessibility_df.groupby(merge_cols)[accessibility_cols] + proto_accessibility_df.groupby(merge_cols)[keep_cols] .mean() .sort_values(nearest_cols) .reset_index() @@ -244,9 +243,9 @@ def disaggregate_accessibility(state: workflow.State): ) # Predict the nearest person ID and pull the logsums - matched_logsums_df = right_df.loc[clf.predict(x_pop)][ - accessibility_cols - ].reset_index(drop=True) + matched_logsums_df = right_df.loc[clf.predict(x_pop)][keep_cols].reset_index( + drop=True + ) merge_df = pd.concat( [left_df.reset_index(drop=False), matched_logsums_df], axis=1 ).set_index("person_id") @@ -278,9 +277,9 @@ def disaggregate_accessibility(state: workflow.State): # Check that it was correctly left-joined assert all(persons_merged_df[merge_cols] == merge_df[merge_cols]) - assert any(merge_df[accessibility_cols].isnull()) + assert any(merge_df[keep_cols].isnull()) # Inject merged accessibilities so that it can be included in persons_merged function - state.add_table("disaggregate_accessibility", merge_df[accessibility_cols]) + state.add_table("disaggregate_accessibility", merge_df[keep_cols]) - return merge_df[accessibility_cols] + return merge_df[keep_cols]