Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it compatible with Pydantic V2 #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions rdbbeat/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from rdbbeat.exceptions import PeriodicTaskNotFound


def get_crontab_schedule(session: Session, schedule: Schedule) -> CrontabSchedule:
def get_crontab_schedule(
session: Session, schedule: Schedule
) -> CrontabSchedule:
crontab = (
session.query(CrontabSchedule)
.where(
Expand All @@ -28,7 +30,7 @@ def get_crontab_schedule(session: Session, schedule: Schedule) -> CrontabSchedul
)
.one_or_none()
)
return crontab or CrontabSchedule(**schedule.dict())
return crontab or CrontabSchedule(**schedule.model_dump())


def schedule_task(
Expand All @@ -42,7 +44,9 @@ def schedule_task(
"""
Schedule a task by adding a periodic task entry.
"""
crontab = get_crontab_schedule(session=session, schedule=scheduled_task.schedule)
crontab = get_crontab_schedule(
session=session, schedule=scheduled_task.schedule
)
task = PeriodicTask(
crontab=crontab,
name=scheduled_task.name,
Expand All @@ -66,7 +70,11 @@ def update_task_enabled_status(
Update task enabled status (if task is enabled or disabled).
"""
try:
task = session.query(PeriodicTask).filter(PeriodicTask.id == periodic_task_id).one()
task = (
session.query(PeriodicTask)
.filter(PeriodicTask.id == periodic_task_id)
.one()
)
task.enabled = enabled_status # type: ignore [assignment]
session.add(task)

Expand All @@ -85,7 +93,11 @@ def update_task(
Update the details of a task including the crontab schedule
"""
try:
task = session.query(PeriodicTask).filter(PeriodicTask.id == periodic_task_id).one()
task = (
session.query(PeriodicTask)
.filter(PeriodicTask.id == periodic_task_id)
.one()
)

task.crontab = get_crontab_schedule(session, scheduled_task.schedule)
task.name = scheduled_task.name # type: ignore [assignment]
Expand All @@ -98,14 +110,22 @@ def update_task(
return task


def is_crontab_used(session: Session, crontab_schedule: CrontabSchedule) -> bool:
schedules = session.query(PeriodicTask).filter_by(crontab=crontab_schedule).all()
def is_crontab_used(
session: Session, crontab_schedule: CrontabSchedule
) -> bool:
schedules = (
session.query(PeriodicTask).filter_by(crontab=crontab_schedule).all()
)
return True if schedules else False


def delete_task(session: Session, periodic_task_id: int) -> PeriodicTask:
try:
task = session.query(PeriodicTask).where(PeriodicTask.id == periodic_task_id).one()
task = (
session.query(PeriodicTask)
.where(PeriodicTask.id == periodic_task_id)
.one()
)
session.delete(task)
session.flush()
if not is_crontab_used(session, task.crontab):
Expand Down
32 changes: 21 additions & 11 deletions rdbbeat/data_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Hewlett Packard Enterprise Development LP
# MIT License

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator


class Schedule(BaseModel):
Expand All @@ -12,49 +12,59 @@ class Schedule(BaseModel):
month_of_year: str = "*"
timezone: str = "UTC"

@validator("minute")
@field_validator("minute")
def minute_validation(cls, v: str) -> str:
if "*" == v:
return v
elif not v.isdigit():
raise ValueError(f"Minute: '{v}' is not a valid int")
assert int(v) >= 0 and int(v) < 60, "Minute value must range between 0 and 59"
assert (
int(v) >= 0 and int(v) < 60
), "Minute value must range between 0 and 59"
return v

@validator("hour")
@field_validator("hour")
def hour_validation(cls, v: str) -> str:
if "*" == v:
return v
elif not v.isdigit():
raise ValueError(f"Hour: '{v}' is not a valid int")
assert int(v) >= 0 and int(v) < 24, "Hour value must range between 0 and 23"
assert (
int(v) >= 0 and int(v) < 24
), "Hour value must range between 0 and 23"
return v

@validator("day_of_week")
@field_validator("day_of_week")
def day_of_week_validation(cls, v: str) -> str:
if "*" == v:
return v
elif not v.isdigit():
raise ValueError(f"Day of week: '{v}' is not a valid int")
assert int(v) >= 0 and int(v) < 7, "Day of the week value must range between 0 and 6"
assert (
int(v) >= 0 and int(v) < 7
), "Day of the week value must range between 0 and 6"
return v

@validator("day_of_month")
@field_validator("day_of_month")
def day_of_month_validation(cls, v: str) -> str:
if "*" == v:
return v
elif not v.isdigit():
raise ValueError(f"Day of month: '{v}' is not a valid int")
assert int(v) > 0 and int(v) < 32, "Day of the month value must range between 1 and 31"
assert (
int(v) > 0 and int(v) < 32
), "Day of the month value must range between 1 and 31"
return v

@validator("month_of_year")
@field_validator("month_of_year")
def month_of_year_validation(cls, v: str) -> str:
if "*" == v:
return v
elif not v.isdigit():
raise ValueError(f"Month: '{v}' is not a valid int")
assert int(v) > 0 and int(v) < 13, "Month of year value must range between 0 and 12"
assert (
int(v) > 0 and int(v) < 13
), "Month of year value must range between 0 and 12"
return v


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
install_requires=[
"celery~=5.2",
"sqlalchemy",
"SQLAlchemy-Utils",
"alembic",
"pydantic",
"python-dotenv",
Expand Down
101 changes: 76 additions & 25 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict

import pytest
from mock import patch
from unittest.mock import patch
from sqlalchemy.orm.exc import NoResultFound

from rdbbeat.controller import (
Expand All @@ -20,78 +20,121 @@

def test_get_new_crontab_schedule(scheduled_task):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.query(CrontabSchedule).where().one_or_none.return_value = None
crontab = get_crontab_schedule(mock_session, Schedule(**scheduled_task.get("schedule")))
mock_session.query(
CrontabSchedule
).where().one_or_none.return_value = None
crontab = get_crontab_schedule(
mock_session, Schedule(**scheduled_task.get("schedule"))
)
assert crontab.minute == scheduled_task.get("schedule")["minute"]
assert crontab.hour == scheduled_task.get("schedule")["hour"]
assert crontab.day_of_week == scheduled_task.get("schedule")["day_of_week"]
assert crontab.day_of_month == scheduled_task.get("schedule")["day_of_month"]
assert crontab.month_of_year == scheduled_task.get("schedule")["month_of_year"]
assert (
crontab.day_of_week
== scheduled_task.get("schedule")["day_of_week"]
)
assert (
crontab.day_of_month
== scheduled_task.get("schedule")["day_of_month"]
)
assert (
crontab.month_of_year
== scheduled_task.get("schedule")["month_of_year"]
)
assert crontab.timezone == scheduled_task.get("schedule")["timezone"]


def test_get_existing_crontab_schedule(scheduled_task, scheduled_task_db_object):
def test_get_existing_crontab_schedule(
scheduled_task, scheduled_task_db_object
):
existing_crontab = scheduled_task_db_object.crontab
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.query(CrontabSchedule).where().one_or_none.return_value = existing_crontab
crontab = get_crontab_schedule(mock_session, Schedule(**scheduled_task.get("schedule")))
mock_session.query(
CrontabSchedule
).where().one_or_none.return_value = existing_crontab
crontab = get_crontab_schedule(
mock_session, Schedule(**scheduled_task.get("schedule"))
)
assert existing_crontab == crontab


def test_schedule_task(scheduled_task_db_object, scheduled_task):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.add.return_value = None
mock_session.query(CrontabSchedule).where().one_or_none.return_value = None
mock_session.query(
CrontabSchedule
).where().one_or_none.return_value = None

actual_scheduled_task = schedule_task(mock_session, ScheduledTask.parse_obj(scheduled_task))
actual_scheduled_task = schedule_task(
mock_session, ScheduledTask.model_validate(scheduled_task)
)

expected_scheduled_task = scheduled_task_db_object

assert actual_scheduled_task.name == expected_scheduled_task.name
assert actual_scheduled_task.task == expected_scheduled_task.task
assert actual_scheduled_task.schedule == expected_scheduled_task.schedule
assert (
actual_scheduled_task.schedule == expected_scheduled_task.schedule
)


def test_schedule_task_kwargs(scheduled_task_db_object, scheduled_task):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.add.return_value = None
mock_session.query(CrontabSchedule).where().one_or_none.return_value = None
mock_session.query(
CrontabSchedule
).where().one_or_none.return_value = None

actual_scheduled_task = schedule_task(
mock_session, ScheduledTask.parse_obj(scheduled_task), report_metadata_uid="some_uid"
mock_session,
ScheduledTask.model_validate(scheduled_task),
report_metadata_uid="some_uid",
)

expected_scheduled_task = scheduled_task_db_object

assert actual_scheduled_task.name == expected_scheduled_task.name
assert actual_scheduled_task.task == expected_scheduled_task.task
assert actual_scheduled_task.schedule == expected_scheduled_task.schedule
assert actual_scheduled_task.kwargs == json.dumps({"report_metadata_uid": "some_uid"})
assert (
actual_scheduled_task.schedule == expected_scheduled_task.schedule
)
assert actual_scheduled_task.kwargs == json.dumps(
{"report_metadata_uid": "some_uid"}
)


def test_update_task_enabled_status(scheduled_task_db_object):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.query(PeriodicTask).get.return_value = scheduled_task_db_object
mock_session.query(PeriodicTask).get.return_value = (
scheduled_task_db_object
)

periodic_task_id = 1
updated_task = update_task_enabled_status(mock_session, False, periodic_task_id)
updated_task = update_task_enabled_status(
mock_session, False, periodic_task_id
)

assert updated_task.enabled is False


def test_update_task_enabled_status_fail():
with patch("sqlalchemy.orm.Session") as mock_session:
with pytest.raises(PeriodicTaskNotFound):
mock_session.query(PeriodicTask).filter().one.side_effect = NoResultFound()
mock_session.query(PeriodicTask).filter().one.side_effect = (
NoResultFound()
)

periodic_task_id = -1
update_task_enabled_status(mock_session, False, periodic_task_id)


def test_update_task(scheduled_task_db_object):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.query(PeriodicTask).filter().one.return_value = scheduled_task_db_object
mock_session.query(CrontabSchedule).where().one_or_none.return_value = None
mock_session.query(PeriodicTask).filter().one.return_value = (
scheduled_task_db_object
)
mock_session.query(
CrontabSchedule
).where().one_or_none.return_value = None

new_schedule: Dict = {
"minute": "24",
Expand All @@ -116,20 +159,26 @@ def test_update_task(scheduled_task_db_object):

periodic_task_id = 1
actual_updated_db_task = update_task(
mock_session, ScheduledTask.parse_obj(new_scheduled_task), periodic_task_id
mock_session,
ScheduledTask.model_validate(new_scheduled_task),
periodic_task_id,
)

assert mock_session.query(PeriodicTask).filter().one.call_count == 1
assert actual_updated_db_task.name == expected_updated_task.name
assert actual_updated_db_task.task == expected_updated_task.task
assert actual_updated_db_task.schedule == expected_updated_task.schedule
assert (
actual_updated_db_task.schedule == expected_updated_task.schedule
)


def test_delete_task(scheduled_task_db_object):
with patch("sqlalchemy.orm.Session") as mock_session:
# Set up the mock_session
periodic_task_id = 1
mock_session.query(PeriodicTask).where().one.return_value = scheduled_task_db_object
mock_session.query(PeriodicTask).where().one.return_value = (
scheduled_task_db_object
)
mock_session.delete.return_value = None
# Delete task
with patch("rdbbeat.controller.is_crontab_used") as is_crontab_used:
Expand All @@ -147,5 +196,7 @@ def test_is_crontab_used(scheduled_task_db_object):
with patch("sqlalchemy.orm.Session") as mock_session:
mock_session.query(PeriodicTask).filter_by().all.return_value = None

result = is_crontab_used(mock_session, scheduled_task_db_object.schedule)
result = is_crontab_used(
mock_session, scheduled_task_db_object.schedule
)
assert result is False
Loading