From fbd32f1d971558f7f32530c99d43e7891b185b00 Mon Sep 17 00:00:00 2001 From: Kuan Butts Date: Mon, 17 Jun 2019 19:32:56 -0700 Subject: [PATCH] Do not prune trips with just one arrival time (#143) * break out mask var name * Overwrite one null val stop wait times with fallback * New test to ensure that 1 row stops are passed through * Drop excess logic --- peartree/summarizer.py | 5 +++-- tests/test_summarizer.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/test_summarizer.py diff --git a/peartree/summarizer.py b/peartree/summarizer.py index 72e380d..ed2b9d2 100644 --- a/peartree/summarizer.py +++ b/peartree/summarizer.py @@ -147,7 +147,7 @@ def generate_summary_wait_times( df_sub = df[['stop_id', 'wait_dir_0', 'wait_dir_1']].reset_index(drop=True) - init_of_stop_ids = df_sub.stop_id.unique() + init_of_stop_ids = df_sub['stop_id'].unique() # Default values for average waits with not enough data should be # NaN types, but let's make sure all null types are NaNs to be safe @@ -157,7 +157,8 @@ def generate_summary_wait_times( # Convert anything that is 0 or less seconds to a NaN as well # to remove negative or 0 second waits in the system - df_sub.loc[~(df_sub[col] > 0), col] = np.nan + over_zero_mask = df_sub[col] > 0 + df_sub.loc[~over_zero_mask, col] = np.nan # With all null types converted to NaN, we can cast col as float df_sub[col] = df_sub[col].astype(float) diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py new file mode 100644 index 0000000..10c2f4e --- /dev/null +++ b/tests/test_summarizer.py @@ -0,0 +1,35 @@ +import numpy as np +import pandas as pd +from peartree.summarizer import generate_summary_wait_times + + +def test_generate_summary_wait_times(): + df = pd.DataFrame({ + 'stop_id': [ + 1, + 1, + 2, + 2, + 3, + 4], + 'wait_dir_0': [ + 10, + 10, + 19, + 21, + np.nan, + 12], + 'wait_dir_1': [ + np.nan, + np.nan, + 9, + 11, + np.nan, + np.nan], + }) + + fallback_stop_cost = 40.0 # seconds + res = generate_summary_wait_times(df, fallback_stop_cost) + res = res.sort_values(by='stop_id') + assert res['stop_id'].tolist() == [1, 2, 3, 4] + assert res['avg_cost'].tolist() == [10.0, 15.0, 40.0, 12.0]