diff --git a/cdmtaskservice/error_mapping.py b/cdmtaskservice/error_mapping.py index fbeb6db..5a74b22 100644 --- a/cdmtaskservice/error_mapping.py +++ b/cdmtaskservice/error_mapping.py @@ -6,6 +6,7 @@ from typing import NamedTuple from cdmtaskservice.errors import ErrorType +from cdmtaskservice.exceptions import UnauthorizedError, InvalidJobStateError from cdmtaskservice.http_bearer import MissingTokenError, InvalidAuthHeaderError from cdmtaskservice.images import NoEntrypointError from cdmtaskservice.image_remote_lookup import ImageNameParseError, ImageInfoFetchError @@ -17,7 +18,7 @@ NoSuchImageError, NoSuchJobError, ) -from cdmtaskservice.routes import UnauthorizedError, ClientLifeTimeError +from cdmtaskservice.routes import ClientLifeTimeError from cdmtaskservice.s3.client import ( S3BucketInaccessibleError, S3BucketNotFoundError, @@ -60,6 +61,7 @@ class ErrorMapping(NamedTuple): ImageDigestExistsError: ErrorMapping(ErrorType.IMAGE_DIGEST_EXISTS, _H400), NoSuchImageError: ErrorMapping(ErrorType.NO_SUCH_IMAGE, _H404), NoSuchJobError: ErrorMapping(ErrorType.NO_SUCH_JOB, _H404), + InvalidJobStateError: ErrorMapping(ErrorType.INVALID_JOB_STATE, _H400), } diff --git a/cdmtaskservice/errors.py b/cdmtaskservice/errors.py index c359f1c..839f903 100644 --- a/cdmtaskservice/errors.py +++ b/cdmtaskservice/errors.py @@ -59,6 +59,9 @@ class ErrorType(Enum): CLIENT_LIFETIME = (30200, "Client lifetime too short") # noqa: E222 @IgnorePep8 """ The client life time is shorter than requested. """ + + INVALID_JOB_STATE = (30300, "Invalid job state") # noqa: E222 @IgnorePep8 + """ The job is not in the correct state for the requested operation. """ NOT_FOUND = (40000, "Not Found") # noqa: E222 @IgnorePep8 """ The requested resource was not found. """ diff --git a/cdmtaskservice/exceptions.py b/cdmtaskservice/exceptions.py index 1263c54..7133cda 100644 --- a/cdmtaskservice/exceptions.py +++ b/cdmtaskservice/exceptions.py @@ -5,3 +5,7 @@ class UnauthorizedError(Exception): """ An exception thrown when a user attempts a forbidden action. """ + + +class InvalidJobStateError(Exception): + """ An exception thrown when a job is in an invalid state to perform an operation. """ diff --git a/cdmtaskservice/job_state.py b/cdmtaskservice/job_state.py index 43093d8..4968ba8 100644 --- a/cdmtaskservice/job_state.py +++ b/cdmtaskservice/job_state.py @@ -4,7 +4,6 @@ import logging -from cdmtaskservice import kb_auth from cdmtaskservice import models from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string from cdmtaskservice.exceptions import UnauthorizedError @@ -24,7 +23,7 @@ def __init__(self, mongo: MongoDAO): async def get_job( self, job_id: str, - user: kb_auth.KBaseUser, + user: str, as_admin: bool = False ) -> models.Job | models.AdminJobDetails: """ @@ -36,13 +35,13 @@ async def get_job( as_admin - True if the user should always have access to the job and should access additional job details. """ - _not_falsy(user, "user") + _require_string(user, "user") job = await self._mongo.get_job(_require_string(job_id, "job_id"), as_admin=as_admin) - if not as_admin and job.user != user.user: + if not as_admin and job.user != user: # reveals the job ID exists in the system but I don't see a problem with that - raise UnauthorizedError(f"User {user.user} may not access job {job_id}") - msg = f"User {user.user} accessed job {job_id}" + raise UnauthorizedError(f"User {user} may not access job {job_id}") + msg = f"User {user} accessed job {job_id}" if as_admin: - msg = f"Admin user {user.user} accessed {job.user}'s job {job_id}" + msg = f"Admin user {user} accessed {job.user}'s job {job_id}" logging.getLogger(__name__).info(msg) return job diff --git a/cdmtaskservice/jobflows/nersc_jaws.py b/cdmtaskservice/jobflows/nersc_jaws.py index 0194825..a251407 100644 --- a/cdmtaskservice/jobflows/nersc_jaws.py +++ b/cdmtaskservice/jobflows/nersc_jaws.py @@ -8,6 +8,7 @@ from cdmtaskservice import timestamp from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string from cdmtaskservice.callback_url_paths import get_download_complete_callback +from cdmtaskservice.exceptions import InvalidJobStateError from cdmtaskservice.job_state import JobState from cdmtaskservice.mongo import MongoDAO from cdmtaskservice.nersc.manager import NERSCManager @@ -70,7 +71,7 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]): objmeta - the S3 object metadata for the files in the job. """ if _not_falsy(job, "job").state != models.JobState.CREATED: - raise ValueError("job must be in the created state") + raise InvalidJobStateError("Job must be in the created state") logr = logging.getLogger(__name__) # Could check that the s3 and job paths / etags match... YAGNI # TODO PERF this validates the file paths yet again. Maybe the way to go is just have @@ -103,7 +104,7 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]): job.id, task_id, models.JobState.CREATED, - models.JobState.UPLOAD_SUBMITTED, + models.JobState.DOWNLOAD_SUBMITTED, timestamp.utcdatetime(), ) except Exception as e: @@ -111,3 +112,23 @@ async def start_job(self, job: models.Job, objmeta: list[S3ObjectMeta]): logr.exception(f"Error starting download for job {job.id}") # TODO IMPORTANT ERRORHANDLING update job state to ERROR w/ message and don't raise raise e + + async def download_complete(self, job: models.AdminJobDetails): + """ + Continue a job after the download is complete. The job is expected to be in the + download submitted satate. + """ + if _not_falsy(job, "job").state != models.JobState.DOWNLOAD_SUBMITTED: + raise InvalidJobStateError("Job must be in the download submitted state") + # TODO ERRHANDLING IMPORTANT pull the task from the SFAPI. If it a) doesn't exist or b) has + # no errors, continue, otherwise put the job into an errored state. + # TODO ERRHANDLING IMPORTANT upload the output file from the download task and check for + # errors. If any exist, put the job into an errored state. + # TDOO LOGGING Add any relevant logs from the task / download task output file in state + # call + await self._mongo.update_job_state( + job.id, + models.JobState.DOWNLOAD_SUBMITTED, + models.JobState.JOB_SUBMITTING, + timestamp.utcdatetime() + ) diff --git a/cdmtaskservice/models.py b/cdmtaskservice/models.py index f89310c..986a8d6 100644 --- a/cdmtaskservice/models.py +++ b/cdmtaskservice/models.py @@ -597,11 +597,11 @@ class JobState(str, Enum): The state of a job. """ CREATED = "created" - UPLOAD_SUBMITTED = "upload_submitted" + DOWNLOAD_SUBMITTED = "download_submitted" JOB_SUBMITTING = "job_submitting" JOB_SUBMITTED = "job_submitted" - DOWNLOAD_SUBMITTING = "download_submitting" - DOWNLOAD_SUBMITTED = "download_submitted" + UPLOAD_SUBMITTING = "upload_submitting" + UPLOAD_SUBMITTED = "upload_submitted" COMPLETE = "complete" ERROR = "error" diff --git a/cdmtaskservice/mongo.py b/cdmtaskservice/mongo.py index c0d34ab..461cf2a 100644 --- a/cdmtaskservice/mongo.py +++ b/cdmtaskservice/mongo.py @@ -8,6 +8,7 @@ from motor.motor_asyncio import AsyncIOMotorDatabase from pymongo import IndexModel, ASCENDING, DESCENDING from pymongo.errors import DuplicateKeyError +from typing import Any from cdmtaskservice import models from cdmtaskservice.arg_checkers import not_falsy as _not_falsy, require_string as _require_string @@ -135,7 +136,52 @@ async def get_job( raise NoSuchJobError(f"No job with ID '{job_id}' exists") # TODO PERF build up the job piece by piece to skip S3 path validation return models.AdminJobDetails(**doc) if as_admin else models.Job(**doc) + + async def _update_job_state( + self, + job_id: str, + current_state: models.JobState, + state: models.JobState, + time: datetime.datetime, + push: dict[str, Any] | None = None, + ): + res = await self._col_jobs.update_one( + { + models.FLD_JOB_ID: _require_string(job_id, "job_id"), + models.FLD_JOB_STATE: _not_falsy(current_state, "current_state").value, + }, + { + "$push": (push if push else {}) | { + models.FLD_JOB_TRANS_TIMES: + (_not_falsy(state, "state").value, _not_falsy(time, "time") + ) + }, + "$set": {models.FLD_JOB_STATE: state.value} + }, + ) + if not res.matched_count: + raise NoSuchJobError( + f"No job with ID '{job_id}' in state {current_state.value} exists" + ) + async def update_job_state( + self, + job_id: str, + current_state: models.JobState, + state: models.JobState, + time: datetime.datetime, + ): + """ + Update the job state. + + job_id - the job ID. + current_state - the expected current state of the job. If the job is not in this state + an error is thrown. + state - the new state for the job. + time - the time at which the job transitioned to the new state. + """ + await self._update_job_state(job_id, current_state, state, time) + _FLD_NERSC_DL_TASK = f"{models.FLD_JOB_NERSC_DETAILS}.{models.FLD_NERSC_DETAILS_DL_TASK_ID}" async def add_NERSC_download_task_id( @@ -158,25 +204,9 @@ async def add_NERSC_download_task_id( """ # may need to make this more generic where the cluster is passed in and mapped to # a job structure location or something if we support more than NERSC - res = await self._col_jobs.update_one( - { - models.FLD_JOB_ID: _require_string(job_id, "job_id"), - models.FLD_JOB_STATE: _not_falsy(current_state, "current_state").value, - }, - { - "$push": { - self._FLD_NERSC_DL_TASK: _require_string(task_id, "task_id"), - models.FLD_JOB_TRANS_TIMES: - (_not_falsy(state, "state").value, _not_falsy(time, "time") - ) - }, - "$set": {models.FLD_JOB_STATE: state.value} - }, - ) - if not res.matched_count: - raise NoSuchJobError( - f"No job with ID '{job_id}' in state {current_state.value} exists" - ) + await self._update_job_state(job_id, current_state, state, time, push={ + self._FLD_NERSC_DL_TASK: _require_string(task_id, "task_id") + }) class NoSuchImageError(Exception): diff --git a/cdmtaskservice/routes.py b/cdmtaskservice/routes.py index 2b8c143..37d21b0 100644 --- a/cdmtaskservice/routes.py +++ b/cdmtaskservice/routes.py @@ -37,6 +37,9 @@ _AUTH = KBaseHTTPBearer() +# * isn't allowed in KBase user names +_SERVICE_USER = "***service***" + def _ensure_admin(user: kb_auth.KBaseUser, err_msg: str): if user.admin_perm != kb_auth.AdminPermission.FULL: @@ -66,7 +69,7 @@ class Root(BaseModel): response_model=Root, summary="General service info", description="General information about the service.") -async def root(): +async def root() -> Root: return { "service_name": SERVICE_NAME, "version": VERSION, @@ -89,7 +92,7 @@ class WhoAmI(BaseModel): summary="Who am I? What does it all mean?", description="Information about the current user." ) -async def whoami(user: kb_auth.KBaseUser=Depends(_AUTH)): +async def whoami(user: kb_auth.KBaseUser=Depends(_AUTH)) -> WhoAmI: return { "user": user.user, "is_service_admin": kb_auth.AdminPermission.FULL == user.admin_perm @@ -111,7 +114,7 @@ async def submit_job( r: Request, job_input: models.JobInput, user: kb_auth.KBaseUser=Depends(_AUTH), -): +) -> SubmitJobResponse: job_submit = app_state.get_app_state(r).job_submit return SubmitJobResponse(job_id=await job_submit.submit(job_input, user)) @@ -135,9 +138,9 @@ async def get_job( r: Request, job_id: _ANN_JOB_ID, user: kb_auth.KBaseUser=Depends(_AUTH), -): +) -> models.Job: job_state = app_state.get_app_state(r).job_state - return await job_state.get_job(job_id, user) + return await job_state.get_job(job_id, user.user) @ROUTER_ADMIN.post( @@ -162,7 +165,7 @@ async def approve_image( max_length=1000, )], user: kb_auth.KBaseUser=Depends(_AUTH) -): +) -> models.Image: _ensure_admin(user, "Only service administrators can approve images.") images = app_state.get_app_state(r).images return await images.register(image_id) @@ -178,10 +181,10 @@ async def get_job_admin( r: Request, job_id: _ANN_JOB_ID, user: kb_auth.KBaseUser=Depends(_AUTH), -): +) -> models.AdminJobDetails: _ensure_admin(user, "Only service administrators can get jobs as an admin.") job_state = app_state.get_app_state(r).job_state - return await job_state.get_job(job_id, user, as_admin=True) + return await job_state.get_job(job_id, user.user, as_admin=True) class NERSCClientInfo(BaseModel): @@ -215,7 +218,7 @@ async def get_nersc_client_info( ge=1 )] = None, user: kb_auth.KBaseUser=Depends(_AUTH) -): +) -> NERSCClientInfo: _ensure_admin(user, "Only service administrators may view NERSC client information.") nersc_cli = app_state.get_app_state(r).sfapi_client expires = nersc_cli.expiration() @@ -241,8 +244,9 @@ async def download_complete( job_id: _ANN_JOB_ID ): logging.getLogger(__name__).info(f"Download reported as complete for job {job_id}") - # TODO NOW implement - raise NotImplementedError() + appstate = app_state.get_app_state(r) + job = await appstate.job_state.get_job(job_id, _SERVICE_USER, as_admin=True) + await appstate.runners[job.job_input.cluster].download_complete(job) @ROUTER_CALLBACKS.get(