Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jun 10, 2024
1 parent a13f1fc commit 6c79a65
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import os
import sys

import numpy as np
import torch
import yaml
from data import MedNISTDataset
from model import load_parameters, save_parameters

from fedn.utils.helpers.helpers import save_metadata

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))

import numpy as np
from monai.data import DataLoader
from monai.transforms import (
Compose,
Expand All @@ -23,6 +17,12 @@
ScaleIntensity,
)

from fedn.utils.helpers.helpers import save_metadata

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


train_transforms = Compose(
[
LoadImage(image_only=True),
Expand Down Expand Up @@ -58,7 +58,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
with open(client_settings_path, "r") as fh: # Used by CJG for local training
try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
except yaml.YAMLError:
raise

print("client settings: ", client_settings)
Expand Down Expand Up @@ -111,7 +111,7 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

print(f"training completed!")
print("training completed!")

# Metadata needed for aggregation server side
metadata = {
Expand Down
2 changes: 1 addition & 1 deletion examples/monai-2D-mednist/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=
with open(client_settings_path, "r") as fh: # Used by CJG for local training
try:
client_settings = dict(yaml.safe_load(fh))
except yaml.YAMLError as e:
except yaml.YAMLError:
raise

num_workers = client_settings["num_workers"]
Expand Down

0 comments on commit 6c79a65

Please sign in to comment.