Skip to content

Commit

Permalink
Merge pull request #22 from kiwix/task-worker
Browse files Browse the repository at this point in the history
Upload task worker results to Backend API
  • Loading branch information
elfkuzco authored Jul 25, 2024
2 parents 0776842 + 81c60e9 commit 8e7ad26
Show file tree
Hide file tree
Showing 40 changed files with 942 additions and 383 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ dev/data/**
!dev/data/README.md
!dev/.env
id_rsa
*.json
12 changes: 3 additions & 9 deletions backend/src/mirrors_qa_backend/cli/mirrors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import sys

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status
from mirrors_qa_backend.exceptions import MirrorsRequestError
from mirrors_qa_backend.extract import get_current_mirrors


def update_mirrors() -> None:
"""Update the list of active mirrors in the DB."""
logger.info("Updating mirrors list.")
try:
with Session.begin() as session:
results = create_or_update_mirror_status(session, get_current_mirrors())
except MirrorsRequestError as exc:
logger.info(f"error while updating mirrors: {exc}")
sys.exit(1)
with Session.begin() as session:
results = create_or_update_mirror_status(session, get_current_mirrors())
logger.info(
f"Updated mirrors list. Added {results.nb_mirrors_added} mirror(s), "
f"disabled {results.nb_mirrors_disabled} mirror(s)"
Expand Down
56 changes: 30 additions & 26 deletions backend/src/mirrors_qa_backend/cli/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.mirrors import get_enabled_mirrors
from mirrors_qa_backend.db.tests import create_test, expire_tests, list_tests
from mirrors_qa_backend.db.worker import get_idle_workers
from mirrors_qa_backend.enums import StatusEnum
Expand All @@ -16,6 +17,7 @@ def main(
):
while True:
with Session.begin() as session:
mirrors = get_enabled_mirrors(session)
# expire tests whose results have not been reported
expired_tests = expire_tests(
session,
Expand Down Expand Up @@ -44,36 +46,38 @@ def main(
f"No countries registered for idle worker {idle_worker.id}"
)
continue
for country in idle_worker.countries:
# While we have expired "unreported" tests, it is possible that
# a test for a country might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for that country.
pending_tests = list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
country_code=country.code,
# While we have expired "unreported" tests, it is possible that
# a test for a mirror might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for such workers.
pending_tests = list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
)

if pending_tests.nb_tests:
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
f"tests are still pending."
)
continue

if pending_tests.nb_tests:
# Create a test for each mirror from the countries the worker registered
for country in idle_worker.countries:
for mirror in mirrors:
new_test = create_test(
session=session,
worker=idle_worker,
country_code=country.code,
mirror=mirror,
)
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
f"tests are still pending for country {country.name}"
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in location {country.name} "
f"for mirror {mirror.id}"
)
continue

new_test = create_test(
session=session,
worker_id=idle_worker.id,
country_code=country.code,
status=StatusEnum.PENDING,
)
logger.info(
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in country {country.name}"
)

logger.info(f"Sleeping for {sleep_seconds} seconds.")
time.sleep(sleep_seconds)
79 changes: 51 additions & 28 deletions backend/src/mirrors_qa_backend/cli/worker.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,65 @@
import sys

import pycountry
from cryptography.hazmat.primitives import serialization

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.country import update_countries as update_db_countries
from mirrors_qa_backend.db.worker import create_worker as create_db_worker
from mirrors_qa_backend.db.worker import update_worker as update_db_worker


def get_country_mapping(country_codes: list[str]) -> dict[str, str]:
"""Fetch the country names from the country codes.
def create_worker(worker_id: str, private_key_data: bytes, country_codes: list[str]):
Maps the country code to the country name.
"""
country_mapping: dict[str, str] = {}
# Ensure all the countries are valid country codes
for country_code in country_codes:
if len(country_code) != 2: # noqa: PLR2004
logger.info(f"Country code '{country_code}' must be two characters long")
sys.exit(1)

if not pycountry.countries.get(alpha_2=country_code):
logger.info(f"'{country_code}' is not valid country code")
sys.exit(1)

try:
private_key = serialization.load_pem_private_key(
private_key_data, password=None
) # pyright: ignore[reportReturnType]
except Exception as exc:
logger.info(f"Unable to load private key: {exc}")
sys.exit(1)

try:
with Session.begin() as session:
create_db_worker(
session,
worker_id,
country_codes,
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType]
raise ValueError(
f"Country code '{country_code}' must be two characters long"
)
except Exception as exc:
logger.info(f"error while creating worker: {exc}")
sys.exit(1)

if country := pycountry.countries.get(alpha_2=country_code):
country_mapping[country_code] = country.name
else:
raise ValueError(f"'{country_code}' is not valid country code")
return country_mapping


def create_worker(
worker_id: str, private_key_data: bytes, initial_country_codes: list[str]
):
"""Create a worker in the DB.
Assigns the countries for a worker to run tests from.
"""
country_mapping = get_country_mapping(initial_country_codes)
private_key = serialization.load_pem_private_key(
private_key_data, password=None
) # pyright: ignore[reportReturnType]

with Session.begin() as session:
# Update the database with the countries in case those countries don't
# exist yet.
update_db_countries(session, country_mapping)
create_db_worker(
session,
worker_id,
initial_country_codes,
private_key, # pyright: ignore [reportGeneralTypeIssues, reportArgumentType]
)

logger.info(f"Created worker {worker_id} successfully")


def update_worker(worker_id: str, country_codes: list[str]):
"""Update worker's data.
Updates the ountries for a worker to run tests from.
"""
country_mapping = get_country_mapping(country_codes)
with Session.begin() as session:
update_db_countries(session, country_mapping)
update_db_worker(session, worker_id, country_codes)
38 changes: 36 additions & 2 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.db.models import Country


def get_countries(session: OrmSession, *country_codes: str) -> list[Country]:
def get_countries(session: OrmSession, country_codes: list[str]) -> list[Country]:
"""Get countries with the provided country codes.
Gets all available countries if no country codes are provided.
"""
return list(
session.scalars(select(Country).where(Country.code.in_(country_codes))).all()
session.scalars(
select(Country).where(
(Country.code.in_(country_codes)) | (country_codes == [])
)
).all()
)


def get_country_or_none(session: OrmSession, country_code: str) -> Country | None:
return session.scalars(
select(Country).where(Country.code == country_code)
).one_or_none()


def get_country(session: OrmSession, country_code: str) -> Country:
if country := get_country_or_none(session, country_code):
return country
raise RecordDoesNotExistError(f"Country with code {country_code} does not exist.")


def create_country(
session: OrmSession, *, country_code: str, country_name: str
) -> Country:
"""Creates a new country in the database."""
session.execute(
insert(Country)
.values(code=country_code, name=country_name)
.on_conflict_do_nothing(index_elements=["code"])
)
return get_country(session, country_code)


def update_countries(session: OrmSession, country_mapping: dict[str, str]) -> None:
"""Updates the list of countries in the database."""
for country_code, country_name in country_mapping.items():
create_country(session, country_code=country_code, country_name=country_name)
53 changes: 25 additions & 28 deletions backend/src/mirrors_qa_backend/db/mirrors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession
from sqlalchemy.orm import selectinload

from mirrors_qa_backend import logger, schemas
from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.exceptions import EmptyMirrorsError
from mirrors_qa_backend.db.exceptions import EmptyMirrorsError, RecordDoesNotExistError
from mirrors_qa_backend.db.models import Mirror


@dataclass
Expand All @@ -24,7 +23,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
"""
nb_created = 0
for mirror in mirrors:
db_mirror = models.Mirror(
db_mirror = Mirror(
id=mirror.id,
base_url=mirror.base_url,
enabled=mirror.enabled,
Expand All @@ -38,20 +37,8 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
as_only=mirror.as_only,
other_countries=mirror.other_countries,
)
# Ensure the country exists for the mirror
country = session.scalars(
select(models.Country).where(models.Country.code == mirror.country.code)
).one_or_none()

if country is None:
country = models.Country(code=mirror.country.code, name=mirror.country.name)
session.add(country)

db_mirror.country = country
session.add(db_mirror)
logger.debug(
f"Registered new mirror: {db_mirror.id} for country: {country.name}"
)
logger.debug(f"Registered new mirror: {db_mirror.id}.")
nb_created += 1
return nb_created

Expand Down Expand Up @@ -79,9 +66,8 @@ def create_or_update_mirror_status(
# Map the id (hostname) of each mirror from the database for comparison
# against the id of mirrors in current_mirrors. To be used in determining
# if this mirror should be disabled
query = select(models.Mirror).options(selectinload(models.Mirror.country))
db_mirrors: dict[str, models.Mirror] = {
mirror.id: mirror for mirror in session.scalars(query).all()
db_mirrors: dict[str, Mirror] = {
mirror.id: mirror for mirror in session.scalars(select(Mirror)).all()
}

# Create any mirror that doesn't exist on the database
Expand All @@ -95,19 +81,30 @@ def create_or_update_mirror_status(
# exists in the list, re-enable it
for db_mirror_id, db_mirror in db_mirrors.items():
if db_mirror_id not in current_mirrors:
logger.debug(
f"Disabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
logger.debug(f"Disabling mirror: {db_mirror.id}")
db_mirror.enabled = False
session.add(db_mirror)
result.nb_mirrors_disabled += 1
elif not db_mirror.enabled: # re-enable mirror if it was disabled
logger.debug(
f"Re-enabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
logger.debug(f"Re-enabling mirror: {db_mirror.id}")
db_mirror.enabled = True
session.add(db_mirror)
result.nb_mirrors_added += 1
return result


def get_mirror(session: OrmSession, mirror_id: str) -> Mirror:
"""Get a mirror from the DB."""
mirror = session.scalars(select(Mirror).where(Mirror.id == mirror_id)).one_or_none()
if mirror is None:
raise RecordDoesNotExistError(f"Mirror with id: {mirror_id} does not exist.")
return mirror


def get_enabled_mirrors(session: OrmSession) -> list[Mirror]:
"""Get all the enabled mirrors from the DB"""
return list(
session.scalars(
select(Mirror).where(Mirror.enabled == True) # noqa: E712
).all()
)
Loading

0 comments on commit 8e7ad26

Please sign in to comment.