Skip to content

Commit

Permalink
not complete yet
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed Jun 18, 2024
1 parent 718a37d commit 18101ec
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 8 deletions.
16 changes: 16 additions & 0 deletions examples/monai-2D-mednist/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ If using pseudo-distributed mode with docker-compose:
- `Docker <https://docs.docker.com/get-docker>`__
- `Docker Compose <https://docs.docker.com/compose/install>`__

Download and Prepare the data
-------------------------------------------

Install monai

.. code-block::
pip install monai
Download and divide the data into parts. Set the number of
data parts as an arguments python prepare_data.py NR-OF-DATAPARTS. In the
below command we divide the dataset into 10 parts.
.. code-block::
python prepare_data.py 10
Creating the compute package and seed model
-------------------------------------------

Expand Down
7 changes: 4 additions & 3 deletions examples/monai-2D-mednist/client/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import random

import sys
import numpy as np
import PIL
import torch
Expand Down Expand Up @@ -33,7 +33,7 @@ def split_data(data_path="data/MedNIST", splits=100, validation_split=0.9):
yaml.dump(clients, file, default_flow_style=False)


def get_data(out_dir="data"):
def get_data(out_dir="data", data_splits=10):
"""Get data from the external repository.
:param out_dir: Path to data directory. If doesn't
Expand All @@ -58,7 +58,7 @@ def get_data(out_dir="data"):
else:
print("files already exist.")

split_data()
split_data(splits=data_splits)


def get_classes(data_path):
Expand Down Expand Up @@ -145,6 +145,7 @@ def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
print("__getitem__ path: ", os.path.join(self.data_path, self.image_files[index]))
return (self.transforms(os.path.join(self.data_path, self.image_files[index])), DATA_CLASSES[os.path.dirname(self.image_files[index])])


Expand Down
12 changes: 7 additions & 5 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,21 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
batch_size = client_settings["batch_size"]
max_epochs = client_settings["local_epochs"]
num_workers = client_settings["num_workers"]
split_index = client_settings["split_index"]
split_index = os.environ.get("FEDN_DATA_SPLIT_INDEX")#client_settings["split_index"]
print("split index: ", split_index)
lr = client_settings["lr"]

if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH")

print("os.path.join(os.path.dirname(data_path), data_splits.yaml: ", os.path.join(os.path.dirname(data_path), "data_splits.yaml"))
with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), "r") as file:
clients = yaml.safe_load(file)

image_list = clients["client " + str(split_index)]["train"]

train_ds = MedNISTDataset(data_path="data/MedNIST", transforms=train_transforms, image_files=image_list)

print("image_list len: ", len(image_list))
train_ds = MedNISTDataset(data_path="app/data/MedNIST", transforms=train_transforms, image_files=image_list)
print("train_ds len: ", len(train_ds))
print("batch_size: ", batch_size, ", num_workers: ", num_workers)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Load parmeters and initialize model
Expand Down
66 changes: 66 additions & 0 deletions examples/monai-2D-mednist/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys
import numpy as np

import yaml
from monai.apps import download_and_extract


def split_data(data_path="data/MedNIST", splits=100, validation_split=0.9):
# create clients
clients = {"client " + str(i): {"train": [], "validation": []} for i in range(splits)}
print("splits: ", splits)
for class_ in os.listdir(data_path):
if os.path.isdir(os.path.join(data_path, class_)):
patients_in_class = [os.path.join(class_, patient) for patient in os.listdir(os.path.join(data_path, class_))]
np.random.shuffle(patients_in_class)
chops = np.int32(np.linspace(0, len(patients_in_class), splits + 1))
for split in range(splits):
p = patients_in_class[chops[split] : chops[split + 1]]

valsplit = np.int32(len(p) * validation_split)

clients["client " + str(split)]["train"] += p[:valsplit]
clients["client " + str(split)]["validation"] += p[valsplit:]

if split == 0:
print("len p: ", len(p))
print("valsplit: ", valsplit)
print("p[:valsplit]: ", p[:valsplit])
print("p[valsplit:]: ", p[valsplit:])

with open(os.path.join(os.path.dirname(data_path), "data_splits.yaml"), "w") as file:
yaml.dump(clients, file, default_flow_style=False)


def get_data(out_dir="data", data_splits=10):
"""Get data from the external repository.
:param out_dir: Path to data directory. If doesn't
:type data_dir: str
"""
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(out_dir, "MedNIST.tar.gz")

data_dir = os.path.abspath(out_dir)
print("data_dir:", data_dir)
if os.path.exists(data_dir):
print("path exist.")
if not os.path.exists(compressed_file):
print("compressed file does not exist, downloading and extracting data.")
download_and_extract(resource, compressed_file, data_dir, md5)
else:
print("files already exist.")

split_data(splits=data_splits)


if __name__ == "__main__":
# Prepare data if not already done
get_data(data_splits=int(sys.argv[1]))

0 comments on commit 18101ec

Please sign in to comment.