Skip to content

Commit

Permalink
Feature/SK-954 | Updated file structure in mnist-keras example (#678)
Browse files Browse the repository at this point in the history
* restrucuturing files in client folder

* fixing imports

* ruff linting
  • Loading branch information
FrankJonasmoelle authored Aug 24, 2024
1 parent 23d0e98 commit ac76702
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 211 deletions.
62 changes: 50 additions & 12 deletions examples/mnist-keras/client/get_data.py → examples/mnist-keras/client/data.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,52 @@
import numpy as np
import tensorflow as tf

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)

NUM_CLASSES = 10


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)


def load_data(data_path, is_train=True):
"""Load data from disk.
:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load train or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.npz")

data = np.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X.astype("float32")
X = np.expand_dims(X, -1)
X = X / 255
y = tf.keras.utils.to_categorical(y, NUM_CLASSES)

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
Expand Down Expand Up @@ -33,16 +79,8 @@ def split(dataset="data/mnist.npz", outdir="data", n_splits=2):
np.savez(f"{subdir}/mnist.npz", x_train=data["x_train"][i], y_train=data["y_train"][i], x_test=data["x_test"][i], y_test=data["y_test"][i])


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)


if __name__ == "__main__":
get_data()
split()
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
194 changes: 0 additions & 194 deletions examples/mnist-keras/client/entrypoint.py

This file was deleted.

10 changes: 5 additions & 5 deletions examples/mnist-keras/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python entrypoint.py init_seed
command: python model.py
startup:
command: python get_data.py
command: python data.py
train:
command: python entrypoint.py train $ENTRYPOINT_OPTS
command: python train.py
validate:
command: python entrypoint.py validate $ENTRYPOINT_OPTS
command: python validate.py
predict:
command: python entrypoint.py predict $ENTRYPOINT_OPTS
command: python predict.py
71 changes: 71 additions & 0 deletions examples/mnist-keras/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tensorflow as tf

from fedn.utils.helpers.helpers import get_helper

NUM_CLASSES = 10
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model(img_rows=28, img_cols=28):
"""Compile the TF model.
param: img_rows: The number of rows in the image
type: img_rows: int
param: img_cols: The number of rows in the image
type: img_cols: int
return: The compiled model
type: keras.model.Sequential
"""
# Set input shape
input_shape = (img_rows, img_cols, 1)

# Define model
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=input_shape))
model.add(tf.keras.layers.Dense(64, activation="relu"))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(32, activation="relu"))
model.add(tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"))
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"])
return model


def save_parameters(model, out_path):
"""Save model parameters to file.
:param model: The model to serialize.
:type model: keras.model.Sequential
:param out_path: The path to save the model to.
:type out_path: str
"""
weights = model.get_weights()
helper.save(weights, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.
:param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: keras.model.Sequential
"""
model = compile_model()
weights = helper.load(model_path)
model.set_weights(weights)
return model


def init_seed(out_path="../seed.npz"):
"""Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
weights = compile_model().get_weights()
helper.save(weights, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
30 changes: 30 additions & 0 deletions examples/mnist-keras/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import os
import sys

import numpy as np
from data import load_data
from model import load_parameters

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_json_path, data_path=None):
# Using test data for inference but another dataset could be loaded
x_test, _ = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)

# Infer
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)

# Save JSON
with open(out_json_path, "w") as fh:
fh.write(json.dumps({"predictions": y_pred.tolist()}))


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
Loading

0 comments on commit ac76702

Please sign in to comment.