Skip to content

Commit

Permalink
Client api use exchange task (#2070)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Oct 16, 2023
1 parent b5ec419 commit 2b10b9f
Show file tree
Hide file tree
Showing 18 changed files with 439 additions and 271 deletions.
27 changes: 27 additions & 0 deletions nvflare/app_common/abstract/exchange_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional


class ExchangeTask:
def __init__(self, task_name: str, task_id: str, data: Any, meta: Optional[dict] = None, return_code: str = "ok"):
self.task_name = task_name
self.task_id = task_id
self.meta = meta
self.data = data
self.return_code = return_code

def __str__(self):
return f"Task(name:{self.task_name},id:{self.task_id})"
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum


class ModelExchangeFormat(str, Enum):
class ExchangeFormat(str, Enum):
RAW = "raw"
PYTORCH = "pytorch"
NUMPY = "numpy"
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import time
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple

from nvflare.fuel.utils.pipe.pipe import Message, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler, Topic
Expand All @@ -40,31 +40,33 @@ class ExchangePeerGoneException(DataExchangeException):
pass


class ModelExchanger:
class DataExchanger:
def __init__(
self,
supported_topics: List[str],
pipe: Pipe,
pipe_name: str = "pipe",
topic: str = "data",
get_poll_interval: float = 0.5,
read_interval: float = 0.1,
heartbeat_interval: float = 5.0,
heartbeat_timeout: float = 30.0,
):
"""Initializes the ModelExchanger.
"""Initializes the DataExchanger.
Args:
supported_topics (list[str]): Supported topics for data exchange. This allows the sender and receiver to identify
the purpose or content of the data being exchanged.
pipe (Pipe): The pipe used for data exchange.
pipe_name (str): Name of the pipe. Defaults to "pipe".
topic (str): Topic for data exchange. Defaults to "data".
get_poll_interval (float): Interval for checking if the other side has sent data. Defaults to 0.5.
read_interval (float): Interval for reading from the pipe. Defaults to 0.1.
heartbeat_interval (float): Interval for sending heartbeat to the peer. Defaults to 5.0.
heartbeat_timeout (float): Timeout for waiting for a heartbeat from the peer. Defaults to 30.0.
"""
self.logger = logging.getLogger(self.__class__.__name__)
self._req_id: Optional[str] = None
self._topic = topic
self.current_topic: Optional[str] = None
self._supported_topics = supported_topics

pipe.open(pipe_name)
self.pipe_handler = PipeHandler(
Expand All @@ -76,51 +78,56 @@ def __init__(
self.pipe_handler.start()
self._get_poll_interval = get_poll_interval

def submit_model(self, model: Any) -> None:
"""Submits a model for exchange.
def submit_data(self, data: Any) -> None:
"""Submits a data for exchange.
Args:
model (Any): The model to be submitted.
data (Any): The data to be submitted.
Raises:
DataExchangeException: If there is no request ID available (needs to pull model from server first).
DataExchangeException: If there is no request ID available (needs to pull data from server first).
"""
if self._req_id is None:
raise DataExchangeException("need to pull a model first.")
self._send_reply(data=model, req_id=self._req_id)
raise DataExchangeException("Missing req_id, need to pull a data first.")

def receive_model(self, timeout: Optional[float] = None) -> Any:
"""Receives a model.
if self.current_topic is None:
raise DataExchangeException("Missing current_topic, need to pull a data first.")

self._send_reply(data=data, topic=self.current_topic, req_id=self._req_id)

def receive_data(self, timeout: Optional[float] = None) -> Tuple[str, Any]:
"""Receives a data.
Args:
timeout (Optional[float]): Timeout for waiting to receive a model. Defaults to None.
timeout (Optional[float]): Timeout for waiting to receive a data. Defaults to None.
Returns:
Any: The received model.
A tuple of (topic, data): The received data.
Raises:
ExchangeTimeoutException: If the data cannot be received within the specified timeout.
ExchangeAbortException: If the other endpoint of the pipe requests to abort.
ExchangeEndException: If the other endpoint has ended.
ExchangePeerGoneException: If the other endpoint is gone.
"""
model, req_id = self._receive_request(timeout)
self._req_id = req_id
return model
msg = self._receive_request(timeout)
self._req_id = msg.msg_id
self.current_topic = msg.topic
return msg.topic, msg.data

def finalize(self, close_pipe: bool = True) -> None:
if self.pipe_handler is None:
raise RuntimeError("PipeMonitor is not initialized.")
self.pipe_handler.stop(close_pipe=close_pipe)

def _receive_request(self, timeout: Optional[float] = None) -> Tuple[Any, str]:
def _receive_request(self, timeout: Optional[float] = None) -> Message:
"""Receives a request.
Args:
timeout: how long to wait for the request to come.
Returns:
A tuple of (data, request id).
A Message.
Raises:
ExchangeTimeoutException: If can't receive data within timeout seconds.
Expand All @@ -138,24 +145,21 @@ def _receive_request(self, timeout: Optional[float] = None) -> Tuple[Any, str]:
self.pipe_handler.notify_abort(msg)
raise ExchangeTimeoutException(f"get data timeout after {timeout} secs")
elif msg.topic == Topic.ABORT:
raise ExchangeAbortException("the other end is aborted")
raise ExchangeAbortException("the other end ask to abort")
elif msg.topic == Topic.END:
raise ExchangeEndException(
f"received {msg.topic}: {msg.data} while waiting for result for {self._topic}"
)
raise ExchangeEndException(f"received msg: '{msg}' while waiting for requests")
elif msg.topic == Topic.PEER_GONE:
raise ExchangePeerGoneException(
f"received {msg.topic}: {msg.data} while waiting for result for {self._topic}"
)
elif msg.topic == self._topic:
return msg.data, msg.msg_id
raise ExchangePeerGoneException(f"received msg: '{msg}' while waiting for requests")
elif msg.topic in self._supported_topics:
return msg
time.sleep(self._get_poll_interval)

def _send_reply(self, data: Any, req_id: str, timeout: Optional[float] = None) -> bool:
def _send_reply(self, data: Any, topic: str, req_id: str, timeout: Optional[float] = None) -> bool:
"""Sends a reply.
Args:
data: The data exchange object to be sent.
topic: message topic.
req_id: request ID.
timeout: how long to wait for the peer to read the data.
If not specified, return False immediately.
Expand All @@ -165,6 +169,6 @@ def _send_reply(self, data: Any, req_id: str, timeout: Optional[float] = None) -
"""
if self.pipe_handler is None:
raise RuntimeError("PipeMonitor is not initialized.")
msg = Message.new_reply(topic=self._topic, data=data, req_msg_id=req_id)
msg = Message.new_reply(topic=topic, data=data, req_msg_id=req_id)
has_been_read = self.pipe_handler.send_to_peer(msg, timeout)
return has_been_read
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@
# limitations under the License.

import os
from typing import Optional
from typing import List, Optional

from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.app_common.data_exchange.data_exchanger import DataExchanger
from nvflare.app_common.decomposers import common_decomposers as app_common_decomposers
from nvflare.app_common.model_exchange.model_exchanger import ModelExchanger
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.file_accessor import FileAccessor
from nvflare.fuel.utils.pipe.file_pipe import FilePipe


class FilePipeModelExchanger(ModelExchanger):
class FilePipeDataExchanger(DataExchanger):
def __init__(
self,
data_exchange_path: str,
supported_topics: List[str],
file_accessor: Optional[FileAccessor] = None,
pipe_name: str = "pipe",
topic: str = "data",
get_poll_interval: float = 0.5,
read_interval: float = 0.1,
heartbeat_interval: float = 5.0,
Expand All @@ -40,14 +40,14 @@ def __init__(
Args:
data_exchange_path (str): The path for data exchange. This is the location where the data
will be read from or written to.
supported_topics (list[str]): Supported topics for data exchange. This allows the sender and receiver to identify
the purpose or content of the data being exchanged.
file_accessor (Optional[FileAccessor]): The file accessor for reading and writing files.
If not provided, the default file accessor (FobsFileAccessor) will be used.
Please refer to the docstring of the FileAccessor class for more information
on implementing a custom file accessor. Defaults to None.
pipe_name (str): The name of the pipe to be used for communication. This pipe will be used
for transmitting data between the sender and receiver. Defaults to "pipe".
topic (str): The topic for data exchange. This allows the sender and receiver to identify
the purpose or content of the data being exchanged. Defaults to "data".
get_poll_interval (float): The interval (in seconds) for checking if the other side has sent data.
This determines how often the receiver checks for incoming data. Defaults to 0.5.
read_interval (float): The interval (in seconds) for reading from the pipe. This determines
Expand All @@ -66,9 +66,9 @@ def __init__(
file_pipe.set_file_accessor(file_accessor)

super().__init__(
supported_topics=supported_topics,
pipe=file_pipe,
pipe_name=pipe_name,
topic=topic,
get_poll_interval=get_poll_interval,
read_interval=read_interval,
heartbeat_interval=heartbeat_interval,
Expand Down
28 changes: 28 additions & 0 deletions nvflare/app_common/decomposers/common_decomposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np

from nvflare.app_common.abstract.exchange_task import ExchangeTask
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.model import ModelLearnable
Expand All @@ -29,6 +30,33 @@
from nvflare.fuel.utils.fobs.decomposer import Decomposer, DictDecomposer, Externalizer, Internalizer


class ExchangeTaskDecomposer(fobs.Decomposer):
def supported_type(self):
return ExchangeTask

def decompose(self, b: ExchangeTask, manager: DatumManager = None) -> Any:
externalizer = Externalizer(manager)
return (
b.task_id,
b.task_name,
externalizer.externalize(b.data),
externalizer.externalize(b.meta),
b.return_code,
)

def recompose(self, data: tuple, manager: DatumManager = None) -> ExchangeTask:
assert isinstance(data, tuple)
task_id, task_name, task_data, meta, return_code = data
internalizer = Internalizer(manager)
return ExchangeTask(
task_name=task_name,
task_id=task_id,
data=internalizer.internalize(task_data),
meta=internalizer.internalize(meta),
return_code=return_code,
)


class FLModelDecomposer(fobs.Decomposer):
def supported_type(self):
return FLModel
Expand Down
25 changes: 13 additions & 12 deletions nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import Tensor

from nvflare.app_common.abstract.fl_model import FLModel, MetaKey
from nvflare.client.api import clear, get_config, init, receive, send
from nvflare.client.api import clear, get_config, init, is_evaluate, is_train, receive, send
from nvflare.client.config import ConfigKey

FL_META_KEY = "__fl_meta__"
Expand All @@ -41,8 +41,7 @@ class FLCallback(Callback):
def __init__(self, rank: int = 0):
super(FLCallback, self).__init__()
init(rank=str(rank))
self.has_global_eval = get_config().get(ConfigKey.GLOBAL_EVAL, False)
self.has_training = get_config().get(ConfigKey.TRAINING, False)
self.train_with_evaluation = get_config().get(ConfigKey.TRAIN_WITH_EVAL, False)
self.current_round = None
self.metrics = None
self.total_local_epochs = 0
Expand All @@ -59,15 +58,15 @@ def reset_state(self, trainer):
"""
# set states for next round
if self.current_round is not None:
if self.current_round == 0:
if self.max_epochs_per_round is None:
if trainer.max_epochs and trainer.max_epochs > 0:
self.max_epochs_per_round = trainer.max_epochs
if trainer.max_steps and trainer.max_steps > 0:
self.max_steps_per_round = trainer.max_steps

# record total local epochs/steps
self.total_local_epochs = trainer.current_epoch
self.total_local_steps += trainer.estimated_stepping_batches
self.total_local_steps = trainer.estimated_stepping_batches

# for next round
trainer.num_sanity_val_steps = 0 # Turn off sanity validation steps in following rounds of FL
Expand All @@ -82,11 +81,11 @@ def reset_state(self, trainer):

def on_train_start(self, trainer, pl_module):
# receive the global model and update the local model with global model
if self.has_training:
if is_train():
self._receive_and_update_model(trainer, pl_module)

def on_train_end(self, trainer, pl_module):
if self.has_training:
if is_train():
if hasattr(pl_module, FL_META_KEY):
fl_meta = getattr(pl_module, FL_META_KEY)
if not isinstance(fl_meta, dict):
Expand All @@ -105,13 +104,15 @@ def on_validation_start(self, trainer, pl_module):
# the metrics will be set.
# The subsequence validate() calls will not trigger the receive update model.
# Hence the validate() will be validating the local model.
if pl_module and self.has_global_eval and self.metrics is None:
self._receive_and_update_model(trainer, pl_module)
if (is_train() and self.train_with_evaluation) or is_evaluate():
if pl_module and self.metrics is None:
self._receive_and_update_model(trainer, pl_module)

def on_validation_end(self, trainer, pl_module):
if pl_module and self.has_global_eval and self.metrics is None:
self.metrics = _extract_metrics(trainer.callback_metrics)
self._send_model(FLModel(metrics=self.metrics))
if (is_train() and self.train_with_evaluation) or is_evaluate():
if pl_module and self.metrics is None:
self.metrics = _extract_metrics(trainer.callback_metrics)
self._send_model(FLModel(metrics=self.metrics))

def _receive_and_update_model(self, trainer, pl_module):
model = self._receive_model(trainer)
Expand Down
5 changes: 5 additions & 0 deletions nvflare/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
from nvflare.app_common.abstract.fl_model import FLModel as FLModel
from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType

from .api import DataExchangeException as DataExchangeException
from .api import clear as clear
from .api import get_config as get_config
from .api import get_job_id as get_job_id
from .api import get_site_name as get_site_name
from .api import get_total_rounds as get_total_rounds
from .api import init as init
from .api import is_evaluate as is_evaluate
from .api import is_running as is_running
from .api import is_submit_model as is_submit_model
from .api import is_train as is_train
from .api import params_diff as params_diff
from .api import receive as receive
from .api import receive_global_model as receive_global_model
Expand Down
Loading

0 comments on commit 2b10b9f

Please sign in to comment.