Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make localdate queries behave like canonical datetime range filters #968

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions emission/core/wrapper/localdate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import arrow
import emission.core.wrapper.wrapperbase as ecwb

# specify the order of time units, from largest to smallest
DATETIME_UNITS = ['year', 'month', 'day', 'hour', 'minute', 'second']

class LocalDate(ecwb.WrapperBase):
"""
Supporting wrapper class that stores the expansions of the components
Expand Down
79 changes: 36 additions & 43 deletions emission/storage/decorations/local_date_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,40 @@

import emission.core.wrapper.localdate as ecwl

def get_range_query(field_name, start_local_dt, end_local_dt):
if list(start_local_dt.keys()) != list(end_local_dt.keys()):
raise RuntimeError("start_local_dt.keys() = %s does not match end_local_dt.keys() = %s" %
(list(start_local_dt.keys()), list(end_local_dt.keys())))
query_result = {}
for key in start_local_dt:
curr_field = "%s.%s" % (field_name, key)
gte_lte_query = {}
try:
start_int = int(start_local_dt[key])
except:
logging.info("start_local_dt[%s] = %s, not an integer, skipping" %
(key, start_local_dt[key]))
continue

try:
end_int = int(end_local_dt[key])
except:
logging.info("end_local_dt[%s] = %s, not an integer, skipping" %
(key, end_local_dt[key]))
continue

is_rollover = start_int > end_int

if is_rollover:
gte_lte_query = get_rollover_query(start_int, end_int)
else:
gte_lte_query = get_standard_query(start_int, end_int)

if len(gte_lte_query) > 0:
query_result.update({curr_field: gte_lte_query})
def get_range_query(field_prefix, start_ld, end_ld):
units = [u for u in ecwl.DATETIME_UNITS if u in start_ld and u in end_ld]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, what if u is in start_ld but not in end_ld. If I specify a month in start_ld but no month in end_ld, I intuitively want to get all trips after the start but be open-ended wrt the end. Alternatively, you should check that there always is a matching pair as I did before.

logging.debug(f'get_range_query: units = {units}')
try:
gt_query = get_comparison_query(field_prefix, start_ld, end_ld, units, 'gt')
lt_query = get_comparison_query(field_prefix, end_ld, start_ld, units, 'lt')
logging.debug(f'get_range_query: gt_query = {gt_query}, lt_query = {lt_query}')
return { "$and": [gt_query, lt_query] } if gt_query and lt_query else {}
except AssertionError as e:
logging.error(f'Invalid range from {str(start_ld)} to {str(end_ld)}: {str(e)}')
return None

def get_comparison_query(field_prefix, base_ld, limit_ld, units, gt_or_lt):
field_name = lambda i: f'{field_prefix}.{units[i]}'
and_conditions, or_conditions = [], []
tiebreaker_index = -1
for i, unit in enumerate(units):
# the range is inclusive, so if on the last unit we should use $lte / $gte instead of $lt / $gt
op = f'${gt_or_lt}e' if i == len(units)-1 else f'${gt_or_lt}'
if tiebreaker_index >= 0:
tiebreaker_conditions = [{ field_name(j): base_ld[units[j]] } for j in range(tiebreaker_index, i)]
tiebreaker_conditions.append({ field_name(i): { op: base_ld[unit] }})
or_conditions.append({ "$and": tiebreaker_conditions })
elif base_ld[unit] == limit_ld[unit]:
and_conditions.append({field_name(i): base_ld[unit]})
else:
logging.info("key %s exists, skipping because upper AND lower range are missing" % key)

logging.debug("In get_range_query, returning query %s" % query_result)
return query_result

def get_standard_query(start_int, end_int):
assert(start_int <= end_int)
return {'$gte': start_int, '$lte': end_int}

def get_rollover_query(start_int, end_int):
assert(start_int > end_int)
return {'$not': {'$gt': end_int, '$lt': start_int}}
assert (base_ld[unit] < limit_ld[unit]) if gt_or_lt == 'gt' else (base_ld[unit] > limit_ld[unit])
or_conditions.append({field_name(i): { op: base_ld[unit] }})
tiebreaker_index = i
if and_conditions and or_conditions:
return { "$and": and_conditions + [{ "$or": or_conditions }] }
elif and_conditions:
return { "$and": and_conditions }
elif or_conditions:
return { "$or": or_conditions }
else:
return {}
16 changes: 3 additions & 13 deletions emission/tests/storageTests/TestLocalDateQueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,14 @@ def testLocalRangeStandardQuery(self):

def testLocalRangeRolloverQuery(self):
"""
Search for all entries between 8:18 and 8:20 local time, both inclusive
Search for all entries between 8:18 and 9:08 local time, both inclusive
"""
start_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 8, 'minute': 18})
end_local_dt = ecwl.LocalDate({'year': 2015, 'month': 8, 'hour': 9, 'minute': 8})
final_query = {"user_id": self.testUUID}
final_query.update(esdl.get_range_query("data.local_dt", start_local_dt, end_local_dt))
entries = edb.get_timeseries_db().find(final_query).sort('data.ts', pymongo.ASCENDING)
self.assertEqual(448, edb.get_timeseries_db().count_documents(final_query))

entries_list = list(entries)

# Note that since this is a set of filters, as opposed to a range, this
# returns all entries between 18 and 8 in both hours.
# so 8:18 is valid, but so is 9:57
self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.hour, 8)
self.assertEqual(ecwe.Entry(entries_list[0]).data.local_dt.minute, 18)
self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.hour, 9)
self.assertEqual(ecwe.Entry(entries_list[-1]).data.local_dt.minute, 57)
entriesCnt = edb.get_timeseries_db().count_documents(final_query)
self.assertEqual(232, entriesCnt)

def testLocalMatchingQuery(self):
"""
Expand Down
Loading