diff --git a/viz_scripts/tests/test_scaffolding.py b/viz_scripts/tests/test_scaffolding.py index b58aeb0..b0cd280 100644 --- a/viz_scripts/tests/test_scaffolding.py +++ b/viz_scripts/tests/test_scaffolding.py @@ -6,7 +6,9 @@ import numpy as np import collections as colls import pytest +import asyncio import matplotlib.pyplot as plt +import emcommon.util as emcu # Dynamically import saved-notebooks.plots scaffolding = importlib.import_module('saved-notebooks.scaffolding') @@ -258,3 +260,40 @@ def test_filter_labeled_trips_no_labeled_trips(): # Assert the returned DataFrame is empty assert len(labeled_ct) == 0 + +@pytest.fixture +def labeled_ct(): + return pd.DataFrame({ + 'user_input':[ + {'purpose_confirm': 'work', 'mode_confirm':'own_car'}, + {'mode_confirm':'bus'}, + {'purpose_confirm': 'school'}, + {'purpose_confirm': 'at_work', 'mode_confirm': 'own_car'}, + {'purpose_confirm': 'access_recreation', 'mode_confirm':'car'}, + {'mode_confirm':'bike', 'purpose_confirm':'pick_drop_person'}, + {'purpose_confirm':'work', 'mode_confirm':'bike'} + ], + "distance": [100, 150, 50, 20, 50, 10, 60], + "user_id":["user_1", "user_1", "user_1", "user_2", "user_2", "user_3", "user_4"], + "raw_trip":["trip_0", "trip_1", "trip_2", "trip_3", "trip_4", "trip_5", "trip_6"], + "start_ts":[1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09], + "duration": [1845.26, 1200.89, 1000.56, 564.54, 456.456, 156.45, 1564.456], + "distance": [100, 150, 600, 500, 300, 200, 50] + }) + +def test_expand_userinputs(labeled_ct): + expanded_ct = scaffolding.expand_userinputs(labeled_ct) + + # Assert the length of the dataframe is not changed + # Assert the columns have increased with labels_per_trip + labels_per_trip = len(pd.DataFrame(labeled_ct.user_input.to_list()).columns) + + assert len(expanded_ct) == len(labeled_ct) + assert labels_per_trip == 2 + assert len(expanded_ct.columns) == len(labeled_ct.columns) + labels_per_trip + + # Assert new columns and their values + assert 'mode_confirm' in expanded_ct.columns + assert 'purpose_confirm' in expanded_ct.columns + assert expanded_ct['purpose_confirm'].fillna('NaN').tolist() == ['work', 'NaN', 'school', 'at_work', 'access_recreation', 'pick_drop_person', 'work'] + assert expanded_ct['mode_confirm'].fillna('NaN').tolist() == ['own_car', 'bus', 'NaN', 'own_car', 'car', 'bike', 'bike']