diff --git a/peartree/summarizer.py b/peartree/summarizer.py index e94c2a5..b5887ab 100644 --- a/peartree/summarizer.py +++ b/peartree/summarizer.py @@ -11,6 +11,22 @@ from peartree.utilities import log + +class InvalidParsedWaitTimes(Exception): + pass + +def _format_summarized_outputs(summarized: pd.Series) -> pd.DataFrame: + # The output of the group by produces a Series, but we want to extract + # the values from the index and the Series itself and generate a + # pandas DataFrame instead + original_stop_ids_index = summarized.index.values + original_series_values = summarized.values + + return pd.DataFrame({ + 'stop_id': original_stop_ids_index, + 'avg_cost': original_series_values}) + + def calculate_average_wait(direction_times: pd.DataFrame) -> float: # Exit early if we do not have enough values to calculate a mean at = direction_times.arrival_time @@ -131,14 +147,27 @@ def generate_summary_wait_times( dir_0_check_2 = df_sub[np.isnan(df_sub.wait_dir_0)] dir_1_check_2 = df_sub[np.isnan(df_sub.wait_dir_1)] - if (len(dir_0_check_2) > 0) or (len(dir_1_check_2) > 0): - raise Exception('NaN values for both directions on some stop IDs.') - - grouped = df_sub.groupby('stop_id') - summarized = grouped.apply(summarize_waits_at_one_stop) + dir_0_trigger = len(dir_0_check_2) > 0 + dir_1_trigger = len(dir_1_check_2) > 0 + if dir_0_trigger or dir_1_trigger: + raise InvalidParsedWaitTimes( + 'NaN values for both directions on some stop IDs.') + + # At this point, we should make sure that there are still values + # in the DataFrame - otherwise we are in a situation where there are + # no valid times to evaluate. This is okay; we just need to skip straight + # to the application of the fallback value + if df_sub.empty: + # So just make a fallback empty dataframe for now + summed_reset = pd.DataFrame({'stop_id': [], 'avg_cost': []}) + + # Only attempt this group by summary if at least one row to group on + else: + grouped = df_sub.groupby('stop_id') + summarized = grouped.apply(summarize_waits_at_one_stop) - summed_reset = summarized.reset_index(drop=False) - summed_reset.columns = ['stop_id', 'avg_cost'] + # Clean up summary results, reformat pandas DataFrame result + summed_reset = _format_summarized_outputs(summarized) end_of_stop_ids = summed_reset.stop_id.unique() log('Original stop id count: {}'.format(len(init_of_stop_ids))) diff --git a/tests/fixtures/highdesertpointorus-2018-03-20.zip b/tests/fixtures/highdesertpointorus-2018-03-20.zip new file mode 100644 index 0000000..bb24169 Binary files /dev/null and b/tests/fixtures/highdesertpointorus-2018-03-20.zip differ diff --git a/tests/test_paths.py b/tests/test_paths.py index 2800622..7df9972 100644 --- a/tests/test_paths.py +++ b/tests/test_paths.py @@ -97,6 +97,17 @@ def test_loading_in_invalid_timeframes(): load_feed_as_graph(feed_1, start, end) +def test_parsing_when_just_on_trip_during_target_window(): + path = fixture('highdesertpointorus-2018-03-20.zip') + feed = get_representative_feed(path) + + start = 7*60*60 # 7:00 AM + end = 8*60*60 # 10:00 AM + G = load_feed_as_graph(feed, start, end) + assert len(list(G.nodes())) == 2 + assert len(list(G.edges())) == 1 + + def test_synthetic_network(): # Load in the GeoJSON as a JSON and convert to a dictionary geojson_path = fixture('synthetic_east_bay.geojson')