Skip to content

Commit

Permalink
Cleaned a bit in examples and added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 23, 2024
1 parent 5868624 commit d214884
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 30 deletions.
77 changes: 50 additions & 27 deletions examples/mnist-keras/client/entrypoint
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import tensorflow as tf
from fedn.utils.helpers import get_helper, save_metadata, save_metrics

HELPER_MODULE = 'numpyhelper'
helper = get_helper(HELPER_MODULE)

NUM_CLASSES = 10


Expand All @@ -23,6 +25,15 @@ def _get_data_path():


def compile_model(img_rows=28, img_cols=28):
""" Compile the TF model.
param: img_rows: The number of rows in the image
type: img_rows: int
param: img_cols: The number of rows in the image
type: img_cols: int
return: The compiled model
type: keras.model.Sequential
"""
# Set input shape
input_shape = (img_rows, img_cols, 1)

Expand All @@ -36,6 +47,7 @@ def compile_model(img_rows=28, img_cols=28):
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])

return model


Expand Down Expand Up @@ -63,41 +75,71 @@ def load_data(data_path, is_train=True):


def init_seed(out_path='seed.npz'):
""" Initialize seed model and save it to file.
:param out_path: The path to save the seed model to.
:type out_path: str
"""
weights = compile_model().get_weights()
helper = get_helper(HELPER_MODULE)
helper.save(weights, out_path)


def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1):
""" Complete a model update.
Load model paramters from in_model_path (managed by the FEDn client),
perform a model update, and write updated paramters
to out_model_path (picked up by the FEDn client).
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_model_path: The path to save the output model to.
:type out_model_path: str
:param data_path: The path to the data file.
:type data_path: str
:param batch_size: The batch size to use.
:type batch_size: int
:param epochs: The number of epochs to train.
:type epochs: int
"""
# Load data
x_train, y_train = load_data(data_path)

# Load model
model = compile_model()
helper = get_helper(HELPER_MODULE)
weights = helper.load(in_model_path)
model.set_weights(weights)

# Train
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)

# Save
weights = model.get_weights()
helper.save(weights, out_model_path)

# Metadata about local training to be passed on to aggregator
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'batch_size': batch_size,
'epochs': epochs,
}

# Save JSON metadata file
# Save JSON metadata file (mandatory)
save_metadata(metadata, out_model_path)

# Save model update (mandatory)
weights = model.get_weights()
helper.save(weights, out_model_path)


def validate(in_model_path, out_json_path, data_path=None):
""" Validate model.
:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_json_path: The path to save the output JSON to.
:type out_json_path: str
:param data_path: The path to the data file.
:type data_path: str
"""

# Load data
x_train, y_train = load_data(data_path)
x_test, y_test = load_data(data_path, is_train=False)
Expand Down Expand Up @@ -126,25 +168,6 @@ def validate(in_model_path, out_json_path, data_path=None):
save_metrics(report, out_json_path)


def infer(in_model_path, out_json_path, data_path=None):
# Using test data for inference but another dataset could be loaded
x_test, _ = load_data(data_path, is_train=False)

# Load model
model = compile_model()
helper = get_helper(HELPER_MODULE)
weights = helper.load(in_model_path)
model.set_weights(weights)

# Infer
y_pred = model.predict(x_test)
y_pred = np.argmax(y_pred, axis=1)

# Save JSON
with open(out_json_path, "w") as fh:
fh.write(json.dumps({'predictions': y_pred.tolist()}))


if __name__ == '__main__':
fire.Fire({
'init_seed': init_seed,
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist-pytorch/client/entrypoint
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def compile_model():
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x

# Return model
return Net()


Expand Down Expand Up @@ -166,16 +165,17 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1

# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
'num_examples': len(x_train),
'batch_size': batch_size,
'epochs': epochs,
'lr': lr
}

# Save JSON metadata file
# Save JSON metadata file (mandatory)
save_metadata(metadata, out_model_path)

# Save model update
# Save model update (mandatory)
save_parameters(model, out_model_path)


Expand Down

0 comments on commit d214884

Please sign in to comment.