Skip to content

Commit

Permalink
add start and end datetime filter
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Dec 4, 2023
1 parent 459a790 commit 77dc9f0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
29 changes: 19 additions & 10 deletions nowcasting_datamodel/read/read_user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Read user"""
import logging
from typing import List
from typing import Optional, List
from datetime import datetime

from sqlalchemy.orm.session import Session

Expand Down Expand Up @@ -62,21 +63,29 @@ def get_all_last_api_request(session: Session) -> List[APIRequestSQL]:
return last_requests_sql


def get_api_requests_for_one_user(session: Session, email: str) -> List[APIRequestSQL]:
def get_api_requests_for_one_user(
session: Session,
email: str,
start_datetime: Optional[datetime] = None,
end_datetime: Optional[datetime] = None,
) -> List[APIRequestSQL]:
"""
Get all api requests for one user
:param session: database session
:param email: user email
:return:
:param start_datetime: only get api requests after start datetime
:param end_datetime: only get api requests before end datetime
"""

api_requests = (
session.query(APIRequestSQL)
.join(UserSQL)
.filter(UserSQL.email == email)
.order_by(APIRequestSQL.created_utc.desc())
.all()
)
query = session.query(APIRequestSQL).join(UserSQL).filter(UserSQL.email == email)

if start_datetime is not None:
query = query.filter(APIRequestSQL.created_utc >= start_datetime)

if end_datetime is not None:
query = query.filter(APIRequestSQL.created_utc <= end_datetime)

api_requests = query.order_by(APIRequestSQL.created_utc.desc()).all()

return api_requests
22 changes: 22 additions & 0 deletions tests/read/test_read_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
get_api_requests_for_one_user,
)

from datetime import datetime, timedelta

def test_get_user(db_session):
db_session.add(UserSQL(email="[email protected]"))
Expand Down Expand Up @@ -41,3 +42,24 @@ def test_get_api_requests_for_one_user(db_session):
requests_sql = get_api_requests_for_one_user(session=db_session, email=user.email)
assert len(requests_sql) == 1
assert requests_sql[0].url == "test"


def test_get_api_requests_for_one_user_start_datetime(db_session):
user = get_user(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session, email=user.email, start_datetime=datetime.now() + timedelta(hours=1)
)
assert len(requests_sql) == 0


def test_get_api_requests_for_one_user_end_datetime(db_session):
user = get_user(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session, email=user.email, start_datetime=datetime.now() - timedelta(hours=1)
)
assert len(requests_sql) == 0

0 comments on commit 77dc9f0

Please sign in to comment.