Skip to content

Commit

Permalink
Add initial unit tests for get_and_store_user_stats
Browse files Browse the repository at this point in the history
- Introduced comprehensive test suite to validate functionality of the `get_and_store_user_stats` method.
- Covered core scenarios including:
  - Correct data aggregation and storage.
  - Handling cases with no trips or missing data.
  - Verifying behavior with partial data and multiple invocations.
  - handling of invalid UUID inputs.
- setup and teardown to ensure test isolation and clean up user data.
- Included data insertion for confirmed trips, composite trips, and server API timestamps to simulate realistic scenarios.

This initial test suite establishes a baseline for ensuring reliability of the `get_and_store_user_stats` function.

remove

Modified test

Refactored the test and simplified it; validated all new user stats

added store to intake

Updated Based on requested changes

Removed unnecessary wrapper

Changed to Aug 21 for the time being
  • Loading branch information
TeachMeTW authored and shankari committed Dec 22, 2024
1 parent 39a3358 commit 957c8fa
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 24 deletions.
25 changes: 2 additions & 23 deletions emission/analysis/result/user_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,6 @@
import emission.storage.timeseries.abstract_timeseries as esta
import emission.core.wrapper.user as ecwu

TIME_FORMAT = 'YYYY-MM-DD HH:mm:ss'

def count_trips(ts: esta.TimeSeries, key_list: list, extra_query_list: Optional[list] = None) -> int:
"""
Counts the number of trips based on the provided query.
:param ts: The time series object.
:type ts: esta.TimeSeries
:param key_list: List of keys to filter trips.
:type key_list: list
:param extra_query_list: Additional queries, defaults to None.
:type extra_query_list: Optional[list], optional
:return: The count of trips.
:rtype: int
"""
count = ts.find_entries_count(key_list=key_list, extra_query_list=extra_query_list)
logging.debug(f"Counted {len(key_list)} trips with additional queries {extra_query_list}: {count}")
return count


def get_last_call_timestamp(ts: esta.TimeSeries) -> Optional[int]:
"""
Retrieves the last API call timestamp.
Expand Down Expand Up @@ -81,9 +61,8 @@ def get_and_store_user_stats(user_id: str, trip_key: str) -> None:
end_ts_result = ts.get_first_value_for_field(trip_key, "data.end_ts", pymongo.DESCENDING)
end_ts = None if end_ts_result == -1 else end_ts_result

total_trips = count_trips(ts, key_list=["analysis/confirmed_trip"])
labeled_trips = count_trips(
ts,
total_trips = ts.find_entries_count(key_list=["analysis/confirmed_trip"])
labeled_trips = ts.find_entries_count(
key_list=["analysis/confirmed_trip"],
extra_query_list=[{'data.user_input': {'$ne': {}}}]
)
Expand Down
121 changes: 121 additions & 0 deletions emission/tests/analysisTests/intakeTests/TestUserStat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import unicode_literals, print_function, division, absolute_import
import unittest
import uuid
import logging
import json
import os
import time
import pandas as pd

from builtins import *
from future import standard_library
standard_library.install_aliases()

# Standard imports
import emission.storage.json_wrappers as esj

# Our imports
import emission.core.get_database as edb
import emission.storage.timeseries.timequery as estt
import emission.storage.timeseries.abstract_timeseries as esta
import emission.storage.decorations.analysis_timeseries_queries as esda
import emission.core.wrapper.user as ecwu
import emission.net.api.stats as enac
import emission.pipeline.intake_stage as epi

# Test imports
import emission.tests.common as etc


class TestUserStats(unittest.TestCase):
def setUp(self):
"""
Set up the test environment by loading real example data for both Android and users.
"""
# Set up the real example data with entries
self.testUUID = uuid.uuid4()
with open("emission/tests/data/real_examples/shankari_2015-aug-21") as fp:
self.entries = json.load(fp, object_hook = esj.wrapped_object_hook)
# Retrieve the user profile
etc.setupRealExampleWithEntries(self)
profile = edb.get_profile_db().find_one({"user_id": self.testUUID})
if profile is None:
# Initialize the profile if it does not exist
edb.get_profile_db().insert_one({"user_id": self.testUUID})

#etc.runIntakePipeline(self.testUUID)
epi.run_intake_pipeline_for_user(self.testUUID, skip_if_no_new_data = False)
logging.debug("UUID = %s" % (self.testUUID))

def tearDown(self):
"""
Clean up the test environment by removing analysis configuration and deleting test data from databases.
"""

edb.get_timeseries_db().delete_many({"user_id": self.testUUID})
edb.get_pipeline_state_db().delete_many({"user_id": self.testUUID})
edb.get_analysis_timeseries_db().delete_many({"user_id": self.testUUID})
edb.get_profile_db().delete_one({"user_id": self.testUUID})

def testGetAndStoreUserStats(self):
"""
Test get_and_store_user_stats for the user to ensure that user statistics
are correctly aggregated and stored in the user profile.
"""

# Retrieve the updated user profile from the database
profile = edb.get_profile_db().find_one({"user_id": self.testUUID})

# Ensure that the profile exists
self.assertIsNotNone(profile, "User profile should exist after storing stats.")

# Verify that the expected fields are present
self.assertIn("total_trips", profile, "User profile should contain 'total_trips'.")
self.assertIn("labeled_trips", profile, "User profile should contain 'labeled_trips'.")
self.assertIn("pipeline_range", profile, "User profile should contain 'pipeline_range'.")
self.assertIn("last_call_ts", profile, "User profile should contain 'last_call_ts'.")

expected_total_trips = 5
expected_labeled_trips = 0

self.assertEqual(profile["total_trips"], expected_total_trips,
f"Expected total_trips to be {expected_total_trips}, got {profile['total_trips']}")
self.assertEqual(profile["labeled_trips"], expected_labeled_trips,
f"Expected labeled_trips to be {expected_labeled_trips}, got {profile['labeled_trips']}")

# Verify pipeline range
pipeline_range = profile.get("pipeline_range", {})
self.assertIn("start_ts", pipeline_range, "Pipeline range should contain 'start_ts'.")
self.assertIn("end_ts", pipeline_range, "Pipeline range should contain 'end_ts'.")

expected_start_ts = 1440168891.095
expected_end_ts = 1440209488.817

self.assertEqual(pipeline_range["start_ts"], expected_start_ts,
f"Expected start_ts to be {expected_start_ts}, got {pipeline_range['start_ts']}")
self.assertEqual(pipeline_range["end_ts"], expected_end_ts,
f"Expected end_ts to be {expected_end_ts}, got {pipeline_range['end_ts']}")

def testLastCall(self):
# Call the function with all required arguments
test_call_ts = time.time()
enac.store_server_api_time(self.testUUID, "test_call_ts", test_call_ts, 69420)
etc.runIntakePipeline(self.testUUID)

# Retrieve the profile from the database
profile = edb.get_profile_db().find_one({"user_id": self.testUUID})

# Verify that last_call_ts is updated correctly
expected_last_call_ts = test_call_ts
actual_last_call_ts = profile.get("last_call_ts")

self.assertEqual(
actual_last_call_ts,
expected_last_call_ts,
f"Expected last_call_ts to be {expected_last_call_ts}, got {actual_last_call_ts}"
)

if __name__ == '__main__':
# Configure logging for the test
etc.configLogging()
unittest.main()
5 changes: 4 additions & 1 deletion emission/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def getRealExampleEmail(testObj):
def fillExistingUUID(testObj):
userObj = ecwu.User.fromEmail(getRealExampleEmail(testObj))
print("Setting testUUID to %s" % userObj.uuid)
testObj.testUUID = userObj.uuir
testObj.testUUID = userObj.uuid

def getRegEmailIfPresent(testObj):
if hasattr(testObj, "evaluation") and testObj.evaluation:
Expand Down Expand Up @@ -193,6 +193,7 @@ def runIntakePipeline(uuid):
import emission.analysis.userinput.expectations as eaue
import emission.analysis.classification.inference.labels.pipeline as eacilp
import emission.analysis.plotting.composite_trip_creation as eapcc
import emission.analysis.result.user_stat as eaurs

eaum.match_incoming_user_inputs(uuid)
eaicf.filter_accuracy(uuid)
Expand All @@ -205,6 +206,8 @@ def runIntakePipeline(uuid):
eaue.populate_expectations(uuid)
eaum.create_confirmed_objects(uuid)
eapcc.create_composite_objects(uuid)
eaurs.get_and_store_user_stats(uuid, "analysis/composite_trip")


def configLogging():
"""
Expand Down

0 comments on commit 957c8fa

Please sign in to comment.