-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2627772
commit 50557cd
Showing
10 changed files
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,3 +182,4 @@ config/settings-combiner.yaml | |
|
||
# CI | ||
client.yaml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
data | ||
seed.npz | ||
*.tgz | ||
*.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
data | ||
*.npz | ||
*.tgz | ||
*.tar.gz | ||
.flower-example | ||
client.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
Using Flower clients in FEDn | ||
------------- | ||
|
||
Example of how a Flower client can be used in FEDn. Flowers quickstart-pytorch example is | ||
used in this example (see `flwr_client.py``). Study the `client/entrypoint` file for | ||
details of the implementation. | ||
|
||
|
||
Run details | ||
----------- | ||
|
||
See `https://fedn.readthedocs.io/en/stable/quickstart.html` for general run details. Note | ||
that the flower client handles data distribution programatically, so data related steps can be | ||
omitted. To run this example after initializing fedn with the `seed.npz` and `package` that | ||
can be generated through `bin/build`, continue with building a docker image containing the flower | ||
dependencies: | ||
|
||
.. code-block:: | ||
docker build --build-arg REQUIREMENTS=examples/flower-client/requirements.txt -t flower-client . | ||
In separate terminals, start clients and inject the `CLIENT_NUMBER` dependency, for example for client1: | ||
|
||
.. code-block:: | ||
docker run \ | ||
-v $PWD/client.yaml:/app/client.yaml \ | ||
--network=fedn_default \ | ||
-e CLIENT_NUMBER=0 \ | ||
flower-client run client -in client.yaml --name client1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
# Init seed | ||
client/entrypoint init_seed | ||
|
||
# Make compute package | ||
tar -czvf package.tgz client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
# Init venv | ||
python3.9 -m venv .flower-example | ||
|
||
# Pip deps | ||
.flower-example/bin/pip install --upgrade pip | ||
.flower-example/bin/pip install -e ../../fedn | ||
.flower-example/bin/pip install -r requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!./.flower-example/bin/python | ||
import os | ||
import sys | ||
|
||
import fire | ||
|
||
from fedn.utils.helpers.helpers import get_helper, save_metadata, save_metrics | ||
|
||
HELPER_MODULE = "numpyhelper" | ||
helper = get_helper(HELPER_MODULE) | ||
|
||
NUM_CLASSES = 10 | ||
|
||
|
||
def _get_node_id(): | ||
# Figure out FEDn client number from container name | ||
number = os.environ.get("CLIENT_NUMBER", "0") | ||
return number | ||
|
||
|
||
def _get_flower_client(): | ||
"""Instanziates flower client.""" | ||
original_argv = sys.argv | ||
# Set sys.argv to mimic command-line input required by flower_client.py | ||
sys.argv = ["flower_client.py", "--node-id", _get_node_id()] | ||
# Import flower client | ||
import flwr_client | ||
|
||
sys.argv = original_argv | ||
flower_client = flwr_client.FlowerClient() | ||
return flwr_client, flower_client | ||
|
||
|
||
flwr_client, flower_client = _get_flower_client() | ||
|
||
|
||
def save_parameters(out_path): | ||
"""Save model paramters to file. | ||
:param model: The model to serialize. | ||
:type model: torch.nn.Module | ||
:param out_path: The path to save to. | ||
:type out_path: str | ||
""" | ||
parameters_np = flower_client.get_parameters({}) | ||
helper.save(parameters_np, out_path) | ||
|
||
|
||
def init_seed(out_path="seed.npz"): | ||
"""Initialize seed model and save it to file. | ||
:param out_path: The path to save the seed model to. | ||
:type out_path: str | ||
""" | ||
# Init and save | ||
save_parameters(out_path) | ||
|
||
|
||
def train(in_model_path, out_model_path): | ||
"""Complete a model update. | ||
Load model paramters from in_model_path (managed by the FEDn client), | ||
perform a model update through the flower client, and write updated paramters | ||
to out_model_path (picked up by the FEDn client). | ||
:param in_model_path: The path to the input model. | ||
:type in_model_path: str | ||
:param out_model_path: The path to save the output model to. | ||
:type out_model_path: str | ||
""" | ||
parameters_np = helper.load(in_model_path) | ||
|
||
# Train on flower client | ||
_, num_examples, _ = flower_client.fit(parameters_np, {}) | ||
|
||
# Metadata needed for aggregation server side | ||
metadata = { | ||
# num_examples are mandatory | ||
"num_examples": num_examples, | ||
} | ||
|
||
# Save JSON metadata file (mandatory) | ||
save_metadata(metadata, out_model_path) | ||
|
||
# Save model update (mandatory) | ||
save_parameters(out_model_path) | ||
|
||
|
||
def validate(in_model_path, out_json_path, data_path=None): | ||
"""Validate model on the clients test dataset. | ||
:param in_model_path: The path to the input model. | ||
:type in_model_path: str | ||
:param out_json_path: The path to save the output JSON to. | ||
:type out_json_path: str | ||
:param data_path: The path to the data file. | ||
:type data_path: str | ||
""" | ||
parameters_np = helper.load(in_model_path) | ||
|
||
loss, _, accuracy = flower_client.evaluate(parameters_np, {}) | ||
accuracy = accuracy["accuracy"] | ||
|
||
# JSON schema | ||
report = { | ||
"test_loss": loss, | ||
"test_accuracy": accuracy, | ||
} | ||
print(f"Loss: {loss}, accuracy: {accuracy}") | ||
# Save JSON | ||
save_metrics(report, out_json_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire( | ||
{ | ||
"init_seed": init_seed, | ||
"train": train, | ||
"validate": validate, | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
entry_points: | ||
train: | ||
command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS | ||
validate: | ||
command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import argparse | ||
import warnings | ||
from collections import OrderedDict | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from flwr_datasets import FederatedDataset | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
from tqdm import tqdm | ||
|
||
# ############################################################################# | ||
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader | ||
# ############################################################################# | ||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" | ||
|
||
def __init__(self) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return self.fc3(x) | ||
|
||
|
||
def train(net, trainloader, epochs): | ||
"""Train the model on the training set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
for _ in range(epochs): | ||
for batch in tqdm(trainloader, "Training"): | ||
images = batch["img"] | ||
labels = batch["label"] | ||
optimizer.zero_grad() | ||
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() | ||
optimizer.step() | ||
|
||
|
||
def test(net, testloader): | ||
"""Validate the model on the test set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
with torch.no_grad(): | ||
for batch in tqdm(testloader, "Testing"): | ||
images = batch["img"].to(DEVICE) | ||
labels = batch["label"].to(DEVICE) | ||
outputs = net(images) | ||
loss += criterion(outputs, labels).item() | ||
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
return loss, accuracy | ||
|
||
|
||
def load_data(node_id): | ||
"""Load partition CIFAR10 data.""" | ||
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) | ||
partition = fds.load_partition(node_id) | ||
# Divide data on each node: 80% train, 20% test | ||
partition_train_test = partition.train_test_split(test_size=0.2) | ||
pytorch_transforms = Compose( | ||
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | ||
) | ||
|
||
def apply_transforms(batch): | ||
"""Apply transforms to the partition from FederatedDataset.""" | ||
batch["img"] = [pytorch_transforms(img) for img in batch["img"]] | ||
return batch | ||
|
||
partition_train_test = partition_train_test.with_transform(apply_transforms) | ||
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) | ||
testloader = DataLoader(partition_train_test["test"], batch_size=32) | ||
return trainloader, testloader | ||
|
||
|
||
# ############################################################################# | ||
# 2. Federation of the pipeline with Flower | ||
# ############################################################################# | ||
|
||
# Get node id | ||
parser = argparse.ArgumentParser(description="Flower") | ||
parser.add_argument( | ||
"--node-id", | ||
choices=[0, 1, 2], | ||
required=True, | ||
type=int, | ||
help="Partition of the dataset divided into 3 iid partitions created artificially.", | ||
) | ||
node_id = parser.parse_args().node_id | ||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data(node_id=node_id) | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def get_parameters(self, config): | ||
return [val.cpu().numpy() for _, val in net.state_dict().items()] | ||
|
||
def set_parameters(self, parameters): | ||
params_dict = zip(net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | ||
net.load_state_dict(state_dict, strict=True) | ||
|
||
def fit(self, parameters, config): | ||
self.set_parameters(parameters) | ||
train(net, trainloader, epochs=1) | ||
return self.get_parameters(config={}), len(trainloader.dataset), {} | ||
|
||
def evaluate(self, parameters, config): | ||
self.set_parameters(parameters) | ||
loss, accuracy = test(net, testloader) | ||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
if __name__ == "__main__": | ||
# Start Flower client | ||
fl.client.start_client( | ||
server_address="127.0.0.1:8080", | ||
client=FlowerClient().to_client(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
torch==1.13.1 | ||
torchvision==0.14.1 | ||
fire==0.3.1 | ||
docker==6.1.1 | ||
flwr==1.7.0 | ||
flwr_datasets==0.0.2 |