diff --git a/examples/monai-2D-mednist/client/train.py b/examples/monai-2D-mednist/client/train.py index 029e9e407..e3cb235c0 100644 --- a/examples/monai-2D-mednist/client/train.py +++ b/examples/monai-2D-mednist/client/train.py @@ -1,17 +1,11 @@ import os import sys +import numpy as np import torch import yaml from data import MedNISTDataset from model import load_parameters, save_parameters - -from fedn.utils.helpers.helpers import save_metadata - -dir_path = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(os.path.abspath(dir_path)) - -import numpy as np from monai.data import DataLoader from monai.transforms import ( Compose, @@ -23,6 +17,12 @@ ScaleIntensity, ) +from fedn.utils.helpers.helpers import save_metadata + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + train_transforms = Compose( [ LoadImage(image_only=True), @@ -58,7 +58,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No with open(client_settings_path, "r") as fh: # Used by CJG for local training try: client_settings = dict(yaml.safe_load(fh)) - except yaml.YAMLError as e: + except yaml.YAMLError: raise print("client settings: ", client_settings) @@ -111,7 +111,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") - print(f"training completed!") + print("training completed!") # Metadata needed for aggregation server side metadata = { diff --git a/examples/monai-2D-mednist/client/validate.py b/examples/monai-2D-mednist/client/validate.py index f94cd41e7..74292c34f 100644 --- a/examples/monai-2D-mednist/client/validate.py +++ b/examples/monai-2D-mednist/client/validate.py @@ -40,7 +40,7 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path= with open(client_settings_path, "r") as fh: # Used by CJG for local training try: client_settings = dict(yaml.safe_load(fh)) - except yaml.YAMLError as e: + except yaml.YAMLError: raise num_workers = client_settings["num_workers"]