Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-954 | Updated file structure in mnist-keras example #678

Merged
merged 3 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions examples/mnist-keras/client/get_data.py → examples/mnist-keras/client/data.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,52 @@
import numpy as np
import tensorflow as tf

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)

NUM_CLASSES = 10


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)


def load_data(data_path, is_train=True):
"""Load data from disk.

:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load train or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.npz")

data = np.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X.astype("float32")
X = np.expand_dims(X, -1)
X = X / 255
y = tf.keras.utils.to_categorical(y, NUM_CLASSES)

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
Expand Down Expand Up @@ -33,16 +79,8 @@ def split(dataset="data/mnist.npz", outdir="data", n_splits=2):
np.savez(f"{subdir}/mnist.npz", x_train=data["x_train"][i], y_train=data["y_train"][i], x_test=data["x_test"][i], y_test=data["y_test"][i])


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Download data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
np.savez(f"{out_dir}/mnist.npz", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)


if __name__ == "__main__":
get_data()
split()
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
194 changes: 0 additions & 194 deletions examples/mnist-keras/client/entrypoint.py

This file was deleted.

10 changes: 5 additions & 5 deletions examples/mnist-keras/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python entrypoint.py init_seed
command: python model.py
startup:
command: python get_data.py
command: python data.py
train:
command: python entrypoint.py train $ENTRYPOINT_OPTS
command: python train.py
validate:
command: python entrypoint.py validate $ENTRYPOINT_OPTS
command: python validate.py
predict:
command: python entrypoint.py predict $ENTRYPOINT_OPTS
command: python predict.py
71 changes: 71 additions & 0 deletions examples/mnist-keras/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tensorflow as tf

from fedn.utils.helpers.helpers import get_helper

NUM_CLASSES = 10
HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model(img_rows=28, img_cols=28):
"""Compile the TF model.

param: img_rows: The number of rows in the image
type: img_rows: int
param: img_cols: The number of rows in the image
type: img_cols: int
return: The compiled model
type: keras.model.Sequential
"""
# Set input shape
input_shape = (img_rows, img_cols, 1)

# Define model
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=input_shape))
model.add(tf.keras.layers.Dense(64, activation="relu"))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(32, activation="relu"))
model.add(tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"))
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"])
return model


def save_parameters(model, out_path):
"""Save model parameters to file.

:param model: The model to serialize.
:type model: keras.model.Sequential
:param out_path: The path to save the model to.
:type out_path: str
"""
weights = model.get_weights()
helper.save(weights, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.

:param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: keras.model.Sequential
"""
model = compile_model()
weights = helper.load(model_path)
model.set_weights(weights)
return model


def init_seed(out_path="../seed.npz"):
"""Initialize seed model and save it to file.

:param out_path: The path to save the seed model to.
:type out_path: str
"""
weights = compile_model().get_weights()
helper.save(weights, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
30 changes: 30 additions & 0 deletions examples/mnist-keras/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import os
import sys

import numpy as np
from data import load_data
from model import load_parameters

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_json_path, data_path=None):
# Using test data for inference but another dataset could be loaded
x_test, _ = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)

# Infer
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)

# Save JSON
with open(out_json_path, "w") as fh:
fh.write(json.dumps({"predictions": y_pred.tolist()}))


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
Loading
Loading