Skip to content

Commit

Permalink
add flower client example
Browse files Browse the repository at this point in the history
  • Loading branch information
viktorvaladi committed Feb 28, 2024
1 parent 2627772 commit 50557cd
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,4 @@ config/settings-combiner.yaml

# CI
client.yaml

4 changes: 4 additions & 0 deletions examples/flower-client/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
seed.npz
*.tgz
*.tar.gz
6 changes: 6 additions & 0 deletions examples/flower-client/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data
*.npz
*.tgz
*.tar.gz
.flower-example
client.yaml
30 changes: 30 additions & 0 deletions examples/flower-client/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
Using Flower clients in FEDn
-------------

Example of how a Flower client can be used in FEDn. Flowers quickstart-pytorch example is
used in this example (see `flwr_client.py``). Study the `client/entrypoint` file for
details of the implementation.


Run details
-----------

See `https://fedn.readthedocs.io/en/stable/quickstart.html` for general run details. Note
that the flower client handles data distribution programatically, so data related steps can be
omitted. To run this example after initializing fedn with the `seed.npz` and `package` that
can be generated through `bin/build`, continue with building a docker image containing the flower
dependencies:

.. code-block::
docker build --build-arg REQUIREMENTS=examples/flower-client/requirements.txt -t flower-client .
In separate terminals, start clients and inject the `CLIENT_NUMBER` dependency, for example for client1:

.. code-block::
docker run \
-v $PWD/client.yaml:/app/client.yaml \
--network=fedn_default \
-e CLIENT_NUMBER=0 \
flower-client run client -in client.yaml --name client1
8 changes: 8 additions & 0 deletions examples/flower-client/bin/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -e

# Init seed
client/entrypoint init_seed

# Make compute package
tar -czvf package.tgz client
10 changes: 10 additions & 0 deletions examples/flower-client/bin/init_venv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e

# Init venv
python3.9 -m venv .flower-example

# Pip deps
.flower-example/bin/pip install --upgrade pip
.flower-example/bin/pip install -e ../../fedn
.flower-example/bin/pip install -r requirements.txt
121 changes: 121 additions & 0 deletions examples/flower-client/client/entrypoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!./.flower-example/bin/python
import os
import sys

import fire

from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)

NUM_CLASSES = 10


def _get_node_id():
# Figure out FEDn client number from container name
number = os.environ.get("CLIENT_NUMBER", "0")
return number


def _get_flower_client():
"""Instanziates flower client."""
original_argv = sys.argv
# Set sys.argv to mimic command-line input required by flower_client.py
sys.argv = ["flower_client.py", "--node-id", _get_node_id()]
# Import flower client
import flwr_client

sys.argv = original_argv
flower_client = flwr_client.FlowerClient()
return flwr_client, flower_client


flwr_client, flower_client = _get_flower_client()


def save_parameters(out_path):
"""Save model paramters to file.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = flower_client.get_parameters({})
helper.save(parameters_np, out_path)


def init_seed(out_path="seed.npz"):
"""Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
save_parameters(out_path)


def train(in_model_path, out_model_path):
"""Complete a model update.
Load model paramters from in_model_path (managed by the FEDn client),
perform a model update through the flower client, and write updated paramters
to out_model_path (picked up by the FEDn client).
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_model_path: The path to save the output model to.
:type out_model_path: str
"""
parameters_np = helper.load(in_model_path)

# Train on flower client
_, num_examples, _ = flower_client.fit(parameters_np, {})

# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
"num_examples": num_examples,
}

# Save JSON metadata file (mandatory)
save_metadata(metadata, out_model_path)

# Save model update (mandatory)
save_parameters(out_model_path)


def validate(in_model_path, out_json_path, data_path=None):
"""Validate model on the clients test dataset.
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_json_path: The path to save the output JSON to.
:type out_json_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
parameters_np = helper.load(in_model_path)

loss, _, accuracy = flower_client.evaluate(parameters_np, {})
accuracy = accuracy["accuracy"]

# JSON schema
report = {
"test_loss": loss,
"test_accuracy": accuracy,
}
print(f"Loss: {loss}, accuracy: {accuracy}")
# Save JSON
save_metrics(report, out_json_path)


if __name__ == "__main__":
fire.Fire(
{
"init_seed": init_seed,
"train": train,
"validate": validate,
}
)
5 changes: 5 additions & 0 deletions examples/flower-client/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
entry_points:
train:
command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS
validate:
command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS
138 changes: 138 additions & 0 deletions examples/flower-client/client/flwr_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr_datasets import FederatedDataset
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for batch in tqdm(trainloader, "Training"):
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()


def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in tqdm(testloader, "Testing"):
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data(node_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


# #############################################################################
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get node id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
choices=[0, 1, 2],
required=True,
type=int,
help="Partition of the dataset divided into 3 iid partitions created artificially.",
)
node_id = parser.parse_args().node_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data(node_id=node_id)


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


if __name__ == "__main__":
# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
6 changes: 6 additions & 0 deletions examples/flower-client/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch==1.13.1
torchvision==0.14.1
fire==0.3.1
docker==6.1.1
flwr==1.7.0
flwr_datasets==0.0.2

0 comments on commit 50557cd

Please sign in to comment.