From 278d68dc0fcf7908bb305714ff804fdad1658cb3 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Dec 2023 17:31:54 +0000 Subject: [PATCH] tidy, use new nowcasting_datamodel --- requirements.txt | 2 +- src/users.py | 28 +++++++++++----------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index 55b1538..d4c7373 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ altair==4.2.2 -nowcasting_datamodel==1.5.19 +nowcasting_datamodel==1.5.25 pvsite-datamodel==1.0.1 numpy==1.24.1 pandas==1.5.3 diff --git a/src/users.py b/src/users.py index 5d2262f..dc4434e 100644 --- a/src/users.py +++ b/src/users.py @@ -4,6 +4,10 @@ import os from nowcasting_datamodel.connection import DatabaseConnection from nowcasting_datamodel.models.api import UserSQL, APIRequestSQL +from nowcasting_datamodel.read.read_user import ( + get_all_last_api_request, + get_api_requests_for_one_user, +) from plots.users import make_api_requests_plot @@ -24,7 +28,10 @@ def user_page(): value=datetime.today() - timedelta(days=31), ) end_time = st.sidebar.date_input( - "End Date", min_value=datetime.today() - timedelta(days=365), max_value=datetime.today() + "End Date", + min_value=datetime.today() - timedelta(days=365), + max_value=datetime.today() + timedelta(days=1), + value=datetime.today() + timedelta(days=1), ) # get last call from the database @@ -32,13 +39,7 @@ def user_page(): connection = DatabaseConnection(url=url, echo=True) with connection.get_session() as session: - last_requests_sql = ( - session.query(APIRequestSQL) - .distinct(APIRequestSQL.user_uuid) - .join(UserSQL) - .order_by(APIRequestSQL.user_uuid, APIRequestSQL.created_utc.desc()) - .all() - ) + last_requests_sql = get_all_last_api_request(session=session) last_request = [ (last_request_sql.user.email, last_request_sql.created_utc) @@ -56,13 +57,8 @@ def user_page(): # get all calls for selected user with connection.get_session() as session: - api_requests_sql = ( - session.query(APIRequestSQL) - .join(UserSQL) - .where(UserSQL.email == email_selected) - .where(APIRequestSQL.created_utc >= start_time) - .where(APIRequestSQL.created_utc <= end_time) - .all() + api_requests_sql = get_api_requests_for_one_user( + session=session, email=email_selected, start_datetime=start_time, end_datetime=end_time ) api_requests = [ @@ -73,5 +69,3 @@ def user_page(): fig = make_api_requests_plot(api_requests, email_selected, end_time, start_time) st.plotly_chart(fig, theme="streamlit") - -