Skip to content

Commit

Permalink
adding readme and refactoring client code
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed May 13, 2024
1 parent a976201 commit d31cb6a
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 116 deletions.
76 changes: 76 additions & 0 deletions examples/FedSimSiam/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
FEDn Project: MNIST (PyTorch)
-----------------------------

This is an example FEDn Project that runs the federated self-supervised learning algorithm FedSimSiam on
the CIFAR-10 dataset. This is a standard example often used for benchmarking. To be able to run this example, you
need to have GPU access.

**Note: We recommend all new users to start by following the Quickstart Tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html**

Prerequisites
-------------

- `Python 3.8, 3.9, 3.10 or 3.11 <https://www.python.org/downloads>`__
- `A FEDn Studio account <https://fedn.scaleoutsystems.com/signup>`__
- Change the dependencies in the 'client/python_env.yaml' file to match your cuda version.

Creating the compute package and seed model
-------------------------------------------

Install fedn:

.. code-block::
pip install fedn
Clone this repository, then locate into this directory:

.. code-block::
git clone https://github.com/scaleoutsystems/fedn.git
cd fedn/examples/mnist-pytorch
Create the compute package:

.. code-block::
fedn package create --path client
This should create a file 'package.tgz' in the project folder.

Next, generate a seed model (the first model in a global model trail):

.. code-block::
fedn run build --path client
This will create a seed model called 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv).

FEDn Studio
-----------

Follow the instructions to register for FEDN Studio and start a project (https://fedn.readthedocs.io/en/stable/studio.html).

In your Studio project:

- Go to the 'Sessions' menu, click on 'New session', and upload the compute package (package.tgz) and seed model (seed.npz).
- In the 'Clients' menu, click on 'Connect client' and download the client configuration file (client.yaml)
- Save the client configuration file to the FedSimSiam example directory (fedn/examples/FedSimSiam)

To connect a client, run the following command in your terminal:

.. code-block::
fedn client start -in client.yaml --secure=True --force-ssl
Running the example
-------------------

After everything is set up, go to 'Sessions' and click on 'New Session'. Click on 'Start run' and the example will execute. You can follow the training progress on 'Events' and 'Models', where you
can monitor the training progress. The monitoring is done using a kNN classifier that is fitted on the feature embeddings of the training images that are obtained by
FedSimSiam's encoder, and evaluated on the feature embeddings of the test images. This process is repeated after each training round.

This is a common method to track FedSimSiam's training progress, as FedSimSiam aims to minimize the distance between the embeddings of similar images.
A high accuracy implies that the feature embeddings for images within the same class are indeed close to each other in the
embedding space, i.e., FedSimSiam learned useful feature embeddings.
7 changes: 4 additions & 3 deletions examples/FedSimSiam/client/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
"""
knn monitor as in InstDisc https://arxiv.org/abs/1805.01978.
This implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
"""
import torch.nn.functional as F
import torch
Expand Down Expand Up @@ -36,7 +37,7 @@ def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1,

total_num += data.size(0)
total_top1 += (pred_labels[:, 0] == target).float().sum().item()
return total_top1 / total_num # * 100
return total_top1 / total_num


def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
Expand Down
4 changes: 2 additions & 2 deletions examples/FedSimSiam/client/python_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ build_dependencies:
- setuptools
- wheel==0.37.1
dependencies:
- torch==2.2.1
- torchvision==0.17.1
- torch==2.2.0
- torchvision==0.17.0
- fedn==0.9.0
9 changes: 2 additions & 7 deletions examples/FedSimSiam/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,12 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)
model.train()

# optimizer = optim.SGD(model.parameters(), lr=0.03,
# momentum=0.9, weight_decay=0.0005)
model.train()

optimizer, lr_scheduler = init_lrscheduler(
model, 500, trainloader) # TODO: Change num epochs
model, 500, trainloader)

print("starting training with lr ", optimizer.param_groups[0]['lr'])
for epoch in range(epochs):
for idx, data in enumerate(trainloader):
images = data[0]
Expand All @@ -116,8 +113,6 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
optimizer.step()
lr_scheduler.step()

print('last learning rate: ', optimizer.param_groups[0]['lr'])

# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
Expand Down
105 changes: 1 addition & 104 deletions examples/FedSimSiam/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,110 +38,7 @@ def __getitem__(self, idx):
return x, y


class LinearEvaluationSimSiam(nn.Module):
def __init__(self, in_model_path):
super(LinearEvaluationSimSiam, self).__init__()
model = load_parameters(in_model_path)
device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
self.encoder = model.encoder.to(device)

# freeze parameters
for param in self.encoder.parameters():
param.requires_grad = False

self.classifier = nn.Linear(2048, 10).to(device)

def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x


def linear_evaluation(in_model_path, out_json_path, data_path=None, train_data_percentage=0.1, epochs=5):
model = LinearEvaluationSimSiam(in_model_path)

device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')

x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)

# for linear evaluation, train only on small subset of training data
n_training_data = train_data_percentage * len(x_train)
print("number of training points: ", n_training_data)

x_train = x_train[:int(n_training_data)]
y_train = y_train[:int(n_training_data)]
print(len(x_train))

traindata = Cifar10(x_train, y_train)
trainloader = DataLoader(traindata, batch_size=4, shuffle=True)

testdata = Cifar10(x_test, y_test)
testloader = DataLoader(testdata, batch_size=4, shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.001)

model.encoder.eval() # this is linear evaluation, only train the classifier
model.classifier.train()

for epoch in range(epochs):
correct = 0
total = 0
total_loss = 0.0
for i, data in enumerate(trainloader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()

with torch.no_grad():
features = model.encoder(inputs)
outputs = model.classifier(features)

loss = criterion(outputs, labels)
print(loss)
loss.backward()
optimizer.step()

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
total_loss += loss.item() * labels.size(0)

training_accuracy = correct / total
print(f"Accuracy: {training_accuracy * 100:.2f}%")

training_loss = total_loss / total
print("train loss: ", training_loss)

# test on test_set
model.eval()
total_loss = 0.0
correct_preds = 0
total_samples = 0

with torch.no_grad():
for i, data in enumerate(testloader):
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)

_, predicted = torch.max(outputs.data, 1)
total_loss += loss.item() * inputs.size(0) # Multiply by batch size
total_samples += labels.size(0)
correct_preds += (predicted == labels).sum().item()

test_accuracy = correct_preds / total_samples
print(f"Test accuracy: {test_accuracy * 100:.2f}%")

test_loss = total_loss / total_samples
print("test loss: ", test_loss)

return training_loss, training_accuracy, test_loss, test_accuracy


def validate(in_model_path, out_json_path, data_path=None, train_data_percentage=1, epochs=3):
def validate(in_model_path, out_json_path, data_path=None):

memory_loader, test_loader = load_knn_monitoring_dataset(data_path)

Expand Down

0 comments on commit d31cb6a

Please sign in to comment.