Skip to content

Commit

Permalink
initial inference + RoundConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed May 23, 2024
1 parent 36b8241 commit 512a4ca
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 61 deletions.
8 changes: 5 additions & 3 deletions fedn/network/combiner/aggregators/aggregatorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def _validate_model_update(self, model_update):
:return: True if the model update is valid, False otherwise.
:rtype: bool
"""
data = json.loads(model_update.meta)["training_metadata"]
if "num_examples" not in data.keys():
logger.error("AGGREGATOR({}): Model validation failed, num_examples missing in metadata.".format(self.name))
try:
data = json.loads(model_update.meta)["training_metadata"]
num_examples = data["num_examples"]
except KeyError as e:
logger.error("AGGREGATOR({}): Invalid model update, missing metadata.".format(self.name))
return False
return True

Expand Down
97 changes: 63 additions & 34 deletions fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream
from fedn.common.log_config import (logger, set_log_level_from_string,
set_log_stream)
from fedn.network.combiner.connect import ConnectorCombiner, Status
from fedn.network.combiner.modelservice import ModelService
from fedn.network.combiner.roundhandler import RoundHandler
from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler
from fedn.network.grpc.server import Server
from fedn.network.storage.s3.repository import Repository
from fedn.network.storage.statestore.mongostatestore import MongoStateStore
Expand Down Expand Up @@ -161,7 +162,7 @@ def __whoami(self, client, instance):
client.role = role_to_proto_role(instance.role)
return client

def request_model_update(self, config, clients=[]):
def request_model_update(self, session_id, model_id, config, clients=[]):
"""Ask clients to update the current global model.
:param config: the model configuration to send to clients
Expand All @@ -170,32 +171,14 @@ def request_model_update(self, config, clients=[]):
:type clients: list
"""
# The request to be added to the client queue
request = fedn.TaskRequest()
request.model_id = config["model_id"]
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
request.data = json.dumps(config)
request.type = fedn.StatusType.MODEL_UPDATE
request.session_id = config["session_id"]

request.sender.name = self.id
request.sender.role = fedn.COMBINER

if len(clients) == 0:
clients = self.get_active_trainers()

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)
request, clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients)

if len(clients) < 20:
logger.info("Sent model update request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model update request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_validation(self, model_id, config, clients=[]):
def request_model_validation(self, session_id, model_id, clients=[]):
"""Ask clients to validate the current global model.
:param model_id: the model id to validate
Expand All @@ -206,30 +189,76 @@ def request_model_validation(self, model_id, config, clients=[]):
:type clients: list
"""
# The request to be added to the client queue
request, clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))

def request_model_inference(self, session_id: str, model_id: str, clients: list=[]) -> None:
"""Ask clients to perform inference on the model.
:param model_id: the model id to perform inference on
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list
"""
request, clients = self._send_request_type(fedn.StatusType.INFERENCE, session_id, model_id, clients)

if len(clients) < 20:
logger.info("Sent model inference request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model inference request for model {} to {} clients".format(request.model_id, len(clients)))

def _send_request_type(self, request_type, session_id, model_id, config=None, clients=[]):
"""Send a request of a specific type to clients.
:param request_type: the type of request
:type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType`
:param model_id: the model id to send in the request
:type model_id: str
:param config: the model configuration to send to clients
:type config: dict
:param clients: the clients to send the request to
:type clients: list
:return: the request and the clients
:rtype: tuple
"""
request = fedn.TaskRequest()
request.model_id = model_id
request.correlation_id = str(uuid.uuid4())
request.timestamp = str(datetime.now())
# request.is_inference = (config['task'] == 'inference')
request.type = fedn.StatusType.MODEL_VALIDATION
request.type = request_type
request.session_id = session_id

request.sender.name = self.id
request.sender.role = fedn.COMBINER
request.session_id = config["session_id"]

if len(clients) == 0:
clients = self.get_active_validators()
if request_type == fedn.StatusType.MODEL_UPDATE:
request.data = json.dumps(config)
if len(clients) == 0:
clients = self.get_active_trainers()
elif request_type == fedn.StatusType.MODEL_VALIDATION:
if len(clients) == 0:
clients = self.get_active_validators()
elif request_type == fedn.StatusType.INFERENCE:
request.data = json.dumps(config)
if len(clients) == 0:
# TODO: add inference clients type
clients = self.get_active_validators()

# TODO: if inference, request.data should be user-defined data/parameters

for client in clients:
request.receiver.name = client
request.receiver.role = fedn.WORKER
self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE)

if len(clients) < 20:
logger.info("Sent model validation request for model {} to clients {}".format(request.model_id, clients))
else:
logger.info("Sent model validation request for model {} to {} clients".format(request.model_id, len(clients)))
return request, clients

def get_active_trainers(self):
"""Get a list of active trainers.
Expand Down Expand Up @@ -410,7 +439,7 @@ def Start(self, control: fedn.ControlRequest, context):
"""
logger.info("grpc.Combiner.Start: Starting round")

config = {}
config = RoundConfig()
for parameter in control.parameter:
config.update({parameter.key: parameter.value})

Expand Down
3 changes: 2 additions & 1 deletion fedn/network/combiner/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.network.combiner.roundhandler import RoundConfig


class CombinerUnavailableError(Exception):
Expand Down Expand Up @@ -202,7 +203,7 @@ def set_aggregator(self, aggregator):
else:
raise

def submit(self, config):
def submit(self, config: RoundConfig):
"""Submit a compute plan to the combiner.
:param config: The job configuration.
Expand Down
100 changes: 90 additions & 10 deletions fedn/network/combiner/roundhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,64 @@
import sys
import time
import uuid
from typing import TypedDict

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
from fedn.network.combiner.modelservice import load_model_from_BytesIO, serialize_model_to_BytesIO
from fedn.network.combiner.modelservice import (load_model_from_BytesIO,
serialize_model_to_BytesIO)
from fedn.utils.helpers.helpers import get_helper
from fedn.utils.parameters import Parameters


class RoundConfig(TypedDict):
"""Round configuration.
:param _job_id: A universally unique identifier for the round. Set by Combiner.
:type _job_id: str
:param committed_at: The time the round was committed. Set by Controller.
:type committed_at: str
:param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "inference".
:type task: str
:param round_id: The round identifier as str(int)
:type round_id: str
:param round_timeout: The round timeout in seconds. Set by user interfaces or Controller.
:type round_timeout: str
:param rounds: The number of rounds. Set by user interfaces.
:param model_id: The model identifier. Set by user interfaces or Controller (get_latest_model).
:type model_id: str
:param model_version: The model version. Currently not used.
:type model_version: str
:param model_type: The model type. Currently not used.
:type model_type: str
:param model_size: The size of the model. Currently not used.
:type model_size: int
:param model_parameters: The model parameters. Currently not used.
:type model_parameters: dict
:param model_metadata: The model metadata. Currently not used.
:type model_metadata: dict
:param session_id: The session identifier. Set by (Controller?).
:type session_id: str
:param helper_type: The helper type.
:type helper_type: str
:param aggregator: The aggregator type.
:type aggregator: str
"""
_job_id: str
committed_at: str
task: str
round_id: str
round_timeout: str
rounds: int
model_id: str
model_version: str
model_type: str
model_size: int
model_parameters: dict
model_metadata: dict
session_id: str
helper_type: str
aggregator: str
class ModelUpdateError(Exception):
pass

Expand Down Expand Up @@ -42,7 +92,7 @@ def __init__(self, storage, server, modelservice):
def set_aggregator(self, aggregator):
self.aggregator = get_aggregator(aggregator, self.storage, self.server, self.modelservice, self)

def push_round_config(self, round_config):
def push_round_config(self, round_config: RoundConfig) -> str:
"""Add a round_config (job description) to the inbox.
:param round_config: A dict containing the round configuration (from global controller).
Expand Down Expand Up @@ -144,8 +194,11 @@ def _training_round(self, config, clients):
meta["nr_required_updates"] = int(config["clients_required"])
meta["timeout"] = float(config["round_timeout"])

session_id = config["session_id"]
model_id = config["model_id"]

# Request model updates from all active clients.
self.server.request_model_update(config, clients=clients)
self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients)

# If buffer_size is -1 (default), the round terminates when/if all clients have completed.
if int(config["buffer_size"]) == -1:
Expand Down Expand Up @@ -182,7 +235,7 @@ def _training_round(self, config, clients):
meta["aggregation_time"] = data
return model, meta

def _validation_round(self, config, clients, model_id):
def _validation_round(self, session_id, model_id, clients):
"""Send model validation requests to clients.
:param config: The round config object (passed to the client).
Expand All @@ -192,7 +245,19 @@ def _validation_round(self, config, clients, model_id):
:param model_id: The ID of the model to validate
:type model_id: str
"""
self.server.request_model_validation(model_id, config, clients)
self.server.request_model_validation(session_id, model_id, clients=clients)

def _inference_round(self, session_id: str, model_id: str, clients: list):
"""Send model inference requests to clients.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to send inference requests to
:type clients: list
:param model_id: The ID of the model to use for inference
:type model_id: str
"""
self.server.request_model_inference(session_id, model_id, clients=clients)

def stage_model(self, model_id, timeout_retry=3, retry=2):
"""Download a model from persistent storage and set in modelservice.
Expand Down Expand Up @@ -271,17 +336,28 @@ def _check_nr_round_clients(self, config):
logger.info("Too few clients to start round.")
return False

def execute_validation_round(self, round_config):
def execute_validation_round(self, session_id, model_id):
"""Coordinate validation rounds as specified in config.
:param round_config: The round config object.
:type round_config: dict
"""
model_id = round_config["model_id"]
logger.info("COMBINER orchestrating validation of model {}".format(model_id))
self.stage_model(model_id)
validators = self._assign_round_clients(self.server.max_clients, type="validators")
self._validation_round(round_config, validators, model_id)
self._validation_round(session_id, model_id, validators)

def execute_inference_round(self, session_id: str, model_id: str) -> None:
"""Coordinate inference rounds as specified in config.
:param round_config: The round config object.
:type round_config: dict
"""
logger.info("COMBINER orchestrating inference using model {}".format(model_id))
self.stage_model(model_id)
# TODO: Implement inference client type
clients = self._assign_round_clients(self.server.max_clients, type="validators")
self._inference_round(session_id, model_id, clients)

def execute_training_round(self, config):
"""Coordinates clients to execute training tasks.
Expand Down Expand Up @@ -330,6 +406,8 @@ def run(self, polling_interval=1.0):
while True:
try:
round_config = self.round_configs.get(block=False)
session_id = round_config["session_id"]
model_id = round_config["model_id"]

# Check that the minimum allowed number of clients are connected
ready = self._check_nr_round_clients(round_config)
Expand All @@ -343,8 +421,10 @@ def run(self, polling_interval=1.0):
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
self.server.statestore.set_round_combiner_data(round_meta)
elif round_config["task"] == "validation" or round_config["task"] == "inference":
self.execute_validation_round(round_config)
elif round_config["task"] == "validation":
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "inference":
logger.info("Inference task not yet implemented.")
else:
logger.warning("config contains unkown task type.")
else:
Expand Down
Loading

0 comments on commit 512a4ca

Please sign in to comment.