Skip to content

Commit

Permalink
Update unit tests for expand_userinputs()
Browse files Browse the repository at this point in the history
  • Loading branch information
iantei committed Nov 28, 2024
1 parent ef760c6 commit ef1cb25
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions viz_scripts/tests/test_scaffolding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import collections as colls
import pytest
import asyncio
import matplotlib.pyplot as plt
import emcommon.util as emcu

# Dynamically import saved-notebooks.plots
scaffolding = importlib.import_module('saved-notebooks.scaffolding')
Expand Down Expand Up @@ -258,3 +260,40 @@ def test_filter_labeled_trips_no_labeled_trips():

# Assert the returned DataFrame is empty
assert len(labeled_ct) == 0

@pytest.fixture
def labeled_ct():
return pd.DataFrame({
'user_input':[
{'purpose_confirm': 'work', 'mode_confirm':'own_car'},
{'mode_confirm':'bus'},
{'purpose_confirm': 'school'},
{'purpose_confirm': 'at_work', 'mode_confirm': 'own_car'},
{'purpose_confirm': 'access_recreation', 'mode_confirm':'car'},
{'mode_confirm':'bike', 'purpose_confirm':'pick_drop_person'},
{'purpose_confirm':'work', 'mode_confirm':'bike'}
],
"distance": [100, 150, 50, 20, 50, 10, 60],
"user_id":["user_1", "user_1", "user_1", "user_2", "user_2", "user_3", "user_4"],
"raw_trip":["trip_0", "trip_1", "trip_2", "trip_3", "trip_4", "trip_5", "trip_6"],
"start_ts":[1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09, 1.690e+09],
"duration": [1845.26, 1200.89, 1000.56, 564.54, 456.456, 156.45, 1564.456],
"distance": [100, 150, 600, 500, 300, 200, 50]
})

def test_expand_userinputs(labeled_ct):
expanded_ct = scaffolding.expand_userinputs(labeled_ct)

# Assert the length of the dataframe is not changed
# Assert the columns have increased with labels_per_trip
labels_per_trip = len(pd.DataFrame(labeled_ct.user_input.to_list()).columns)

assert len(expanded_ct) == len(labeled_ct)
assert labels_per_trip == 2
assert len(expanded_ct.columns) == len(labeled_ct.columns) + labels_per_trip

# Assert new columns and their values
assert 'mode_confirm' in expanded_ct.columns
assert 'purpose_confirm' in expanded_ct.columns
assert expanded_ct['purpose_confirm'].fillna('NaN').tolist() == ['work', 'NaN', 'school', 'at_work', 'access_recreation', 'pick_drop_person', 'work']
assert expanded_ct['mode_confirm'].fillna('NaN').tolist() == ['own_car', 'bus', 'NaN', 'own_car', 'car', 'bike', 'bike']

0 comments on commit ef1cb25

Please sign in to comment.