From ca6da08b57c84b7f6c05db2a68378facca02194c Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:30:27 -0500 Subject: [PATCH] Fix meta file processing in storage and improve schedule job retrieval (#2193) * Fix meta file processing in storage and improve schedule job retrieval * changed update_unfinished_jobs to use one get_jobs_by_status --- nvflare/apis/impl/job_def_manager.py | 86 ++++++++++++++----- nvflare/apis/job_def_manager_spec.py | 21 ++++- nvflare/apis/storage.py | 14 ++- .../app_common/storages/filesystem_storage.py | 69 +++++++++++---- nvflare/private/fed/client/communicator.py | 2 +- nvflare/private/fed/server/job_runner.py | 18 ++-- .../app_common/storages/storage_test.py | 7 +- 7 files changed, 159 insertions(+), 58 deletions(-) diff --git a/nvflare/apis/impl/job_def_manager.py b/nvflare/apis/impl/job_def_manager.py index 9f47c43769..c08a2aed11 100644 --- a/nvflare/apis/impl/job_def_manager.py +++ b/nvflare/apis/impl/job_def_manager.py @@ -20,7 +20,7 @@ import time import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from nvflare.apis.fl_context import FLContext from nvflare.apis.job_def import Job, JobDataKey, JobMetaKey, job_from_meta @@ -30,21 +30,34 @@ from nvflare.fuel.utils import fobs from nvflare.fuel.utils.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes +_OBJ_TAG_SCHEDULED = "scheduled" + + +class JobInfo: + def __init__(self, meta: dict, job_id: str, uri: str): + self.meta = meta + self.job_id = job_id + self.uri = uri + class _JobFilter(ABC): @abstractmethod - def filter_job(self, meta: dict) -> bool: + def filter_job(self, info: JobInfo) -> bool: pass class _StatusFilter(_JobFilter): def __init__(self, status_to_check): self.result = [] + if not isinstance(status_to_check, list): + # turning to list + status_to_check = [status_to_check] self.status_to_check = status_to_check - def filter_job(self, meta: dict): - if meta[JobMetaKey.STATUS] == self.status_to_check: - self.result.append(job_from_meta(meta)) + def filter_job(self, info: JobInfo): + status = info.meta.get(JobMetaKey.STATUS.value) + if status in self.status_to_check: + self.result.append(job_from_meta(info.meta)) return True @@ -52,25 +65,42 @@ class _AllJobsFilter(_JobFilter): def __init__(self): self.result = [] - def filter_job(self, meta: dict): - self.result.append(job_from_meta(meta)) + def filter_job(self, info: JobInfo): + self.result.append(job_from_meta(info.meta)) return True class _ReviewerFilter(_JobFilter): - def __init__(self, reviewer_name, fl_ctx: FLContext): + def __init__(self, reviewer_name): """Not used yet, for use in future implementations.""" self.result = [] self.reviewer_name = reviewer_name - def filter_job(self, meta: dict): - approvals = meta.get(JobMetaKey.APPROVALS) + def filter_job(self, info: JobInfo): + approvals = info.meta.get(JobMetaKey.APPROVALS.value) if not approvals or self.reviewer_name not in approvals: - self.result.append(job_from_meta(meta)) + self.result.append(job_from_meta(info.meta)) return True -# TODO:: use try block around storage calls +class _ScheduleJobFilter(_JobFilter): + + """ + This filter is optimized for selecting jobs to schedule since it is used so frequently (every 1 sec). + """ + + def __init__(self, store): + self.store = store + self.result = [] + + def filter_job(self, info: JobInfo): + status = info.meta.get(JobMetaKey.STATUS.value) + if status == RunStatus.SUBMITTED.value: + self.result.append(job_from_meta(info.meta)) + elif status: + # skip this job in all future calls (so the meta file of this job won't be read) + self.store.tag_object(uri=info.uri, tag=_OBJ_TAG_SCHEDULED) + return True class SimpleJobDefManager(JobDefManagerSpec): @@ -239,28 +269,40 @@ def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]: self._scan(job_filter, fl_ctx) return job_filter.result - def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext): + def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]: + job_filter = _ScheduleJobFilter(self._get_job_store(fl_ctx)) + self._scan(job_filter, fl_ctx, skip_tag=_OBJ_TAG_SCHEDULED) + return job_filter.result + + def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext, skip_tag=None): store = self._get_job_store(fl_ctx) - jid_paths = store.list_objects(self.uri_root) - if not jid_paths: + obj_uris = store.list_objects(self.uri_root, without_tag=skip_tag) + self.log_debug(fl_ctx, f"objects to scan: {len(obj_uris)}") + if not obj_uris: return - for jid_path in jid_paths: - jid = pathlib.PurePath(jid_path).name - - meta = store.get_meta(self.job_uri(jid)) + for uri in obj_uris: + jid = pathlib.PurePath(uri).name + job_uri = self.job_uri(jid) + meta = store.get_meta(job_uri) if meta: - ok = job_filter.filter_job(meta) + ok = job_filter.filter_job(JobInfo(meta, jid, job_uri)) if not ok: break - def get_jobs_by_status(self, status, fl_ctx: FLContext) -> List[Job]: + def get_jobs_by_status(self, status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]: + """Get jobs that are in the specified status + Args: + status: a single status value or a list of status values + fl_ctx: the FL context + Returns: list of jobs that are in specified status + """ job_filter = _StatusFilter(status) self._scan(job_filter, fl_ctx) return job_filter.result def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Job]: - job_filter = _ReviewerFilter(reviewer_name, fl_ctx) + job_filter = _ReviewerFilter(reviewer_name) self._scan(job_filter, fl_ctx) return job_filter.result diff --git a/nvflare/apis/job_def_manager_spec.py b/nvflare/apis/job_def_manager_spec.py index 235acd87e8..aee02b0195 100644 --- a/nvflare/apis/job_def_manager_spec.py +++ b/nvflare/apis/job_def_manager_spec.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext @@ -103,7 +103,8 @@ def get_job_data(self, jid: str, fl_ctx: FLContext) -> dict: fl_ctx (FLContext): FLContext information Returns: - a dict to hold the job data and workspace. With the format: {JobDataKey.JOB_DATA.value: stored_data, JobDataKey.WORKSPACE_DATA: workspace_data} + a dict to hold the job data and workspace. With the format: + {JobDataKey.JOB_DATA.value: stored_data, JobDataKey.WORKSPACE_DATA: workspace_data} """ pass @@ -145,6 +146,18 @@ def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext): """ pass + @abstractmethod + def get_jobs_to_schedule(self, fl_ctx: FLContext) -> List[Job]: + """Get job candidates for scheduling. + + Args: + fl_ctx: FL context + + Returns: list of jobs for scheduling + + """ + pass + @abstractmethod def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]: """Gets all Jobs in the system. @@ -158,11 +171,11 @@ def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]: pass @abstractmethod - def get_jobs_by_status(self, run_status: RunStatus, fl_ctx: FLContext) -> List[Job]: + def get_jobs_by_status(self, run_status: Union[RunStatus, List[RunStatus]], fl_ctx: FLContext) -> List[Job]: """Gets Jobs of a specified status. Args: - run_status (RunStatus): status to filter for + run_status: status to filter for: a single or a list of status values fl_ctx (FLContext): FLContext information Returns: diff --git a/nvflare/apis/storage.py b/nvflare/apis/storage.py index 953ec7e103..95ada59a87 100644 --- a/nvflare/apis/storage.py +++ b/nvflare/apis/storage.py @@ -91,11 +91,12 @@ def update_data(self, uri: str, data: bytes): pass @abstractmethod - def list_objects(self, path: str) -> List[str]: + def list_objects(self, path: str, without_tag=None) -> List[str]: """Lists all objects in the specified path. Args: path: the path to the objects + without_tag: skip the objects with this specified tag Returns: list of URIs of objects @@ -163,3 +164,14 @@ def delete_object(self, uri: str): """ pass + + @abstractmethod + def tag_object(self, uri: str, tag: str, data=None): + """Tag an object with specified tag and data. + Args: + uri: URI of the object + tag: tag to be placed on the object + data: data associated with the tag. + Returns: None + """ + pass diff --git a/nvflare/app_common/storages/filesystem_storage.py b/nvflare/app_common/storages/filesystem_storage.py index 9d807763ab..9fc71a1999 100644 --- a/nvflare/app_common/storages/filesystem_storage.py +++ b/nvflare/app_common/storages/filesystem_storage.py @@ -59,6 +59,21 @@ def _object_exists(uri: str): return all((os.path.isabs(uri), os.path.isdir(uri), data_exists, meta_exists)) +def _encode_meta(meta: dict) -> bytes: + return json.dumps(meta).encode("utf-8") + + +def _decode_meta(data: bytes) -> dict: + s = data.decode("utf-8") + if s.startswith('"'): + # this is in old format + result = ast.literal_eval(json.loads(s)) + else: + # this is json string + result = json.loads(s) + return result + + @validate_class_methods_args class FilesystemStorage(StorageSpec): def __init__(self, root_dir=os.path.abspath(os.sep), uri_root="/"): @@ -79,6 +94,9 @@ def __init__(self, root_dir=os.path.abspath(os.sep), uri_root="/"): self.root_dir = root_dir self.uri_root = uri_root + def _object_path(self, uri: str): + return os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: bool = False): """Creates an object. @@ -97,7 +115,7 @@ def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: b IOError: if error writing the object """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if _object_exists(full_uri) and not overwrite_existing: raise StorageException("object {} already exists and overwrite_existing is False".format(uri)) @@ -111,7 +129,7 @@ def create_object(self, uri: str, data: bytes, meta: dict, overwrite_existing: b tmp_data_path = data_path + "_" + str(uuid.uuid4()) _write(tmp_data_path, data) try: - _write(meta_path, json.dumps(str(meta)).encode("utf-8")) + _write(meta_path, _encode_meta(meta)) except Exception as e: os.remove(tmp_data_path) raise e @@ -133,17 +151,17 @@ def update_meta(self, uri: str, meta: dict, replace: bool): IOError: if error writing the object """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) if replace: - _write(os.path.join(full_uri, "meta"), json.dumps(str(meta)).encode("utf-8")) + _write(os.path.join(full_uri, "meta"), _encode_meta(meta)) else: prev_meta = self.get_meta(uri) prev_meta.update(meta) - _write(os.path.join(full_uri, "meta"), json.dumps(str(prev_meta)).encode("utf-8")) + _write(os.path.join(full_uri, "meta"), _encode_meta(prev_meta)) def update_data(self, uri: str, data: bytes): """Updates the data of the specified object. @@ -158,18 +176,19 @@ def update_data(self, uri: str, data: bytes): IOError: if error writing the object """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) _write(os.path.join(full_uri, "data"), data) - def list_objects(self, path: str) -> List[str]: + def list_objects(self, path: str, without_tag=None) -> List[str]: """List all objects in the specified path. Args: path: the path uri to the objects + without_tag: if set, skip the objects with this specified tag Returns: list of URIs of objects @@ -179,13 +198,23 @@ def list_objects(self, path: str) -> List[str]: StorageException: if path does not exist or is not a valid directory. """ - full_dir_path = os.path.join(self.root_dir, path.lstrip(self.uri_root)) + full_dir_path = self._object_path(path) if not os.path.isdir(full_dir_path): raise StorageException(f"path {full_dir_path} is not a valid directory.") - return [ - os.path.join(path, f) for f in os.listdir(full_dir_path) if _object_exists(os.path.join(full_dir_path, f)) - ] + result = [] + + # Use scandir instead of listdir. + # According to https://peps.python.org/pep-0471/#os-scandir, scandir is more memory-efficient than listdir + # when iterating very large directories. + gen = os.scandir(full_dir_path) + for e in gen: + # assert isinstance(e, os.DirEntry) + obj_dir = os.path.join(full_dir_path, e.name) + if _object_exists(obj_dir): + if not without_tag or not os.path.exists(os.path.join(obj_dir, without_tag)): + result.append(os.path.join(path, e.name)) + return result def get_meta(self, uri: str) -> dict: """Gets meta of the specified object. @@ -201,12 +230,11 @@ def get_meta(self, uri: str) -> dict: StorageException: if object does not exist """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) - - return ast.literal_eval(json.loads(_read(os.path.join(full_uri, "meta")).decode("utf-8"))) + return _decode_meta(_read(os.path.join(full_uri, "meta"))) def get_data(self, uri: str) -> bytes: """Gets data of the specified object. @@ -222,7 +250,7 @@ def get_data(self, uri: str) -> bytes: StorageException: if object does not exist """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) @@ -243,7 +271,7 @@ def get_detail(self, uri: str) -> Tuple[dict, bytes]: StorageException: if object does not exist """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) @@ -261,7 +289,7 @@ def delete_object(self, uri: str): StorageException: if object does not exist """ - full_uri = os.path.join(self.root_dir, uri.lstrip(self.uri_root)) + full_uri = self._object_path(uri) if not _object_exists(full_uri): raise StorageException("object {} does not exist".format(uri)) @@ -269,3 +297,10 @@ def delete_object(self, uri: str): shutil.rmtree(full_uri) return full_uri + + def tag_object(self, uri: str, tag: str, data=None): + full_path = self._object_path(uri) + mark_file = os.path.join(full_path, tag) + with open(mark_file, "w") as f: + if data: + f.write(data) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 3a0a66d267..5fe982aebc 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -349,7 +349,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C num_heartbeats_sent += 1 if num_heartbeats_sent % heartbeats_log_interval == 0: - self.logger.info(f"Client: {client_name} has sent {num_heartbeats_sent} heartbeats.") + self.logger.debug(f"Client: {client_name} has sent {num_heartbeats_sent} heartbeats.") if not simulate_mode: # server_message = result.get_header(CellMessageHeaderKeys.MESSAGE) diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index 3d123abcb5..fd02501429 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -381,10 +381,16 @@ def run(self, fl_ctx: FLContext): thread.start() while not self.ask_to_stop: + time.sleep(1.0) + if not isinstance(engine.server.server_state, HotState): - time.sleep(1.0) continue - approved_jobs = job_manager.get_jobs_by_status(RunStatus.SUBMITTED, fl_ctx) + + if not engine.get_clients(): + # no clients registered yet - don't try to schedule! + continue + + approved_jobs = job_manager.get_jobs_to_schedule(fl_ctx) self.log_debug( fl_ctx, f"{fl_ctx.get_identity_name()} Got approved_jobs: {approved_jobs} from the job_manager" ) @@ -464,8 +470,6 @@ def run(self, fl_ctx: FLContext): fl_ctx, f"Failed to run the Job ({ready_job.job_id}): {secure_format_exception(e)}" ) - time.sleep(1.0) - thread.join() else: self.log_error(fl_ctx, "There's no Job Manager defined. Won't be able to run the jobs.") @@ -497,11 +501,7 @@ def restore_running_job(self, run_number: str, job_id: str, job_clients, snapsho def update_unfinished_jobs(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() job_manager = engine.get_component(SystemComponents.JOB_MANAGER) - all_jobs = [] - dispatched_jobs = job_manager.get_jobs_by_status(RunStatus.DISPATCHED, fl_ctx) - all_jobs.extend(dispatched_jobs) - running_jobs = job_manager.get_jobs_by_status(RunStatus.RUNNING, fl_ctx) - all_jobs.extend(running_jobs) + all_jobs = job_manager.get_jobs_by_status([RunStatus.RUNNING, RunStatus.DISPATCHED], fl_ctx) for job in all_jobs: try: diff --git a/tests/unit_test/app_common/storages/storage_test.py b/tests/unit_test/app_common/storages/storage_test.py index 55e7fe9869..32496e5f36 100644 --- a/tests/unit_test/app_common/storages/storage_test.py +++ b/tests/unit_test/app_common/storages/storage_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ast import json import os import random @@ -42,7 +41,7 @@ def random_data(): def random_meta(): - return {random.getrandbits(8): random.getrandbits(8) for _ in range(32)} + return {random_string(20): random.getrandbits(8) for _ in range(32)} ROOT_DIR = os.path.abspath(os.sep) @@ -95,7 +94,7 @@ def test_large_storage(self, storage, n_folders, n_files, path_depth): with open(os.path.join(test_filepath, "meta"), "wb") as f: meta = random_meta() - f.write(json.dumps(str(meta)).encode("utf-8")) + f.write(json.dumps(meta).encode("utf-8")) storage.create_object(filepath, data, meta, overwrite_existing=True) @@ -109,7 +108,7 @@ def test_large_storage(self, storage, n_folders, n_files, path_depth): with open(os.path.join(test_dir_path, "data"), "rb") as f: data = f.read() with open(os.path.join(test_dir_path, "meta"), "rb") as f: - meta = ast.literal_eval(json.loads(f.read().decode("utf-8"))) + meta = json.loads(f.read().decode("utf-8")) assert storage.get_data(dir_path) == data assert storage.get_detail(dir_path)[1] == data