Skip to content

Commit

Permalink
Merge pull request #133 from kbase/dev-service
Browse files Browse the repository at this point in the history
Implement download callback through job state change
  • Loading branch information
MrCreosote authored Dec 17, 2024
2 parents e1733da + 9b19ade commit 30978ff
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 43 deletions.
4 changes: 3 additions & 1 deletion cdmtaskservice/error_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import NamedTuple

from cdmtaskservice.errors import ErrorType
from cdmtaskservice.exceptions import UnauthorizedError, InvalidJobStateError
from cdmtaskservice.http_bearer import MissingTokenError, InvalidAuthHeaderError
from cdmtaskservice.images import NoEntrypointError
from cdmtaskservice.image_remote_lookup import ImageNameParseError, ImageInfoFetchError
Expand All @@ -17,7 +18,7 @@
NoSuchImageError,
NoSuchJobError,
)
from cdmtaskservice.routes import UnauthorizedError, ClientLifeTimeError
from cdmtaskservice.routes import ClientLifeTimeError
from cdmtaskservice.s3.client import (
S3BucketInaccessibleError,
S3BucketNotFoundError,
Expand Down Expand Up @@ -60,6 +61,7 @@ class ErrorMapping(NamedTuple):
ImageDigestExistsError: ErrorMapping(ErrorType.IMAGE_DIGEST_EXISTS, _H400),
NoSuchImageError: ErrorMapping(ErrorType.NO_SUCH_IMAGE, _H404),
NoSuchJobError: ErrorMapping(ErrorType.NO_SUCH_JOB, _H404),
InvalidJobStateError: ErrorMapping(ErrorType.INVALID_JOB_STATE, _H400),
}


Expand Down
3 changes: 3 additions & 0 deletions cdmtaskservice/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class ErrorType(Enum):

CLIENT_LIFETIME = (30200, "Client lifetime too short") # noqa: E222 @IgnorePep8
""" The client life time is shorter than requested. """

INVALID_JOB_STATE = (30300, "Invalid job state") # noqa: E222 @IgnorePep8
""" The job is not in the correct state for the requested operation. """

NOT_FOUND = (40000, "Not Found") # noqa: E222 @IgnorePep8
""" The requested resource was not found. """
Expand Down
4 changes: 4 additions & 0 deletions cdmtaskservice/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@

class UnauthorizedError(Exception):
""" An exception thrown when a user attempts a forbidden action. """


class InvalidJobStateError(Exception):
""" An exception thrown when a job is in an invalid state to perform an operation. """
13 changes: 6 additions & 7 deletions cdmtaskservice/job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging

from cdmtaskservice import kb_auth
from cdmtaskservice import models
from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string
from cdmtaskservice.exceptions import UnauthorizedError
Expand All @@ -24,7 +23,7 @@ def __init__(self, mongo: MongoDAO):
async def get_job(
self,
job_id: str,
user: kb_auth.KBaseUser,
user: str,
as_admin: bool = False
) -> models.Job | models.AdminJobDetails:
"""
Expand All @@ -36,13 +35,13 @@ async def get_job(
as_admin - True if the user should always have access to the job and should access
additional job details.
"""
_not_falsy(user, "user")
_require_string(user, "user")
job = await self._mongo.get_job(_require_string(job_id, "job_id"), as_admin=as_admin)
if not as_admin and job.user != user.user:
if not as_admin and job.user != user:
# reveals the job ID exists in the system but I don't see a problem with that
raise UnauthorizedError(f"User {user.user} may not access job {job_id}")
msg = f"User {user.user} accessed job {job_id}"
raise UnauthorizedError(f"User {user} may not access job {job_id}")
msg = f"User {user} accessed job {job_id}"
if as_admin:
msg = f"Admin user {user.user} accessed {job.user}'s job {job_id}"
msg = f"Admin user {user} accessed {job.user}'s job {job_id}"
logging.getLogger(__name__).info(msg)
return job
25 changes: 23 additions & 2 deletions cdmtaskservice/jobflows/nersc_jaws.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cdmtaskservice import timestamp
from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string
from cdmtaskservice.callback_url_paths import get_download_complete_callback
from cdmtaskservice.exceptions import InvalidJobStateError
from cdmtaskservice.job_state import JobState
from cdmtaskservice.mongo import MongoDAO
from cdmtaskservice.nersc.manager import NERSCManager
Expand Down Expand Up @@ -70,7 +71,7 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]):
objmeta - the S3 object metadata for the files in the job.
"""
if _not_falsy(job, "job").state != models.JobState.CREATED:
raise ValueError("job must be in the created state")
raise InvalidJobStateError("Job must be in the created state")
logr = logging.getLogger(__name__)
# Could check that the s3 and job paths / etags match... YAGNI
# TODO PERF this validates the file paths yet again. Maybe the way to go is just have
Expand Down Expand Up @@ -103,11 +104,31 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]):
job.id,
task_id,
models.JobState.CREATED,
models.JobState.UPLOAD_SUBMITTED,
models.JobState.DOWNLOAD_SUBMITTED,
timestamp.utcdatetime(),
)
except Exception as e:
# TODO LOGGING figure out how logging it going to work etc.
logr.exception(f"Error starting download for job {job.id}")
# TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise
raise e

async def download_complete(self, job: models.AdminJobDetails):
"""
Continue a job after the download is complete. The job is expected to be in the
download submitted satate.
"""
if _not_falsy(job, "job").state != models.JobState.DOWNLOAD_SUBMITTED:
raise InvalidJobStateError("Job must be in the download submitted state")
# TODO ERRHANDLING IMPORTANT pull the task from the SFAPI. If it a) doesn't exist or b) has
# no errors, continue, otherwise put the job into an errored state.
# TODO ERRHANDLING IMPORTANT upload the output file from the download task and check for
# errors. If any exist, put the job into an errored state.
# TDOO LOGGING Add any relevant logs from the task / download task output file in state
# call
await self._mongo.update_job_state(
job.id,
models.JobState.DOWNLOAD_SUBMITTED,
models.JobState.JOB_SUBMITTING,
timestamp.utcdatetime()
)
6 changes: 3 additions & 3 deletions cdmtaskservice/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,11 @@ class JobState(str, Enum):
The state of a job.
"""
CREATED = "created"
UPLOAD_SUBMITTED = "upload_submitted"
DOWNLOAD_SUBMITTED = "download_submitted"
JOB_SUBMITTING = "job_submitting"
JOB_SUBMITTED = "job_submitted"
DOWNLOAD_SUBMITTING = "download_submitting"
DOWNLOAD_SUBMITTED = "download_submitted"
UPLOAD_SUBMITTING = "upload_submitting"
UPLOAD_SUBMITTED = "upload_submitted"
COMPLETE = "complete"
ERROR = "error"

Expand Down
68 changes: 49 additions & 19 deletions cdmtaskservice/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from motor.motor_asyncio import AsyncIOMotorDatabase
from pymongo import IndexModel, ASCENDING, DESCENDING
from pymongo.errors import DuplicateKeyError
from typing import Any

from cdmtaskservice import models
from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string
Expand Down Expand Up @@ -135,7 +136,52 @@ async def get_job(
raise NoSuchJobError(f"No job with ID '{job_id}' exists")
# TODO PERF build up the job piece by piece to skip S3 path validation
return models.AdminJobDetails(**doc) if as_admin else models.Job(**doc)

async def _update_job_state(
self,
job_id: str,
current_state: models.JobState,
state: models.JobState,
time: datetime.datetime,
push: dict[str, Any] | None = None,
):
res = await self._col_jobs.update_one(
{
models.FLD_JOB_ID: _require_string(job_id, "job_id"),
models.FLD_JOB_STATE: _not_falsy(current_state, "current_state").value,
},
{
"$push": (push if push else {}) | {
models.FLD_JOB_TRANS_TIMES:
(_not_falsy(state, "state").value, _not_falsy(time, "time")
)
},
"$set": {models.FLD_JOB_STATE: state.value}
},
)
if not res.matched_count:
raise NoSuchJobError(
f"No job with ID '{job_id}' in state {current_state.value} exists"
)

async def update_job_state(
self,
job_id: str,
current_state: models.JobState,
state: models.JobState,
time: datetime.datetime,
):
"""
Update the job state.
job_id - the job ID.
current_state - the expected current state of the job. If the job is not in this state
an error is thrown.
state - the new state for the job.
time - the time at which the job transitioned to the new state.
"""
await self._update_job_state(job_id, current_state, state, time)

_FLD_NERSC_DL_TASK = f"{models.FLD_JOB_NERSC_DETAILS}.{models.FLD_NERSC_DETAILS_DL_TASK_ID}"

async def add_NERSC_download_task_id(
Expand All @@ -158,25 +204,9 @@ async def add_NERSC_download_task_id(
"""
# may need to make this more generic where the cluster is passed in and mapped to
# a job structure location or something if we support more than NERSC
res = await self._col_jobs.update_one(
{
models.FLD_JOB_ID: _require_string(job_id, "job_id"),
models.FLD_JOB_STATE: _not_falsy(current_state, "current_state").value,
},
{
"$push": {
self._FLD_NERSC_DL_TASK: _require_string(task_id, "task_id"),
models.FLD_JOB_TRANS_TIMES:
(_not_falsy(state, "state").value, _not_falsy(time, "time")
)
},
"$set": {models.FLD_JOB_STATE: state.value}
},
)
if not res.matched_count:
raise NoSuchJobError(
f"No job with ID '{job_id}' in state {current_state.value} exists"
)
await self._update_job_state(job_id, current_state, state, time, push={
self._FLD_NERSC_DL_TASK: _require_string(task_id, "task_id")
})


class NoSuchImageError(Exception):
Expand Down
26 changes: 15 additions & 11 deletions cdmtaskservice/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

_AUTH = KBaseHTTPBearer()

# * isn't allowed in KBase user names
_SERVICE_USER = "***service***"


def _ensure_admin(user: kb_auth.KBaseUser, err_msg: str):
if user.admin_perm != kb_auth.AdminPermission.FULL:
Expand Down Expand Up @@ -66,7 +69,7 @@ class Root(BaseModel):
response_model=Root,
summary="General service info",
description="General information about the service.")
async def root():
async def root() -> Root:
return {
"service_name": SERVICE_NAME,
"version": VERSION,
Expand All @@ -89,7 +92,7 @@ class WhoAmI(BaseModel):
summary="Who am I? What does it all mean?",
description="Information about the current user."
)
async def whoami(user: kb_auth.KBaseUser=Depends(_AUTH)):
async def whoami(user: kb_auth.KBaseUser=Depends(_AUTH)) -> WhoAmI:
return {
"user": user.user,
"is_service_admin": kb_auth.AdminPermission.FULL == user.admin_perm
Expand All @@ -111,7 +114,7 @@ async def submit_job(
r: Request,
job_input: models.JobInput,
user: kb_auth.KBaseUser=Depends(_AUTH),
):
) -> SubmitJobResponse:
job_submit = app_state.get_app_state(r).job_submit
return SubmitJobResponse(job_id=await job_submit.submit(job_input, user))

Expand All @@ -135,9 +138,9 @@ async def get_job(
r: Request,
job_id: _ANN_JOB_ID,
user: kb_auth.KBaseUser=Depends(_AUTH),
):
) -> models.Job:
job_state = app_state.get_app_state(r).job_state
return await job_state.get_job(job_id, user)
return await job_state.get_job(job_id, user.user)


@ROUTER_ADMIN.post(
Expand All @@ -162,7 +165,7 @@ async def approve_image(
max_length=1000,
)],
user: kb_auth.KBaseUser=Depends(_AUTH)
):
) -> models.Image:
_ensure_admin(user, "Only service administrators can approve images.")
images = app_state.get_app_state(r).images
return await images.register(image_id)
Expand All @@ -178,10 +181,10 @@ async def get_job_admin(
r: Request,
job_id: _ANN_JOB_ID,
user: kb_auth.KBaseUser=Depends(_AUTH),
):
) -> models.AdminJobDetails:
_ensure_admin(user, "Only service administrators can get jobs as an admin.")
job_state = app_state.get_app_state(r).job_state
return await job_state.get_job(job_id, user, as_admin=True)
return await job_state.get_job(job_id, user.user, as_admin=True)


class NERSCClientInfo(BaseModel):
Expand Down Expand Up @@ -215,7 +218,7 @@ async def get_nersc_client_info(
ge=1
)] = None,
user: kb_auth.KBaseUser=Depends(_AUTH)
):
) -> NERSCClientInfo:
_ensure_admin(user, "Only service administrators may view NERSC client information.")
nersc_cli = app_state.get_app_state(r).sfapi_client
expires = nersc_cli.expiration()
Expand All @@ -241,8 +244,9 @@ async def download_complete(
job_id: _ANN_JOB_ID
):
logging.getLogger(__name__).info(f"Download reported as complete for job {job_id}")
# TODO NOW implement
raise NotImplementedError()
appstate = app_state.get_app_state(r)
job = await appstate.job_state.get_job(job_id, _SERVICE_USER, as_admin=True)
await appstate.runners[job.job_input.cluster].download_complete(job)


@ROUTER_CALLBACKS.get(
Expand Down

0 comments on commit 30978ff

Please sign in to comment.