Skip to content

Commit

Permalink
add read and tests #130 (#131)
Browse files Browse the repository at this point in the history
* add read and tests #130

* fix

* fix
  • Loading branch information
peterdudfield authored Apr 26, 2024
1 parent 1e2a526 commit a6e3985
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 4 deletions.
1 change: 1 addition & 0 deletions pvsite_datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .connection import DatabaseConnection
from .sqlmodels import (
APIRequestSQL,
ForecastSQL,
ForecastValueSQL,
GenerationSQL,
Expand Down
9 changes: 8 additions & 1 deletion pvsite_datamodel/read/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,11 @@
get_sites_from_user,
)
from .status import get_latest_status
from .user import get_all_site_groups, get_all_users, get_site_group_by_name, get_user_by_email
from .user import (
get_all_last_api_request,
get_all_site_groups,
get_all_users,
get_api_requests_for_one_user,
get_site_group_by_name,
get_user_by_email,
)
56 changes: 53 additions & 3 deletions pvsite_datamodel/read/user.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
""" Functions for reading user data from the database. """
import logging
from typing import List
from datetime import datetime
from typing import List, Optional

from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, contains_eager

from pvsite_datamodel.sqlmodels import SiteGroupSQL, UserSQL
from pvsite_datamodel.sqlmodels import APIRequestSQL, SiteGroupSQL, UserSQL

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,3 +91,52 @@ def get_all_site_groups(session: Session) -> List[SiteGroupSQL]:
site_groups = query.all()

return site_groups


def get_all_last_api_request(session: Session) -> List[APIRequestSQL]:
"""
Get all last api requests for all users.
:param session: database session
:return:
"""

last_requests_sql: [APIRequestSQL] = (
session.query(APIRequestSQL)
.distinct(APIRequestSQL.user_uuid)
.join(UserSQL)
.options(contains_eager(APIRequestSQL.user))
.populate_existing()
.order_by(APIRequestSQL.user_uuid, APIRequestSQL.created_utc.desc())
.all()
)

return last_requests_sql


def get_api_requests_for_one_user(
session: Session,
email: str,
start_datetime: Optional[datetime] = None,
end_datetime: Optional[datetime] = None,
) -> List[APIRequestSQL]:
"""
Get all api requests for one user.
:param session: database session
:param email: user email
:param start_datetime: only get api requests after start datetime
:param end_datetime: only get api requests before end datetime
"""

query = session.query(APIRequestSQL).join(UserSQL).filter(UserSQL.email == email)

if start_datetime is not None:
query = query.filter(APIRequestSQL.created_utc >= start_datetime)

if end_datetime is not None:
query = query.filter(APIRequestSQL.created_utc <= end_datetime)

api_requests: [APIRequestSQL] = query.order_by(APIRequestSQL.created_utc.desc()).all()

return api_requests
45 changes: 45 additions & 0 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy.orm import Query

from pvsite_datamodel import (
APIRequestSQL,
ForecastSQL,
ForecastValueSQL,
SiteGroupSQL,
Expand All @@ -17,9 +18,11 @@
)
from pvsite_datamodel.pydantic_models import LatitudeLongitudeLimits
from pvsite_datamodel.read import (
get_all_last_api_request,
get_all_site_groups,
get_all_sites,
get_all_users,
get_api_requests_for_one_user,
get_latest_forecast_values_by_site,
get_latest_status,
get_pv_generation_by_sites,
Expand Down Expand Up @@ -559,3 +562,45 @@ def test_get_site_list_min(db_session, user_with_sites):
lat_lon = LatitudeLongitudeLimits(latitude_min=50, longitude_min=2)
sites = get_sites_from_user(session=db_session, user=user_with_sites, lat_lon_limits=lat_lon)
assert len(sites) > 0


def test_get_all_last_api_request(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test2"))

last_requests_sql = get_all_last_api_request(session=db_session)
assert len(last_requests_sql) == 1
assert last_requests_sql[0].url == "test2"
assert last_requests_sql[0].user_uuid == user.user_uuid


def test_get_api_requests_for_one_user(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(session=db_session, email=user.email)
assert len(requests_sql) == 1
assert requests_sql[0].url == "test"


def test_get_api_requests_for_one_user_start_datetime(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session,
email=user.email,
start_datetime=dt.datetime.now() + dt.timedelta(hours=1),
)
assert len(requests_sql) == 0


def test_get_api_requests_for_one_user_end_datetime(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session, email=user.email, end_datetime=dt.datetime.now() - dt.timedelta(hours=1)
)
assert len(requests_sql) == 0

0 comments on commit a6e3985

Please sign in to comment.