From 77dc9f0db4f53e0d00b23f2256b54df5ac3fca7f Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Mon, 4 Dec 2023 17:10:06 +0000 Subject: [PATCH] add start and end datetime filter --- nowcasting_datamodel/read/read_user.py | 29 +++++++++++++++++--------- tests/read/test_read_user.py | 22 +++++++++++++++++++ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/nowcasting_datamodel/read/read_user.py b/nowcasting_datamodel/read/read_user.py index 3f055a72..8552cae1 100644 --- a/nowcasting_datamodel/read/read_user.py +++ b/nowcasting_datamodel/read/read_user.py @@ -1,6 +1,7 @@ """ Read user""" import logging -from typing import List +from typing import Optional, List +from datetime import datetime from sqlalchemy.orm.session import Session @@ -62,21 +63,29 @@ def get_all_last_api_request(session: Session) -> List[APIRequestSQL]: return last_requests_sql -def get_api_requests_for_one_user(session: Session, email: str) -> List[APIRequestSQL]: +def get_api_requests_for_one_user( + session: Session, + email: str, + start_datetime: Optional[datetime] = None, + end_datetime: Optional[datetime] = None, +) -> List[APIRequestSQL]: """ Get all api requests for one user :param session: database session :param email: user email - :return: + :param start_datetime: only get api requests after start datetime + :param end_datetime: only get api requests before end datetime """ - api_requests = ( - session.query(APIRequestSQL) - .join(UserSQL) - .filter(UserSQL.email == email) - .order_by(APIRequestSQL.created_utc.desc()) - .all() - ) + query = session.query(APIRequestSQL).join(UserSQL).filter(UserSQL.email == email) + + if start_datetime is not None: + query = query.filter(APIRequestSQL.created_utc >= start_datetime) + + if end_datetime is not None: + query = query.filter(APIRequestSQL.created_utc <= end_datetime) + + api_requests = query.order_by(APIRequestSQL.created_utc.desc()).all() return api_requests diff --git a/tests/read/test_read_user.py b/tests/read/test_read_user.py index bc38f0ac..d3c57ba5 100644 --- a/tests/read/test_read_user.py +++ b/tests/read/test_read_user.py @@ -5,6 +5,7 @@ get_api_requests_for_one_user, ) +from datetime import datetime, timedelta def test_get_user(db_session): db_session.add(UserSQL(email="test@test.com")) @@ -41,3 +42,24 @@ def test_get_api_requests_for_one_user(db_session): requests_sql = get_api_requests_for_one_user(session=db_session, email=user.email) assert len(requests_sql) == 1 assert requests_sql[0].url == "test" + + +def test_get_api_requests_for_one_user_start_datetime(db_session): + user = get_user(session=db_session, email="test@test.com") + db_session.add(APIRequestSQL(user_uuid=user.uuid, url="test")) + + requests_sql = get_api_requests_for_one_user( + session=db_session, email=user.email, start_datetime=datetime.now() + timedelta(hours=1) + ) + assert len(requests_sql) == 0 + + +def test_get_api_requests_for_one_user_end_datetime(db_session): + user = get_user(session=db_session, email="test@test.com") + db_session.add(APIRequestSQL(user_uuid=user.uuid, url="test")) + + requests_sql = get_api_requests_for_one_user( + session=db_session, email=user.email, start_datetime=datetime.now() - timedelta(hours=1) + ) + assert len(requests_sql) == 0 +