forked from e-mission/e-mission-eval-private-data
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request e-mission#34 from hlu109/poster_update
Merge label-assist paper updates
- Loading branch information
Showing
26 changed files
with
14,179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
This folder contains code used in the TRB paper submission on label-assist | ||
algorithms. In addition, the notebooks contain code that can be used to | ||
generate the figures displayed in the paper. | ||
|
||
These notebooks rely on modifications in a [branch](https://github.com/hlu109/e-mission-server/tree/eval-private-data-compatibility) of the e-mission-server. Ensure that this version of the server is set up according to the [instructions in this repository's main README](../README.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### imports\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import matplotlib\n", | ||
"\n", | ||
"from sklearn.pipeline import make_pipeline\n", | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"from sklearn import svm\n", | ||
"\n", | ||
"from uuid import UUID\n", | ||
"\n", | ||
"import emission.storage.timeseries.abstract_timeseries as esta\n", | ||
"import emission.storage.decorations.trip_queries as esdtq\n", | ||
"import emission.core.get_database as edb\n", | ||
"\n", | ||
"import data_wrangling\n", | ||
"from clustering import add_loc_clusters" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### load data\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Please contact the corresponding author for the emails/uuids of the data used in the paper.\n", | ||
"suburban_email = \"stage_d3GBLDSVzn4\"\n", | ||
"college_campus_email = \"stage_OlTMUasqLbxL7KfI\"\n", | ||
"suburban_uuid = list(edb.get_uuid_db().find({\"user_email\":\n", | ||
" suburban_email}))[0]['uuid']\n", | ||
"college_campus_uuid = list(edb.get_uuid_db().find(\n", | ||
" {\"user_email\": college_campus_email}))[0]['uuid']\n", | ||
"\n", | ||
"uuids = [suburban_uuid, college_campus_uuid]\n", | ||
"confirmed_trip_df_map = {}\n", | ||
"labeled_trip_df_map = {}\n", | ||
"expanded_trip_df_map = {}\n", | ||
"for u in uuids:\n", | ||
" ts = esta.TimeSeries.get_time_series(u)\n", | ||
" ct_df = ts.get_data_df(\"analysis/confirmed_trip\")\n", | ||
" confirmed_trip_df_map[u] = ct_df\n", | ||
" labeled_trip_df_map[u] = esdtq.filter_labeled_trips(ct_df)\n", | ||
" expanded_trip_df_map[u] = esdtq.expand_userinputs(labeled_trip_df_map[u])" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### SVM exploration for single clusters\n" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Set up data so that we can look at some specific clusters first.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def setup(user_df,\n", | ||
" alg,\n", | ||
" loc_type,\n", | ||
" radii=[50, 100, 150, 200],\n", | ||
" cluster_unlabeled=False):\n", | ||
" \"\"\" copied and modified from find_plot_clusters() in mapping \"\"\"\n", | ||
" # clean up the dataframe by dropping entries with NaN locations and\n", | ||
" # resetting index because oursim needs the position of each trip to match\n", | ||
" # its nominal index\n", | ||
" all_trips_df = user_df.dropna(subset=['start_loc', 'end_loc']).reset_index(\n", | ||
" drop=True)\n", | ||
"\n", | ||
" # expand the 'start/end_loc' column into 'start/end_lat/lon' columns\n", | ||
" all_trips_df = data_wrangling.expand_coords(all_trips_df)\n", | ||
"\n", | ||
" labeled_trips_df = all_trips_df.loc[all_trips_df.user_input != {}]\n", | ||
" df_for_cluster = all_trips_df if cluster_unlabeled else labeled_trips_df\n", | ||
"\n", | ||
" df_for_cluster = add_loc_clusters(df_for_cluster,\n", | ||
" radii=radii,\n", | ||
" alg=alg,\n", | ||
" loc_type=loc_type,\n", | ||
" min_samples=2)\n", | ||
"\n", | ||
" return df_for_cluster\n", | ||
"\n", | ||
"\n", | ||
"suburb_df = expanded_trip_df_map[suburban_uuid]\n", | ||
"college_df = expanded_trip_df_map[college_campus_uuid]\n", | ||
"\n", | ||
"suburb_df = setup(suburb_df,\n", | ||
" alg='DBSCAN',\n", | ||
" loc_type='end',\n", | ||
" radii=[100, 150],\n", | ||
" cluster_unlabeled=False)\n", | ||
"college_df = setup(college_df,\n", | ||
" alg='DBSCAN',\n", | ||
" loc_type='end',\n", | ||
" radii=[150],\n", | ||
" cluster_unlabeled=False)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"alg = 'DBSCAN'\n", | ||
"loc_type = 'end'\n", | ||
"\n", | ||
"c1 = college_df.loc[(college_df[f'{loc_type}_{alg}_clusters_150_m'] == 1\n", | ||
" )].loc[:,\n", | ||
" ['end_lat', 'end_lon', 'purpose_confirm']].dropna()\n", | ||
"\n", | ||
"c2 = suburb_df.loc[(suburb_df[f'{loc_type}_{alg}_clusters_100_m'] == 1\n", | ||
" )].loc[:,\n", | ||
" ['end_lat', 'end_lon', 'purpose_confirm']].dropna()\n", | ||
"c3 = suburb_df.loc[(suburb_df[f'{loc_type}_{alg}_clusters_150_m'] == 1\n", | ||
" )].loc[:,\n", | ||
" ['end_lat', 'end_lon', 'purpose_confirm']].dropna()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.style.use(\"default\")\n", | ||
"\n", | ||
"\n", | ||
"def vis_pred(cluster):\n", | ||
" # setup up model\n", | ||
" X = cluster.loc[:, ['end_lon', 'end_lat']]\n", | ||
" Y = cluster.loc[:, 'purpose_confirm']\n", | ||
"\n", | ||
" model = make_pipeline(\n", | ||
" StandardScaler(),\n", | ||
" svm.SVC(kernel='rbf', gamma=0.05, C=1, decision_function_shape='ovr'))\n", | ||
"\n", | ||
" # train models\n", | ||
" # fit() wants Y as an array, not a column vector\n", | ||
" model.fit(X, Y.to_list())\n", | ||
"\n", | ||
" # vars for visualizing decision functions\n", | ||
" min_lat = X[['end_lat']].min()\n", | ||
" max_lat = X[['end_lat']].max()\n", | ||
" min_lon = X[['end_lon']].min()\n", | ||
" max_lon = X[['end_lon']].max()\n", | ||
" width = max_lon - min_lon\n", | ||
" height = max_lat - min_lat\n", | ||
" xx, yy = np.meshgrid(\n", | ||
" np.linspace(min_lon - 0.05 * width, max_lon + 0.05 * width, 500),\n", | ||
" np.linspace(min_lat - 0.05 * height, max_lat + 0.05 * height, 500))\n", | ||
"\n", | ||
" num_classes = len(model.classes_)\n", | ||
" label_map = {model.classes_[i]: i for i in range(num_classes)}\n", | ||
"\n", | ||
" pred = model.predict(np.c_[xx.ravel(), yy.ravel()])\n", | ||
" pred = [label_map[p] for p in pred]\n", | ||
" pred = np.array(pred).reshape((xx.shape[0], xx.shape[1]))\n", | ||
"\n", | ||
" fig, ax = plt.subplots(figsize=(5.45, 4))\n", | ||
"\n", | ||
" ## Prepare bins for the normalizer\n", | ||
" ## normalize the colors\n", | ||
" norm_bins = np.sort([*label_map.values()]) + 0.5\n", | ||
" norm_bins = np.insert(norm_bins, 0, np.min(norm_bins) - 1.0)\n", | ||
" norm = matplotlib.colors.BoundaryNorm(norm_bins, num_classes, clip=True)\n", | ||
"\n", | ||
" if num_classes <= 10:\n", | ||
" cm = plt.cm.tab10\n", | ||
" else:\n", | ||
" cm = plt.cm.tab20\n", | ||
"\n", | ||
" im = ax.imshow(\n", | ||
" pred,\n", | ||
" interpolation=\"none\",\n", | ||
" extent=(xx.min(), xx.max(), yy.min(), yy.max()),\n", | ||
" aspect=\"auto\",\n", | ||
" origin=\"lower\",\n", | ||
" cmap=cm,\n", | ||
" norm=norm,\n", | ||
" )\n", | ||
"\n", | ||
" ax.scatter(\n", | ||
" X['end_lon'],\n", | ||
" X['end_lat'],\n", | ||
" c=Y.map(label_map).to_list(),\n", | ||
" cmap=cm,\n", | ||
" edgecolors=\"k\",\n", | ||
" norm=norm,\n", | ||
" )\n", | ||
" ax.set_xticks([])\n", | ||
" ax.set_yticks([])\n", | ||
" fig.subplots_adjust(bottom=0.1, top=0.9, left=0.5, right=1)\n", | ||
"\n", | ||
" plt.axis('scaled')\n", | ||
" # plt.tight_layout()\n", | ||
" plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vis_pred(c3)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vis_pred(c1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vis_pred(c2)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.6" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "5edc29c2ed010d6458d71a83433b383a96a8cbd3efe8531bc90c4b8a5b8bcec9" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.