Skip to content

Commit

Permalink
refactor test_read.py (#134)
Browse files Browse the repository at this point in the history
* refactor test_read.py

* reformat using black

* refactor - run-isort
  • Loading branch information
vishalj0501 authored May 7, 2024
1 parent fa81c60 commit 1916eb5
Show file tree
Hide file tree
Showing 17 changed files with 624 additions and 606 deletions.
1 change: 1 addition & 0 deletions pvsite_datamodel/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Database Connection class."""

import logging

from sqlalchemy import create_engine
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Pydantic models."""

from datetime import datetime

from pydantic import BaseModel, Field
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/read/generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Read pv generation functions."""

import logging
import uuid
from datetime import datetime
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/read/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Functions for reading user data from the database. """

import logging
from datetime import datetime
from typing import List, Optional
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/write/data/dno.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
3. For each site add dno
"""

import logging
import os

Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/write/data/gsp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" GSP functions for UK regions. """

import logging
import os

Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/write/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Utils for GSP and DNO. """

import pyproj

# OSGB is also called "OSGB 1936 / British National Grid -- United
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/write/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Functions to read from the database and format.
"""

import logging

from pvsite_datamodel.read.user import get_user_by_email as get_user_by_db
Expand Down
1 change: 1 addition & 0 deletions pvsite_datamodel/write/user_and_site.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Tools for making fake users and sites in the database."""

import json
from datetime import datetime, timezone
from typing import Optional
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pytest fixtures for tests."""

import datetime as dt
import json
import uuid
Expand Down
50 changes: 50 additions & 0 deletions tests/read/test_get_api_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import datetime as dt

from pvsite_datamodel import APIRequestSQL
from pvsite_datamodel.read import (
get_all_last_api_request,
get_api_requests_for_one_user,
get_user_by_email,
)


def test_get_all_last_api_request(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test2"))

last_requests_sql = get_all_last_api_request(session=db_session)
assert len(last_requests_sql) == 1
assert last_requests_sql[0].url == "test2"
assert last_requests_sql[0].user_uuid == user.user_uuid


def test_get_api_requests_for_one_user(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(session=db_session, email=user.email)
assert len(requests_sql) == 1
assert requests_sql[0].url == "test"


def test_get_api_requests_for_one_user_start_datetime(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session,
email=user.email,
start_datetime=dt.datetime.now() + dt.timedelta(hours=1),
)
assert len(requests_sql) == 0


def test_get_api_requests_for_one_user_end_datetime(db_session):
user = get_user_by_email(session=db_session, email="[email protected]")
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session, email=user.email, end_datetime=dt.datetime.now() - dt.timedelta(hours=1)
)
assert len(requests_sql) == 0
123 changes: 123 additions & 0 deletions tests/read/test_get_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import datetime as dt
from typing import List

import pytest
from sqlalchemy.orm import Query

from pvsite_datamodel import SiteSQL
from pvsite_datamodel.read import get_pv_generation_by_sites, get_pv_generation_by_user_uuids
from pvsite_datamodel.write.user_and_site import create_site_group, create_user


class TestGetPVGenerationByUser:
"""Tests for the get_pv_generation_by_client function."""

def test_returns_all_generations_without_input_user(self, generations, db_session):
generations = get_pv_generation_by_user_uuids(session=db_session)

assert len(generations) == 40

def test_returns_all_generations_for_input_user(self, generations, db_session):
# associate site to one user
site: SiteSQL = db_session.query(SiteSQL).first()
site_group = create_site_group(db_session=db_session)
user = create_user(
session=db_session, site_group_name=site_group.site_group_name, email="[email protected]"
)
site_group.sites.append(site)

generations = get_pv_generation_by_user_uuids(
session=db_session, user_uuids=[user.user_uuid]
)

assert len(generations) == 10

def test_returns_all_generations_in_datetime_window(self, generations, db_session):
# associate site to one user
site: SiteSQL = db_session.query(SiteSQL).first()
site_group = create_site_group(db_session=db_session)
user = create_user(
session=db_session, site_group_name=site_group.site_group_name, email="[email protected]"
)
site_group.sites.append(site)

window_lower: dt.datetime = dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=7)
window_upper: dt.datetime = dt.datetime.now(dt.timezone.utc) + dt.timedelta(minutes=8)

generations = get_pv_generation_by_user_uuids(
session=db_session,
user_uuids=[user.user_uuid],
start_utc=window_lower,
end_utc=window_upper,
)

assert len(generations) == 7


class TestGetPVGenerationBySites:
"""Tests for the get_pv_generation_by_sites function."""

def test_gets_generation_for_single_input_site(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
site: SiteSQL = query.first()

generations = get_pv_generation_by_sites(session=db_session, site_uuids=[site.site_uuid])

assert len(generations) == 10
assert generations[0].start_utc is not None
assert generations[0].site is not None

def test_gets_generation_for_multiple_input_sites(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites]
)

assert len(generations) == 10 * len(sites)

def test_returns_empty_list_for_no_input_sites(self, generations, db_session):
generations = get_pv_generation_by_sites(session=db_session, site_uuids=[])

assert len(generations) == 0

def test_gets_generation_for_multiple_sum_total(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="total"
)

assert len(generations) == 10
assert generations[0].power_kw == 4
assert generations[1].power_kw == 8
assert (generations[2].start_utc - generations[1].start_utc).seconds == 60

def test_gets_generation_for_multiple_sum_gsp(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="gsp"
)
assert len(generations) == 10 * len(sites)

def test_gets_generation_for_multiple_sum_dno(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="dno"
)
assert len(generations) == 10 * len(sites)

def test_gets_generation_for_multiple_sum_error(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

with pytest.raises(ValueError): # noqa
_ = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="blah"
)
Loading

0 comments on commit 1916eb5

Please sign in to comment.