Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving dependence from custom branch's tour_model to master's trip_model #933

Merged
merged 10 commits into from
Sep 14, 2023
10 changes: 8 additions & 2 deletions emission/analysis/modelling/similarity/similarity_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ def similarity(self, a: List[float], b: List[float]) -> List[float]:
"""
pass

def similar(self, a: List[float], b: List[float], thresh: float) -> bool:
def similar(self, a: List[float], b: List[float], thresh: float, clusteringWay :str = 'origin-destination') -> bool:
"""compares the features, returning true if they are similar
within some threshold

:param a: features for a trip
:param b: features for another trip
:param thresh: threshold for similarity
:clusterinWay: clustering based on origin/destination/origin-destination-pair
humbleOldSage marked this conversation as resolved.
Show resolved Hide resolved
:return: true if the feature similarity is within some threshold
"""
similarity_values = self.similarity(a, b)
shankari marked this conversation as resolved.
Show resolved Hide resolved
is_similar = all(map(lambda sim: sim <= thresh, similarity_values))
shankari marked this conversation as resolved.
Show resolved Hide resolved
if clusteringWay == 'origin':
is_similar = similarity_values[0] <= thresh
elif clusteringWay == 'destination':
is_similar = similarity_values[1] <= thresh
else:
is_similar = all(map(lambda sim: sim <= thresh, similarity_values))
humbleOldSage marked this conversation as resolved.
Show resolved Hide resolved
return is_similar
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class label to apply:
self.sim_thresh = config['similarity_threshold_meters']
self.apply_cutoff = config['apply_cutoff']
self.is_incremental = config['incremental_evaluation']
if config.get('clustering_way') is None:
shankari marked this conversation as resolved.
Show resolved Hide resolved
self.clusteringWay='origin-destination' # previous default
else:
self.clusteringWay= config['clustering_way']
self.tripLabels=[]

self.bins: Dict[str, Dict] = {}

Expand Down Expand Up @@ -184,9 +189,11 @@ def _assign_bins(self, trips: List[ecwc.Confirmedtrip]):
logging.debug(f"adding trip to bin {bin_id} with features {trip_features}")
self.bins[bin_id]['feature_rows'].append(trip_features)
self.bins[bin_id]['labels'].append(trip_labels)
self.tripLabels.append(bin_id)
else:
# create new bin
new_bin_id = str(len(self.bins))
self.tripLabels.append(new_bin_id)
new_bin_record = {
'feature_rows': [trip_features],
'labels': [trip_labels],
Expand All @@ -204,7 +211,7 @@ def _find_matching_bin_id(self, trip_features: List[float]) -> Optional[str]:
:return: the id of a bin if a match was found, otherwise None
"""
for bin_id, bin_record in self.bins.items():
matches_bin = all([self.metric.similar(trip_features, bin_sample, self.sim_thresh)
matches_bin = all([self.metric.similar(trip_features, bin_sample, self.sim_thresh,self.clusteringWay)
for bin_sample in bin_record['feature_rows']])
if matches_bin:
return bin_id
Expand Down