diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2874504..2329e21 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "pycountry==24.6.1", "cryptography==42.0.8", "PyJWT==2.8.0", + "paramiko==3.4.0", ] license = {text = "GPL-3.0-or-later"} classifiers = [ @@ -37,6 +38,7 @@ Homepage = "https://github.com/kiwix/mirrors-qa" [project.scripts] update-mirrors = "mirrors_qa_backend.entrypoint:main" +mirrors-qa-scheduler = "mirrors_qa_backend.scheduler:main" [project.optional-dependencies] scripts = [ @@ -53,7 +55,6 @@ test = [ "pytest==8.0.0", "coverage==7.4.1", "Faker==25.8.0", - "paramiko==3.4.0", "httpx==0.27.0", ] dev = [ @@ -189,6 +190,8 @@ ignore = [ "S603", # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # Ignore warnings on missing timezone info + "DTZ005", "DTZ001", "DTZ006", ] unfixable = [ # Don't touch unused imports @@ -215,7 +218,7 @@ testpaths = ["tests"] pythonpath = [".", "src"] addopts = "--strict-markers" markers = [ - "num_tests: number of tests to create in the database (default: 10)", + "num_tests(num=10, *, status=..., country_code=...): create num tests in the database using status and/or country_code. Random data is chosen for country_code or status if either is not set", ] [tool.coverage.paths] diff --git a/backend/src/mirrors_qa_backend/cryptography.py b/backend/src/mirrors_qa_backend/cryptography.py index d4286c4..dd1f161 100644 --- a/backend/src/mirrors_qa_backend/cryptography.py +++ b/backend/src/mirrors_qa_backend/cryptography.py @@ -1,14 +1,13 @@ # pyright: strict, reportGeneralTypeIssues=false -import datetime +from pathlib import Path -import jwt +import paramiko from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError -from mirrors_qa_backend.settings import Settings def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) -> bool: @@ -44,13 +43,26 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes: ) -def generate_access_token(worker_id: str) -> str: - issue_time = datetime.datetime.now(datetime.UTC) - expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY) - payload = { - "iss": "mirrors-qa-backend", # issuer - "exp": expire_time.timestamp(), # expiration time - "iat": issue_time.timestamp(), # issued at - "subject": worker_id, - } - return jwt.encode(payload, key=Settings.JWT_SECRET, algorithm="HS256") +def load_private_key_from_path(private_key_fpath: Path) -> RSAPrivateKey: + with private_key_fpath.open("rb") as key_file: + return serialization.load_pem_private_key( + key_file.read(), password=None + ) # pyright: ignore[reportReturnType] + + +def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey: + return private_key.public_key() + + +def serialize_public_key(public_key: RSAPublicKey) -> bytes: + return public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +def get_public_key_fingerprint(public_key: RSAPublicKey) -> str: + """Compute the SHA256 fingerprint of the public key""" + return paramiko.RSAKey( + key=public_key + ).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType] diff --git a/backend/src/mirrors_qa_backend/db/__init__.py b/backend/src/mirrors_qa_backend/db/__init__.py index 4336f88..e84bfec 100644 --- a/backend/src/mirrors_qa_backend/db/__init__.py +++ b/backend/src/mirrors_qa_backend/db/__init__.py @@ -7,7 +7,8 @@ from sqlalchemy.orm import sessionmaker from mirrors_qa_backend import logger -from mirrors_qa_backend.db import mirrors, models +from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status from mirrors_qa_backend.extract import get_current_mirrors from mirrors_qa_backend.settings import Settings @@ -44,16 +45,16 @@ def initialize_mirrors() -> None: if nb_mirrors == 0: logger.info("No mirrors exist in database.") if not current_mirrors: - logger.info(f"No mirrors were found on {Settings.MIRRORS_URL!r}") + logger.info(f"No mirrors were found on {Settings.MIRRORS_URL}") return - result = mirrors.create_or_update_status(session, current_mirrors) + result = create_or_update_mirror_status(session, current_mirrors) logger.info( f"Registered {result.nb_mirrors_added} mirrors " - f"from {Settings.MIRRORS_URL!r}" + f"from {Settings.MIRRORS_URL}" ) else: logger.info(f"Found {nb_mirrors} mirrors in database.") - result = mirrors.create_or_update_status(session, current_mirrors) + result = create_or_update_mirror_status(session, current_mirrors) logger.info( f"Added {result.nb_mirrors_added} mirrors. " f"Disabled {result.nb_mirrors_disabled} mirrors." diff --git a/backend/src/mirrors_qa_backend/db/country.py b/backend/src/mirrors_qa_backend/db/country.py new file mode 100644 index 0000000..a3c173c --- /dev/null +++ b/backend/src/mirrors_qa_backend/db/country.py @@ -0,0 +1,16 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session as OrmSession + +from mirrors_qa_backend.db.models import Country + + +def get_countries(session: OrmSession, *country_codes: str) -> list[Country]: + return list( + session.scalars(select(Country).where(Country.code.in_(country_codes))).all() + ) + + +def get_country_or_none(session: OrmSession, country_code: str) -> Country | None: + return session.scalars( + select(Country).where(Country.code == country_code) + ).one_or_none() diff --git a/backend/src/mirrors_qa_backend/db/exceptions.py b/backend/src/mirrors_qa_backend/db/exceptions.py index ec251c9..3c330c5 100644 --- a/backend/src/mirrors_qa_backend/db/exceptions.py +++ b/backend/src/mirrors_qa_backend/db/exceptions.py @@ -7,3 +7,7 @@ def __init__(self, message: str, *args: object) -> None: class EmptyMirrorsError(Exception): """An empty list was used to update the mirrors in the database.""" + + +class DuplicatePrimaryKeyError(Exception): + """A database record with the same primary key exists.""" diff --git a/backend/src/mirrors_qa_backend/db/mirrors.py b/backend/src/mirrors_qa_backend/db/mirrors.py index 8e07777..904672d 100644 --- a/backend/src/mirrors_qa_backend/db/mirrors.py +++ b/backend/src/mirrors_qa_backend/db/mirrors.py @@ -50,13 +50,13 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: db_mirror.country = country session.add(db_mirror) logger.debug( - f"Registered new mirror: {db_mirror.id!r} for country: {country.name!r}" + f"Registered new mirror: {db_mirror.id} for country: {country.name}" ) nb_created += 1 return nb_created -def create_or_update_status( +def create_or_update_mirror_status( session: OrmSession, mirrors: list[schemas.Mirror] ) -> MirrorsUpdateResult: """Updates the status of mirrors in the database and creates any new mirrors. @@ -96,16 +96,16 @@ def create_or_update_status( for db_mirror_id, db_mirror in db_mirrors.items(): if db_mirror_id not in current_mirrors: logger.debug( - f"Disabling mirror: {db_mirror.id!r} for " - f"country: {db_mirror.country.name!r}" + f"Disabling mirror: {db_mirror.id} for " + f"country: {db_mirror.country.name}" ) db_mirror.enabled = False session.add(db_mirror) result.nb_mirrors_disabled += 1 elif not db_mirror.enabled: # re-enable mirror if it was disabled logger.debug( - f"Re-enabling mirror: {db_mirror.id!r} for " - f"country: {db_mirror.country.name!r}" + f"Re-enabling mirror: {db_mirror.id} for " + f"country: {db_mirror.country.name}" ) db_mirror.enabled = True session.add(db_mirror) diff --git a/backend/src/mirrors_qa_backend/db/models.py b/backend/src/mirrors_qa_backend/db/models.py index 3ad3a35..125095e 100644 --- a/backend/src/mirrors_qa_backend/db/models.py +++ b/backend/src/mirrors_qa_backend/db/models.py @@ -64,6 +64,8 @@ class Country(Base): cascade="all, delete-orphan", ) + tests: Mapped[list[Test]] = relationship(back_populates="country", init=False) + __table_args__ = (UniqueConstraint("name", "code"),) @@ -131,7 +133,11 @@ class Test(Base): ip_address: Mapped[IPv4Address | None] = mapped_column(default=None) # autonomous system based on IP asn: Mapped[str | None] = mapped_column(default=None) - country: Mapped[str | None] = mapped_column(default=None) # country based on IP + country_code: Mapped[str | None] = mapped_column( + ForeignKey("country.code"), + init=False, + default=None, + ) location: Mapped[str | None] = mapped_column(default=None) # city based on IP latency: Mapped[int | None] = mapped_column(default=None) # milliseconds download_size: Mapped[int | None] = mapped_column(default=None) # bytes @@ -142,3 +148,5 @@ class Test(Base): ) worker: Mapped[Worker | None] = relationship(back_populates="tests", init=False) + + country: Mapped[Country | None] = relationship(back_populates="tests", init=False) diff --git a/backend/src/mirrors_qa_backend/db/tests.py b/backend/src/mirrors_qa_backend/db/tests.py index 4aefd1e..e8e9013 100644 --- a/backend/src/mirrors_qa_backend/db/tests.py +++ b/backend/src/mirrors_qa_backend/db/tests.py @@ -3,10 +3,11 @@ from ipaddress import IPv4Address from uuid import UUID -from sqlalchemy import UnaryExpression, asc, desc, func, select +from sqlalchemy import UnaryExpression, asc, desc, func, select, update from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.country import get_country_or_none from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.settings import Settings @@ -24,7 +25,7 @@ def filter_test( test: models.Test, *, worker_id: str | None = None, - country: str | None = None, + country_code: str | None = None, statuses: list[StatusEnum] | None = None, ) -> bool: """Checks if a test has the same attribute as the provided attribute. @@ -34,7 +35,7 @@ def filter_test( """ if worker_id is not None and test.worker_id != worker_id: return False - if country is not None and test.country != country: + if country_code is not None and test.country_code != country_code: return False if statuses is not None and test.status not in statuses: return False @@ -51,7 +52,7 @@ def list_tests( session: OrmSession, *, worker_id: str | None = None, - country: str | None = None, + country_code: str | None = None, statuses: list[StatusEnum] | None = None, page_num: int = 1, page_size: int = Settings.MAX_PAGE_SIZE, @@ -87,7 +88,7 @@ def list_tests( select(func.count().over().label("total_records"), models.Test) .where( (models.Test.worker_id == worker_id) | (worker_id is None), - (models.Test.country == country) | (country is None), + (models.Test.country_code == country_code) | (country_code is None), (models.Test.status.in_(statuses)), ) .order_by(*order_by) @@ -113,7 +114,7 @@ def create_or_update_test( error: str | None = None, ip_address: IPv4Address | None = None, asn: str | None = None, - country: str | None = None, + country_code: str | None = None, location: str | None = None, latency: int | None = None, download_size: int | None = None, @@ -135,7 +136,9 @@ def create_or_update_test( test.error = error if error else test.error test.ip_address = ip_address if ip_address else test.ip_address test.asn = asn if asn else test.asn - test.country = country if country else test.country + test.country = ( + get_country_or_none(session, country_code) if country_code else test.country + ) test.location = location if location else test.location test.latency = latency if latency else test.latency test.download_size = download_size if download_size else test.download_size @@ -144,5 +147,59 @@ def create_or_update_test( test.started_on = started_on if started_on else test.started_on session.add(test) + session.flush() return test + + +def create_test( + session: OrmSession, + *, + worker_id: str | None = None, + status: StatusEnum = StatusEnum.PENDING, + error: str | None = None, + ip_address: IPv4Address | None = None, + asn: str | None = None, + country_code: str | None = None, + location: str | None = None, + latency: int | None = None, + download_size: int | None = None, + duration: int | None = None, + speed: float | None = None, + started_on: datetime.datetime | None = None, +) -> models.Test: + return create_or_update_test( + session, + test_id=None, + worker_id=worker_id, + status=status, + error=error, + ip_address=ip_address, + asn=asn, + country_code=country_code, + location=location, + latency=latency, + download_size=download_size, + duration=duration, + speed=speed, + started_on=started_on, + ) + + +def expire_tests( + session: OrmSession, interval: datetime.timedelta +) -> list[models.Test]: + """Change the status of PENDING tests created before the interval to MISSED""" + end = datetime.datetime.now() - interval + begin = datetime.datetime.fromtimestamp(0) + return list( + session.scalars( + update(models.Test) + .where( + models.Test.requested_on.between(begin, end), + models.Test.status == StatusEnum.PENDING, + ) + .values(status=StatusEnum.MISSED) + .returning(models.Test) + ).all() + ) diff --git a/backend/src/mirrors_qa_backend/db/worker.py b/backend/src/mirrors_qa_backend/db/worker.py index 17898f2..6c88425 100644 --- a/backend/src/mirrors_qa_backend/db/worker.py +++ b/backend/src/mirrors_qa_backend/db/worker.py @@ -1,10 +1,83 @@ +import datetime +from pathlib import Path + from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend.db import models +from mirrors_qa_backend.cryptography import ( + generate_public_key, + get_public_key_fingerprint, + load_private_key_from_path, + serialize_public_key, +) +from mirrors_qa_backend.db.country import get_countries +from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError +from mirrors_qa_backend.db.models import Worker +from mirrors_qa_backend.exceptions import PEMPrivateKeyLoadError + + +def get_worker(session: OrmSession, worker_id: str) -> Worker | None: + return session.scalars(select(Worker).where(Worker.id == worker_id)).one_or_none() + + +def create_worker( + session: OrmSession, + worker_id: str, + country_codes: list[str], + private_key_fpath: Path, +) -> Worker: + """Creates a worker using RSA private key.""" + if get_worker(session, worker_id) is not None: + raise DuplicatePrimaryKeyError(f"A worker with id {worker_id} already exists.") + try: + private_key = load_private_key_from_path(private_key_fpath) + except Exception as exc: + raise PEMPrivateKeyLoadError("unable to load private key from file") from exc + + public_key = generate_public_key(private_key) + public_key_pkcs8 = serialize_public_key(public_key).decode(encoding="ascii") + worker = Worker( + id=worker_id, + pubkey_pkcs8=public_key_pkcs8, + pubkey_fingerprint=get_public_key_fingerprint(public_key), + ) + session.add(worker) + + for db_country in get_countries(session, *country_codes): + db_country.worker_id = worker_id + session.add(db_country) + + return worker + + +def get_workers_last_seen_in_range( + session: OrmSession, begin: datetime.datetime, end: datetime.datetime +) -> list[Worker]: + """Get workers whose last_seen_on falls between begin and end dates""" + return list( + session.scalars( + select(Worker).where( + Worker.last_seen_on.between(begin, end), + ) + ).all() + ) + + +def get_idle_workers(session: OrmSession, interval: datetime.timedelta) -> list[Worker]: + end = datetime.datetime.now() - interval + begin = datetime.datetime(1970, 1, 1) + return get_workers_last_seen_in_range(session, begin, end) + + +def get_active_workers( + session: OrmSession, interval: datetime.timedelta +) -> list[Worker]: + end = datetime.datetime.now() + begin = end - interval + return get_workers_last_seen_in_range(session, begin, end) -def get_worker(session: OrmSession, worker_id: str) -> models.Worker | None: - return session.scalars( - select(models.Worker).where(models.Worker.id == worker_id) - ).one_or_none() +def update_worker_last_seen(session: OrmSession, worker: Worker) -> Worker: + worker.last_seen_on = datetime.datetime.now() + session.add(worker) + return worker diff --git a/backend/src/mirrors_qa_backend/entrypoint.py b/backend/src/mirrors_qa_backend/entrypoint.py index c8699c7..69f314e 100644 --- a/backend/src/mirrors_qa_backend/entrypoint.py +++ b/backend/src/mirrors_qa_backend/entrypoint.py @@ -1,8 +1,9 @@ import argparse import logging -from mirrors_qa_backend import db, logger -from mirrors_qa_backend.db import mirrors +from mirrors_qa_backend import logger +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status from mirrors_qa_backend.extract import get_current_mirrors @@ -17,8 +18,8 @@ def main(): if args.verbose: logger.setLevel(logging.DEBUG) - with db.Session.begin() as session: - mirrors.create_or_update_status(session, get_current_mirrors()) + with Session.begin() as session: + create_or_update_mirror_status(session, get_current_mirrors()) if __name__ == "__main__": diff --git a/backend/src/mirrors_qa_backend/enums.py b/backend/src/mirrors_qa_backend/enums.py index 7dab556..bc3369f 100644 --- a/backend/src/mirrors_qa_backend/enums.py +++ b/backend/src/mirrors_qa_backend/enums.py @@ -17,7 +17,7 @@ class TestSortColumnEnum(Enum): started_on = "started_on" status = "status" worker_id = "worker_id" - country = "country" + country_code = "country_code" city = "city" diff --git a/backend/src/mirrors_qa_backend/exceptions.py b/backend/src/mirrors_qa_backend/exceptions.py index 0c2f85c..0398309 100644 --- a/backend/src/mirrors_qa_backend/exceptions.py +++ b/backend/src/mirrors_qa_backend/exceptions.py @@ -17,3 +17,9 @@ class PEMPublicKeyLoadError(Exception): """Unable to deserialize a public key from PEM encoded data""" pass + + +class PEMPrivateKeyLoadError(Exception): + """Unable to deserialize a private key from PEM encoded data""" + + pass diff --git a/backend/src/mirrors_qa_backend/extract.py b/backend/src/mirrors_qa_backend/extract.py index e5e9026..ee8dc63 100644 --- a/backend/src/mirrors_qa_backend/extract.py +++ b/backend/src/mirrors_qa_backend/extract.py @@ -28,17 +28,21 @@ def is_country_row(tag: Tag) -> bool: return tag.name == "tr" and tag.findChild("td", class_="newregion") is None try: - resp = requests.get(Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT) + resp = requests.get( + Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT_SECONDS + ) resp.raise_for_status() except requests.RequestException as exc: - raise MirrorsRequestError from exc + raise MirrorsRequestError( + "network error while fetching mirrors from url" + ) from exc soup = BeautifulSoup(resp.text, features="html.parser") body = soup.find("tbody") if body is None or isinstance(body, NavigableString | int): raise MirrorsExtractError( - f"unable to parse mirrors information from {Settings.MIRRORS_URL!r}" + f"unable to parse mirrors information from {Settings.MIRRORS_URL}" ) mirrors: list[schemas.Mirror] = [] @@ -54,7 +58,7 @@ def is_country_row(tag: Tag) -> bool: try: country: Any = pycountry.countries.search_fuzzy(country_name)[0] except LookupError: - logger.error(f"Could not get information for country: {country_name!r}") + logger.error(f"Could not get information for country: {country_name}") continue else: mirrors.append( diff --git a/backend/src/mirrors_qa_backend/main.py b/backend/src/mirrors_qa_backend/main.py index df17118..06933e7 100644 --- a/backend/src/mirrors_qa_backend/main.py +++ b/backend/src/mirrors_qa_backend/main.py @@ -2,14 +2,14 @@ from fastapi import FastAPI -from mirrors_qa_backend import db +from mirrors_qa_backend.db import initialize_mirrors, upgrade_db_schema from mirrors_qa_backend.routes import auth, tests @asynccontextmanager async def lifespan(_: FastAPI): - db.upgrade_db_schema() - db.initialize_mirrors() + upgrade_db_schema() + initialize_mirrors() yield diff --git a/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py b/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py new file mode 100644 index 0000000..7175a56 --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py @@ -0,0 +1,40 @@ +"""add country code to tests + +Revision ID: 88e49e681048 +Revises: 5c376f6fb191 +Create Date: 2024-06-20 21:43:32.830017 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "88e49e681048" +down_revision = "5c376f6fb191" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("test", sa.Column("country_code", sa.String(), nullable=True)) + op.create_foreign_key( + op.f("fk_test_country_code_country"), + "test", + "country", + ["country_code"], + ["code"], + ) + op.drop_column("test", "country") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "test", sa.Column("country", sa.VARCHAR(), autoincrement=False, nullable=True) + ) + op.drop_constraint(op.f("fk_test_country_code_country"), "test", type_="foreignkey") + op.drop_column("test", "country_code") + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/routes/auth.py b/backend/src/mirrors_qa_backend/routes/auth.py index 59ace0e..1463f78 100644 --- a/backend/src/mirrors_qa_backend/routes/auth.py +++ b/backend/src/mirrors_qa_backend/routes/auth.py @@ -5,12 +5,19 @@ from fastapi import APIRouter, Header -from mirrors_qa_backend import cryptography, logger, schemas -from mirrors_qa_backend.db import worker +from mirrors_qa_backend import logger +from mirrors_qa_backend.cryptography import verify_signed_message +from mirrors_qa_backend.db.worker import get_worker from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError -from mirrors_qa_backend.routes import http_errors from mirrors_qa_backend.routes.dependencies import DbSession -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.routes.http_errors import ( + BadRequestError, + ForbiddenError, + UnauthorizedError, +) +from mirrors_qa_backend.schemas import Token +from mirrors_qa_backend.settings.api import APISettings +from mirrors_qa_backend.tokens import generate_access_token router = APIRouter(prefix="/auth", tags=["auth"]) @@ -25,52 +32,50 @@ def authenticate_worker( x_sshauth_signature: Annotated[ str, Header(description="signature, base64-encoded") ], -) -> schemas.Token: +) -> Token: """Authenticate using signed message and generate tokens.""" try: signature = base64.standard_b64decode(x_sshauth_signature) except binascii.Error as exc: - raise http_errors.BadRequestError( - "Invalid signature format (not base64)" - ) from exc + raise BadRequestError("Invalid signature format (not base64)") from exc try: # decode message: worker_id:timestamp(UTC ISO) worker_id, timestamp_str = x_sshauth_message.split(":", 1) timestamp = datetime.datetime.fromisoformat(timestamp_str) except ValueError as exc: - raise http_errors.BadRequestError("Invalid message format.") from exc + raise BadRequestError("Invalid message format.") from exc # verify timestamp is less than MESSAGE_VALIDITY if ( datetime.datetime.now(datetime.UTC) - timestamp - ).total_seconds() > Settings.MESSAGE_VALIDITY: - raise http_errors.UnauthorizedError( + ).total_seconds() > APISettings.MESSAGE_VALIDITY: + raise UnauthorizedError( "Difference betweeen message time and server time is " - f"greater than {Settings.MESSAGE_VALIDITY}s" + f"greater than {APISettings.MESSAGE_VALIDITY}s" ) # verify worker with worker_id exists in database - db_worker = worker.get_worker(session, worker_id) + db_worker = get_worker(session, worker_id) if db_worker is None: - raise http_errors.UnauthorizedError() + raise UnauthorizedError() # verify signature of message with worker's public keys try: - if not cryptography.verify_signed_message( + if not verify_signed_message( bytes(db_worker.pubkey_pkcs8, encoding="ascii"), signature, bytes(x_sshauth_message, encoding="ascii"), ): - raise http_errors.UnauthorizedError() + raise UnauthorizedError() except PEMPublicKeyLoadError as exc: logger.exception("error while verifying message using public key") - raise http_errors.ForbiddenError("Unable to load public_key") from exc + raise ForbiddenError("Unable to load public_key") from exc # generate tokens - access_token = cryptography.generate_access_token(worker_id) - return schemas.Token( + access_token = generate_access_token(worker_id) + return Token( access_token=access_token, token_type="bearer", - expires_in=datetime.timedelta(hours=Settings.TOKEN_EXPIRY).total_seconds(), + expires_in=datetime.timedelta(hours=APISettings.TOKEN_EXPIRY).total_seconds(), ) diff --git a/backend/src/mirrors_qa_backend/routes/dependencies.py b/backend/src/mirrors_qa_backend/routes/dependencies.py index 6d317aa..963c811 100644 --- a/backend/src/mirrors_qa_backend/routes/dependencies.py +++ b/backend/src/mirrors_qa_backend/routes/dependencies.py @@ -9,9 +9,11 @@ from sqlalchemy.orm import Session from mirrors_qa_backend import schemas -from mirrors_qa_backend.db import gen_dbsession, models, tests, worker -from mirrors_qa_backend.routes import http_errors -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.db import gen_dbsession, models +from mirrors_qa_backend.db.tests import get_test as db_get_test +from mirrors_qa_backend.db.worker import get_worker +from mirrors_qa_backend.routes.http_errors import NotFoundError, UnauthorizedError +from mirrors_qa_backend.settings.api import APISettings DbSession = Annotated[Session, Depends(gen_dbsession)] @@ -25,22 +27,22 @@ def get_current_worker( ) -> models.Worker: token = authorization.credentials try: - jwt_claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"]) + jwt_claims = jwt.decode(token, APISettings.JWT_SECRET, algorithms=["HS256"]) except jwt_exceptions.ExpiredSignatureError as exc: - raise http_errors.UnauthorizedError("Token has expired.") from exc + raise UnauthorizedError("Token has expired.") from exc except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError) as exc: - raise http_errors.UnauthorizedError from exc + raise UnauthorizedError from exc try: claims = schemas.JWTClaims(**jwt_claims) except PydanticValidationError as exc: - raise http_errors.UnauthorizedError from exc + raise UnauthorizedError from exc # At this point, we know that the JWT is all OK and we can # trust the data in it. We extract the worker_id from the claims - db_worker = worker.get_worker(session, claims.subject) + db_worker = get_worker(session, claims.subject) if db_worker is None: - raise http_errors.UnauthorizedError() + raise UnauthorizedError() return db_worker @@ -49,9 +51,9 @@ def get_current_worker( def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Test: """Fetches the test specified in the request.""" - test = tests.get_test(session, test_id) + test = db_get_test(session, test_id) if test is None: - raise http_errors.NotFoundError(f"Test with id {test_id} does not exist.") + raise NotFoundError(f"Test with id {test_id} does not exist.") return test @@ -60,4 +62,4 @@ def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Te def verify_worker_owns_test(worker: CurrentWorker, test: RetrievedTest): if test.worker_id != worker.id: - raise http_errors.UnauthorizedError("Insufficient privileges to update test.") + raise UnauthorizedError("Insufficient privileges to update test.") diff --git a/backend/src/mirrors_qa_backend/routes/tests.py b/backend/src/mirrors_qa_backend/routes/tests.py index 5c5e14e..a463368 100644 --- a/backend/src/mirrors_qa_backend/routes/tests.py +++ b/backend/src/mirrors_qa_backend/routes/tests.py @@ -3,8 +3,10 @@ from fastapi import APIRouter, Depends, Query from fastapi import status as status_codes -from mirrors_qa_backend import schemas, serializer -from mirrors_qa_backend.db import tests +from mirrors_qa_backend import schemas +from mirrors_qa_backend.db.tests import create_or_update_test +from mirrors_qa_backend.db.tests import list_tests as db_list_tests +from mirrors_qa_backend.db.worker import update_worker_last_seen from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.routes.dependencies import ( CurrentWorker, @@ -12,6 +14,8 @@ RetrievedTest, verify_worker_owns_test, ) +from mirrors_qa_backend.schemas import Test, TestsList, calculate_pagination_metadata +from mirrors_qa_backend.serializer import serialize_test from mirrors_qa_backend.settings import Settings router = APIRouter(prefix="/tests", tags=["tests"]) @@ -27,7 +31,7 @@ def list_tests( session: DbSession, worker_id: Annotated[str | None, Query()] = None, - country: Annotated[str | None, Query(min_length=3)] = None, + country_code: Annotated[str | None, Query(min_length=2, max_length=2)] = None, status: Annotated[list[StatusEnum] | None, Query()] = None, page_size: Annotated[ int, Query(le=Settings.MAX_PAGE_SIZE, ge=1) @@ -35,11 +39,11 @@ def list_tests( page_num: Annotated[int, Query(ge=1)] = 1, sort_by: Annotated[TestSortColumnEnum, Query()] = TestSortColumnEnum.requested_on, order: Annotated[SortDirectionEnum, Query()] = SortDirectionEnum.asc, -) -> schemas.TestsList: - result = tests.list_tests( +) -> TestsList: + result = db_list_tests( session, worker_id=worker_id, - country=country, + country_code=country_code, statuses=status, page_size=page_size, page_num=page_num, @@ -47,8 +51,8 @@ def list_tests( sort_direction=order, ) return schemas.TestsList( - tests=[serializer.serialize_test(test) for test in result.tests], - metadata=schemas.calculate_pagination_metadata( + tests=[serialize_test(test) for test in result.tests], + metadata=calculate_pagination_metadata( result.nb_tests, page_size=page_size, current_page=page_num ), ) @@ -64,8 +68,8 @@ def list_tests( }, }, ) -def get_test(test: RetrievedTest) -> schemas.Test: - return serializer.serialize_test(test) +def get_test(test: RetrievedTest) -> Test: + return serialize_test(test) @router.patch( @@ -78,26 +82,26 @@ def get_test(test: RetrievedTest) -> schemas.Test: ) def update_test( session: DbSession, - worker: CurrentWorker, + current_worker: CurrentWorker, test: RetrievedTest, update: schemas.UpdateTestModel, -) -> schemas.Test: +) -> Test: data = update.model_dump(exclude_unset=True) body = schemas.UpdateTestModel().model_copy(update=data) - updated_test = tests.create_or_update_test( + updated_test = create_or_update_test( session, test_id=test.id, - worker_id=worker.id, + worker_id=current_worker.id, status=body.status, error=body.error, ip_address=body.ip_address, asn=body.asn, - country=body.country, + country_code=body.country_code, location=body.location, latency=body.latency, download_size=body.download_size, duration=body.duration, speed=body.speed, ) - - return serializer.serialize_test(updated_test) + update_worker_last_seen(session, current_worker) + return serialize_test(updated_test) diff --git a/backend/src/mirrors_qa_backend/scheduler.py b/backend/src/mirrors_qa_backend/scheduler.py new file mode 100644 index 0000000..da008bf --- /dev/null +++ b/backend/src/mirrors_qa_backend/scheduler.py @@ -0,0 +1,77 @@ +import datetime +import time + +from mirrors_qa_backend import logger +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.tests import create_test, expire_tests, list_tests +from mirrors_qa_backend.db.worker import get_idle_workers +from mirrors_qa_backend.enums import StatusEnum +from mirrors_qa_backend.settings.scheduler import SchedulerSettings + + +def main(): + while True: + with Session.begin() as session: + # expire tests whose results have not been reported + expired_tests = expire_tests( + session, + interval=datetime.timedelta(hours=SchedulerSettings.EXPIRE_TEST_HOURS), + ) + for expired_test in expired_tests: + logger.info( + f"Expired test {expired_test.id}, " + f"country: {expired_test.country_code}, " + f"worker: {expired_test.worker_id}" + ) + + idle_workers = get_idle_workers( + session, + interval=datetime.timedelta(hours=SchedulerSettings.IDLE_WORKER_HOURS), + ) + if not idle_workers: + logger.info("No idle workers found.") + + # Create tests for the countries the worker is responsible for.. + for idle_worker in idle_workers: + if not idle_worker.countries: + logger.info( + f"No countries registered for idle worker {idle_worker.id}" + ) + continue + for country in idle_worker.countries: + # While we have expired "unreported" tests, it is possible that + # a test for a country might still be PENDING as the interval + # for expiration and that of the scheduler might overlap. + # In such scenarios, we skip creating a test for that country. + pending_tests = list_tests( + session, + worker_id=idle_worker.id, + statuses=[StatusEnum.PENDING], + country_code=country.code, + ) + + if pending_tests.nb_tests: + logger.info( + "Skipping creation of new test entries for " + f"{idle_worker.id} as {pending_tests.nb_tests} " + f"tests are still pending for country {country.name}" + ) + continue + + new_test = create_test( + session=session, + worker_id=idle_worker.id, + country_code=country.code, + status=StatusEnum.PENDING, + ) + logger.info( + f"Created new test {new_test.id} for worker " + f"{idle_worker.id} in country {country.name}" + ) + + sleep_interval = datetime.timedelta( + hours=SchedulerSettings.SCHEDULER_SLEEP_HOURS + ).total_seconds() + + logger.info(f"Sleeping for {sleep_interval} seconds.") + time.sleep(sleep_interval) diff --git a/backend/src/mirrors_qa_backend/schemas.py b/backend/src/mirrors_qa_backend/schemas.py index 16d61cc..33583b5 100644 --- a/backend/src/mirrors_qa_backend/schemas.py +++ b/backend/src/mirrors_qa_backend/schemas.py @@ -39,7 +39,7 @@ class UpdateTestModel(BaseModel): isp: str | None = None ip_address: IPv4Address | None = None asn: str | None = None - country: str | None = None + country_code: str | None = None location: str | None = None latency: int | None = None download_size: int | None = None diff --git a/backend/src/mirrors_qa_backend/serializer.py b/backend/src/mirrors_qa_backend/serializer.py index c9b1b63..fd4d907 100644 --- a/backend/src/mirrors_qa_backend/serializer.py +++ b/backend/src/mirrors_qa_backend/serializer.py @@ -12,7 +12,7 @@ def serialize_test(test: models.Test) -> schemas.Test: isp=test.isp, ip_address=test.ip_address, asn=test.asn, - country=test.country, + country_code=test.country_code, location=test.location, latency=test.latency, download_size=test.download_size, diff --git a/backend/src/mirrors_qa_backend/settings.py b/backend/src/mirrors_qa_backend/settings/__init__.py similarity index 68% rename from backend/src/mirrors_qa_backend/settings.py rename to backend/src/mirrors_qa_backend/settings/__init__.py index d1cf20f..9e27132 100644 --- a/backend/src/mirrors_qa_backend/settings.py +++ b/backend/src/mirrors_qa_backend/settings/__init__.py @@ -15,6 +15,11 @@ class Settings: """Shared backend configuration""" DATABASE_URL: str = getenv("POSTGRES_URI", mandatory=True) + DEBUG = bool(getenv("DEBUG", default=False)) + # number of seconds before requests time out + REQUESTS_TIMEOUT_SECONDS = int(getenv("REQUESTS_TIMEOUT_SECONDS", default=5)) + # maximum number of items to return from a request/query + MAX_PAGE_SIZE = int(getenv("PAGE_SIZE", default=20)) # url to fetch the list of mirrors MIRRORS_URL: str = getenv( "MIRRORS_LIST_URL", default="https://download.kiwix.org/mirrors.html" @@ -23,13 +28,3 @@ class Settings: MIRRORS_EXCLUSION_LIST = getenv( "EXCLUDED_MIRRORS", default="mirror.isoc.org.il" ).split(",") - DEBUG = bool(getenv("DEBUG", default=False)) - # number of seconds before requests time out - REQUESTS_TIMEOUT = int(getenv("REQUESTS_TIMEOUT", default=5)) - # maximum number of items to return from a request - MAX_PAGE_SIZE = int(getenv("PAGE_SIZE", default=20)) - # number of seconds before a message expire - MESSAGE_VALIDITY = int(getenv("MESSAGE_VALIDITY", default=60)) - # number of hours before access tokens expire - TOKEN_EXPIRY = int(getenv("TOKEN_EXPIRY", default=24)) - JWT_SECRET: str = getenv("JWT_SECRET", mandatory=True) diff --git a/backend/src/mirrors_qa_backend/settings/api.py b/backend/src/mirrors_qa_backend/settings/api.py new file mode 100644 index 0000000..f61b76b --- /dev/null +++ b/backend/src/mirrors_qa_backend/settings/api.py @@ -0,0 +1,11 @@ +from mirrors_qa_backend.settings import Settings, getenv + + +class APISettings(Settings): + """Backend API settings""" + + JWT_SECRET: str = getenv("JWT_SECRET", mandatory=True) + # number of seconds before a message expire + MESSAGE_VALIDITY = int(getenv("MESSAGE_VALIDITY", default=60)) + # number of hours before access tokens expire + TOKEN_EXPIRY = int(getenv("TOKEN_EXPIRY", default=24)) diff --git a/backend/src/mirrors_qa_backend/settings/scheduler.py b/backend/src/mirrors_qa_backend/settings/scheduler.py new file mode 100644 index 0000000..afd39c7 --- /dev/null +++ b/backend/src/mirrors_qa_backend/settings/scheduler.py @@ -0,0 +1,12 @@ +from mirrors_qa_backend.settings import Settings, getenv + + +class SchedulerSettings(Settings): + """Scheduler settings""" + + # number of hours the scheduler sleeps before attempting to create tests + SCHEDULER_SLEEP_HOURS = int(getenv("SCHEDULER_SLEEP_INTERVAL", default=3)) + # number of hours into the past to determine if a worker is idle + IDLE_WORKER_HOURS = int(getenv("IDLE_WORKER_INTERVAL", default=1)) + # number of hours to wait before expiring a test whose data never arrived + EXPIRE_TEST_HOURS = int(getenv("EXPIRE_TEST_INTERVAL", default=24)) diff --git a/backend/src/mirrors_qa_backend/tokens.py b/backend/src/mirrors_qa_backend/tokens.py new file mode 100644 index 0000000..6ed6a79 --- /dev/null +++ b/backend/src/mirrors_qa_backend/tokens.py @@ -0,0 +1,17 @@ +import datetime + +import jwt + +from mirrors_qa_backend.settings.api import APISettings + + +def generate_access_token(worker_id: str) -> str: + issue_time = datetime.datetime.now(datetime.UTC) + expire_time = issue_time + datetime.timedelta(hours=APISettings.TOKEN_EXPIRY) + payload = { + "iss": "mirrors-qa-backend", # issuer + "exp": expire_time.timestamp(), # expiration time + "iat": issue_time.timestamp(), # issued at + "subject": worker_id, + } + return jwt.encode(payload, key=APISettings.JWT_SECRET, algorithm="HS256") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 1c20b96..2ca6a5d 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,7 @@ from typing import Any import paramiko +import pycountry import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -13,7 +14,9 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.cryptography import sign_message -from mirrors_qa_backend.db import Session, models +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.country import get_country_or_none +from mirrors_qa_backend.db.models import Base, Country, Test, Worker from mirrors_qa_backend.enums import StatusEnum @@ -22,8 +25,8 @@ def dbsession() -> Generator[OrmSession, None, None]: with Session.begin() as session: # Ensure we are starting with an empty database engine = session.get_bind() - models.Base.metadata.drop_all(bind=engine) - models.Base.metadata.create_all(bind=engine) + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) yield session session.rollback() @@ -32,9 +35,9 @@ def dbsession() -> Generator[OrmSession, None, None]: def data_gen(faker: Faker) -> Faker: """Adds additional providers to faker. - Registers test_country and test_status as providers. + Registers test_country_code and test_status as providers. data_gen.test_status() returns a status. - data_gen.test_country() returns a country. + data_gen.test_country_code() returns a country code. All other providers from Faker can be used accordingly. """ test_status_provider = DynamicProvider( @@ -42,11 +45,11 @@ def data_gen(faker: Faker) -> Faker: elements=list(StatusEnum), ) test_country_provider = DynamicProvider( - provider_name="test_country", + provider_name="test_country_code", elements=[ - "Nigeria", - "Canada", - "Brazil", + "ng", + "fr", + "us", ], ) faker.add_provider(test_status_provider) @@ -58,8 +61,8 @@ def data_gen(faker: Faker) -> Faker: @pytest.fixture def tests( - dbsession: OrmSession, data_gen: Faker, worker: models.Worker, request: Any -) -> list[models.Test]: + dbsession: OrmSession, data_gen: Faker, worker: Worker, request: Any +) -> list[Test]: """Adds tests to the database using the num_test mark.""" mark = request.node.get_closest_marker("num_tests") if mark and len(mark.args) > 0: @@ -67,15 +70,29 @@ def tests( else: num_tests = 10 - tests = [ - models.Test( - status=data_gen.test_status(), - country=data_gen.test_country(), + status = mark.kwargs.get("status", None) + country_code = mark.kwargs.get("country_code", None) + + for _ in range(num_tests): + test = Test(status=status if status else data_gen.test_status()) + selected_country_code = ( + country_code if country_code else data_gen.test_country_code() ) - for _ in range(num_tests) - ] - worker.tests = tests - dbsession.add_all(tests) + if country := get_country_or_none(dbsession, selected_country_code): + test.country = country + else: + country = Country( + code=selected_country_code.lower(), + name=pycountry.countries.get( + alpha_2=selected_country_code + ).name, # pyright: ignore [reportOptionalMemberAccess] + ) + dbsession.add(country) + test.country = country + + test.worker = worker + dbsession.add(test) + dbsession.flush() return worker.tests @@ -91,14 +108,23 @@ def public_key(private_key: RSAPrivateKey) -> RSAPublicKey: return private_key.public_key() +@pytest.fixture(scope="session") +def private_key_data(private_key: RSAPrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + @pytest.fixture -def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> models.Worker: +def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> Worker: pubkey_pkcs8 = public_key.public_bytes( serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ).decode(encoding="ascii") - worker = models.Worker( + worker = Worker( id="test", pubkey_fingerprint=paramiko.RSAKey(key=public_key).fingerprint, # type: ignore pubkey_pkcs8=pubkey_pkcs8, @@ -108,7 +134,7 @@ def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> models.Worker: @pytest.fixture -def auth_message(worker: models.Worker) -> str: +def auth_message(worker: Worker) -> str: return f"{worker.id}:{datetime.datetime.now(datetime.UTC).isoformat()}" diff --git a/backend/tests/db/test_mirrors.py b/backend/tests/db/test_mirrors.py index 11c2b09..62f138d 100644 --- a/backend/tests/db/test_mirrors.py +++ b/backend/tests/db/test_mirrors.py @@ -2,9 +2,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend import db, schemas, serializer -from mirrors_qa_backend.db import mirrors, models +from mirrors_qa_backend import schemas +from mirrors_qa_backend.db import count_from_stmt, models from mirrors_qa_backend.db.exceptions import EmptyMirrorsError +from mirrors_qa_backend.db.mirrors import create_mirrors, create_or_update_mirror_status +from mirrors_qa_backend.serializer import serialize_mirror @pytest.fixture(scope="session") @@ -29,7 +31,7 @@ def db_mirror() -> models.Mirror: @pytest.fixture(scope="session") def schema_mirror(db_mirror: models.Mirror) -> schemas.Mirror: - return serializer.serialize_mirror(db_mirror) + return serialize_mirror(db_mirror) @pytest.fixture(scope="session") @@ -55,20 +57,20 @@ def new_schema_mirror() -> schemas.Mirror: def test_db_empty(dbsession: OrmSession): - assert db.count_from_stmt(dbsession, select(models.Country)) == 0 + assert count_from_stmt(dbsession, select(models.Country)) == 0 def test_create_no_mirrors(dbsession: OrmSession): - assert mirrors.create_mirrors(dbsession, []) == 0 + assert create_mirrors(dbsession, []) == 0 def test_create_mirrors(dbsession: OrmSession, schema_mirror: schemas.Mirror): - assert mirrors.create_mirrors(dbsession, [schema_mirror]) == 1 + assert create_mirrors(dbsession, [schema_mirror]) == 1 def test_raises_empty_mirrors_error(dbsession: OrmSession): with pytest.raises(EmptyMirrorsError): - mirrors.create_or_update_status(dbsession, []) + create_or_update_mirror_status(dbsession, []) def test_register_new_mirror( @@ -78,7 +80,7 @@ def test_register_new_mirror( new_schema_mirror: schemas.Mirror, ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status( + result = create_or_update_mirror_status( dbsession, [schema_mirror, new_schema_mirror] ) assert result.nb_mirrors_added == 1 @@ -90,7 +92,7 @@ def test_disable_old_mirror( new_schema_mirror: schemas.Mirror, ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [new_schema_mirror]) + result = create_or_update_mirror_status(dbsession, [new_schema_mirror]) assert result.nb_mirrors_disabled == 1 @@ -98,7 +100,7 @@ def test_no_mirrors_disabled( dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_disabled == 0 @@ -106,7 +108,7 @@ def test_no_mirrors_added( dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 0 @@ -132,8 +134,8 @@ def test_re_enable_existing_mirror( dbsession.add(db_mirror) # Update the status of the mirror - schema_mirror = serializer.serialize_mirror(db_mirror) + schema_mirror = serialize_mirror(db_mirror) schema_mirror.enabled = True - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 1 diff --git a/backend/tests/db/test_tests.py b/backend/tests/db/test_tests.py index 00d63ee..4adfbc5 100644 --- a/backend/tests/db/test_tests.py +++ b/backend/tests/db/test_tests.py @@ -6,20 +6,26 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models -from mirrors_qa_backend.db import tests as db_tests +from mirrors_qa_backend.db.tests import ( + create_or_update_test, + expire_tests, + filter_test, + get_test, + list_tests, +) from mirrors_qa_backend.enums import StatusEnum @pytest.mark.num_tests(1) def test_get_test(dbsession: OrmSession, tests: list[models.Test]): test = tests[0] - result = db_tests.get_test(dbsession, test.id) + result = get_test(dbsession, test.id) assert result is not None assert result.id == test.id @pytest.mark.parametrize( - ["worker_id", "country", "statuses", "expected"], + ["worker_id", "country_code", "statuses", "expected"], [ (None, None, None, True), ("worker_id", None, None, False), @@ -32,14 +38,14 @@ def test_basic_filter( *, dbsession: OrmSession, worker_id: str | None, - country: str | None, + country_code: str | None, statuses: list[StatusEnum] | None, expected: bool, ): - test = db_tests.create_or_update_test(dbsession, status=StatusEnum.PENDING) + test = create_or_update_test(dbsession, status=StatusEnum.PENDING) assert ( - db_tests.filter_test( - test, worker_id=worker_id, country=country, statuses=statuses + filter_test( + test, worker_id=worker_id, country_code=country_code, statuses=statuses ) == expected ) @@ -47,11 +53,11 @@ def test_basic_filter( @pytest.mark.num_tests @pytest.mark.parametrize( - ["worker_id", "country", "statuses"], + ["worker_id", "country_code", "statuses"], [ (None, None, None), - (None, "Nigeria", None), - (None, "Nigeria", [StatusEnum.PENDING]), + (None, "ng", None), + (None, "ng", [StatusEnum.PENDING]), (None, None, [StatusEnum.PENDING, StatusEnum.MISSED]), ], ) @@ -59,18 +65,18 @@ def test_list_tests( dbsession: OrmSession, tests: list[models.Test], worker_id: str | None, - country: str | None, + country_code: str | None, statuses: list[StatusEnum] | None, ): filtered_tests = [ test for test in tests - if db_tests.filter_test( - test, worker_id=worker_id, country=country, statuses=statuses + if filter_test( + test, worker_id=worker_id, country_code=country_code, statuses=statuses ) ] - result = db_tests.list_tests( - dbsession, worker_id=worker_id, country=country, statuses=statuses + result = list_tests( + dbsession, worker_id=worker_id, country_code=country_code, statuses=statuses ) assert len(filtered_tests) == result.nb_tests @@ -84,7 +90,6 @@ def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: speed = download_size / duration update_values = { "status": data_gen.test_status(), - "country": data_gen.test_country(), "download_size": download_size, "duration": duration, "speed": speed, @@ -92,7 +97,29 @@ def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: "started_on": data_gen.date_time(datetime.UTC), "latency": latency, } - updated_test = db_tests.create_or_update_test(dbsession, test_id, **update_values) # type: ignore + updated_test = create_or_update_test(dbsession, test_id, **update_values) # type: ignore for key, value in update_values.items(): if hasattr(updated_test, key): assert getattr(updated_test, key) == value + + +@pytest.mark.num_tests(10, status=StatusEnum.PENDING) +@pytest.mark.parametrize( + ["interval", "expected_status"], + [ + (datetime.timedelta(seconds=0), StatusEnum.MISSED), + (datetime.timedelta(days=7), StatusEnum.PENDING), + ], +) +def test_expire_tests( + dbsession: OrmSession, + tests: list[models.Test], + interval: datetime.timedelta, + expected_status: StatusEnum, +): + for test in tests: + assert test.status == StatusEnum.PENDING + + expire_tests(dbsession, interval) + for test in tests: + assert test.status == expected_status diff --git a/backend/tests/db/test_worker.py b/backend/tests/db/test_worker.py new file mode 100644 index 0000000..19f5e18 --- /dev/null +++ b/backend/tests/db/test_worker.py @@ -0,0 +1,34 @@ +from pathlib import Path + +from sqlalchemy.orm import Session as OrmSession + +from mirrors_qa_backend.db.models import Country +from mirrors_qa_backend.db.worker import create_worker + + +def test_create_worker(dbsession: OrmSession, tmp_path: Path, private_key_data: bytes): + worker_id = "test" + countries = [ + Country(code="ng", name="Nigeria"), + Country(code="fr", name="France"), + ] + dbsession.add_all(countries) + + private_key_fpath = tmp_path / "key.pem" + private_key_fpath.write_bytes(private_key_data) + + new_worker = create_worker( + dbsession, + worker_id=worker_id, + country_codes=[country.code for country in countries], + private_key_fpath=private_key_fpath, + ) + assert new_worker.id == worker_id + assert new_worker.pubkey_fingerprint != "" + assert len(new_worker.countries) == len(countries) + assert "BEGIN PUBLIC KEY" in new_worker.pubkey_pkcs8 + assert "END PUBLIC KEY" in new_worker.pubkey_pkcs8 + assert private_key_fpath.exists() + contents = private_key_fpath.read_text() + assert "BEGIN PRIVATE KEY" in contents + assert "END PRIVATE KEY" in contents diff --git a/backend/tests/routes/test_auth_endpoints.py b/backend/tests/routes/test_auth_endpoints.py index 543b977..04e595e 100644 --- a/backend/tests/routes/test_auth_endpoints.py +++ b/backend/tests/routes/test_auth_endpoints.py @@ -7,14 +7,14 @@ from fastapi.testclient import TestClient from mirrors_qa_backend.cryptography import sign_message -from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.models import Worker @pytest.mark.parametrize( ["datetime_str", "expected_status", "expected_response_contents"], [ ( - datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC).isoformat(), + datetime.datetime.fromtimestamp(0, tz=datetime.UTC).isoformat(), status.HTTP_401_UNAUTHORIZED, [], ), @@ -32,7 +32,7 @@ ) def test_authenticate_worker( client: TestClient, - worker: models.Worker, + worker: Worker, private_key: RSAPrivateKey, datetime_str: str, expected_status: int, diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml index 1d05d50..eaf7097 100644 --- a/dev/docker-compose.yaml +++ b/dev/docker-compose.yaml @@ -40,6 +40,19 @@ services: BACKEND_ROOT_API: http://backend SLEEP_INTERVAL: 180 DEBUG: true + scheduler: + depends_on: + postgresdb: + condition: service_healthy + build: + context: ../backend + container_name: mirrors-qa-scheduler + environment: + POSTGRES_URI: postgresql+psycopg://mirrors_qa:mirrors_qa@postgresdb:5432/mirrors_qa + DEBUG: true + command: mirrors-qa-scheduler + networks: + - mirrors-qa-network volumes: pg-data-mirrors-qa: