Skip to content

Commit

Permalink
Adding new example: MonAI 2D classification using Mednist dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed May 20, 2024
1 parent 5f54d04 commit a7a8f7f
Show file tree
Hide file tree
Showing 11 changed files with 756 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/monai-2D-mednist/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
seed.npz
*.tgz
*.tar.gz
6 changes: 6 additions & 0 deletions examples/monai-2D-mednist/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data
*.npz
*.tgz
*.tar.gz
.mnist-pytorch
client.yaml
169 changes: 169 additions & 0 deletions examples/monai-2D-mednist/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
FEDn Project: MNIST (PyTorch)
-----------------------------

This is an example FEDn Project based on the classic hand-written text recognition dataset MNIST.
The example is intented as a minimalistic quickstart and automates the handling of training data
by letting the client download and create its partition of the dataset as it starts up.

**Note: These instructions are geared towards users seeking to learn how to work
with FEDn in local development mode using Docker/docker-compose. We recommend all new users
to start by following the Quickstart Tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html**

Prerequisites
-------------

Using FEDn Studio:

- `Python 3.8, 3.9, 3.10 or 3.11 <https://www.python.org/downloads>`__
- `A FEDn Studio account <https://fedn.scaleoutsystems.com/signup>`__

If using pseudo-distributed mode with docker-compose:

- `Docker <https://docs.docker.com/get-docker>`__
- `Docker Compose <https://docs.docker.com/compose/install>`__

Creating the compute package and seed model
-------------------------------------------

Install fedn:

.. code-block::
pip install fedn
Clone this repository, then locate into this directory:

.. code-block::
git clone https://github.com/scaleoutsystems/fedn.git
cd fedn/examples/mnist-pytorch
Create the compute package:

.. code-block::
fedn package create --path client
This should create a file 'package.tgz' in the project folder.

Next, generate a seed model (the first model in a global model trail):

.. code-block::
fedn run build --path client
This will create a seed model called 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv).

Using FEDn Studio
-----------------

Follow the guide here to set up your FEDn Studio project and learn how to connect clients (using token authentication): `Studio guide <https://fedn.readthedocs.io/en/stable/studio.html>`__.
On the step "Upload Files", upload 'package.tgz' and 'seed.npz' created above.


Modifing the data split:
========================

The default traning and test data for this example is downloaded and split direcly by the client when it starts up (see 'startup' entrypoint).
The number of splits and which split used by a client can be controlled via the environment variables ``FEDN_NUM_DATA_SPLITS`` and ``FEDN_DATA_PATH``.
For example, to split the data in 10 parts and start a client using the 8th partiton:

.. code-block::
export FEDN_PACKAGE_EXTRACT_DIR=package
export FEDN_NUM_DATA_SPLITS=10
export FEDN_DATA_PATH=./data/clients/8/mnist.pt
fedn client start -in client.yaml --secure=True --force-ssl
The default is to split the data into 2 partitions and use the first partition.


Connecting clients using Docker:
================================

For convenience, there is a Docker image hosted on ghrc.io with fedn preinstalled. To start a client using Docker:

.. code-block::
docker run \
-v $PWD/client.yaml:/app/client.yaml \
-e FEDN_PACKAGE_EXTRACT_DIR=package \
-e FEDN_NUM_DATA_SPLITS=2 \
-e FEDN_DATA_PATH=/app/package/data/clients/1/mnist.pt \
ghcr.io/scaleoutsystems/fedn/fedn:0.9.0 run client -in client.yaml --force-ssl --secure=True
Local development mode using Docker/docker compose
--------------------------------------------------

Follow the steps above to install FEDn, generate 'package.tgz' and 'seed.tgz'.

Start a pseudo-distributed FEDn network using docker-compose:

.. code-block::
docker compose \
-f ../../docker-compose.yaml \
-f docker-compose.override.yaml \
up
This starts up local services for MongoDB, Minio, the API Server, one Combiner and two clients.
You can verify the deployment using these urls:

- API Server: http://localhost:8092/get_controller_status
- Minio: http://localhost:9000
- Mongo Express: http://localhost:8081

Upload the package and seed model to FEDn controller using the APIClient. In Python:

.. code-block::
from fedn import APIClient
client = APIClient(host="localhost", port=8092)
client.set_active_package("package.tgz", helper="numpyhelper")
client.set_active_model("seed.npz")
You can now start a training session with 5 rounds (default):

.. code-block::
client.start_session()
Automate experimentation with several clients
=============================================

If you want to scale the number of clients, you can do so by modifying ``docker-compose.override.yaml``. For example,
in order to run with 3 clients, change the environment variable ``FEDN_NUM_DATA_SPLITS`` to 3, and add one more client
by copying ``client1`` and setting ``FEDN_DATA_PATH`` to ``/app/package/data/clients/3/mnist.pt``


Access message logs and validation data from MongoDB
====================================================

You can access and download event logs and validation data via the API, and you can also as a developer obtain
the MongoDB backend data using pymongo or via the MongoExpress interface:

- http://localhost:8081/db/fedn-network/

The credentials are as set in docker-compose.yaml in the root of the repository.

Access global models
====================

You can obtain global model updates from the 'fedn-models' bucket in Minio:

- http://localhost:9000

Reset the FEDn deployment
=========================

To purge all data from a deployment incuding all session and round data, access the MongoExpress UI interface and
delete the entire ``fedn-network`` collection. Then restart all services.

Clean up
========
You can clean up by running

.. code-block::
docker-compose -f ../../docker-compose.yaml -f docker-compose.override.yaml down -v
124 changes: 124 additions & 0 deletions examples/monai-2D-mednist/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from math import floor
import random
import PIL
import numpy as np
import torch
import torchvision
from monai.apps import download_and_extract

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir="data"):
"""Get data from the external repository.
:param out_dir: Path to data directory. If doesn't
:type data_dir: str
"""

# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(out_dir, "MedNIST.tar.gz")

data_dir = os.path.abspath(out_dir)
print('data_dir:', data_dir)
if os.path.exists(data_dir):
print('path exist.')
if not os.path.exists(compressed_file):
print('compressed file does not exist, downloading and extracting data.')
download_and_extract(resource, compressed_file, data_dir, md5)
else:
print('files already exist.')

def get_classes(data_path):
"""Get a list of classes from the dataset
:param data_path: Path to data directory.
:type data_path: str
"""

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")

class_names = sorted(x for x in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, x)))
return(class_names)

def load_data(data_path, is_train=True):
"""Load data from disk.
:param data_path: Path to data directory.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple"""

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/MedNIST")

class_names = get_classes(data_path)

num_class = len(class_names)

image_files_all = [
[os.path.join(data_path, class_names[i], x) for x in os.listdir(os.path.join(data_path, class_names[i]))]
for i in range(num_class)
]

# To make the dataset small, we are using 100 images of each class.
sample_size = 100
image_files = [random.sample(inner_list, sample_size) for inner_list in image_files_all]

num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
image_files_list.extend(image_files[i])
image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")

val_frac = 0.1
test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)

test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
train_indices = indices[val_split:]

train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]

print(f"Training count: {len(train_x)}, Validation count: " f"{len(val_x)}, Test count: {len(test_x)}")

if is_train:
return train_x, train_y, val_x, val_y
else:
return test_x, test_y


if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data"):
get_data()
#load_data('./data')
10 changes: 10 additions & 0 deletions examples/monai-2D-mednist/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
68 changes: 68 additions & 0 deletions examples/monai-2D-mednist/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import collections

import torch
from monai.networks.nets import DenseNet121



from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model():
"""Compile the MonAI model.
:return: The compiled model.
:rtype: torch.nn.Module
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_classes).to(device)
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)
return model


def save_parameters(model, out_path):
"""Save model paramters to file.
:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.
param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path="seed.npz"):
"""Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
Loading

0 comments on commit a7a8f7f

Please sign in to comment.