Skip to content

Commit

Permalink
mods
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed May 28, 2024
1 parent d70a9bb commit ba822a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 30 deletions.
18 changes: 6 additions & 12 deletions examples/monai-2D-mednist/client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __getitem__(self, index):
train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers )

return train_ds, train_loader
return train_loader


def train(in_model_path, out_model_path, data_path=None, client_settings_path=None):
Expand Down Expand Up @@ -94,25 +94,20 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
sample_size = client_settings['sample_size']
lr = client_settings['lr']

#val_interval = 1
num_class = len(get_classes(data_path))

# Load data
x_train, y_train = load_data(data_path, sample_size)
train_ds, train_loader = pre_training_settings(num_class, batch_size, x_train, y_train, num_workers)

x_train, y_train = load_data(data_path, sample_size)
train_loader = pre_training_settings(num_class, batch_size, x_train, y_train, num_workers)

# Load parmeters and initialize model
model = load_parameters(in_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
loss_function = torch.nn.CrossEntropyLoss()

# Train

best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
# writer = SummaryWriter()

for epoch in range(max_epochs):
Expand All @@ -130,9 +125,8 @@ def train(in_model_path, out_model_path, data_path=None, client_settings_path=No
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
epoch_len = len(train_ds) // train_loader.batch_size
# writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
print(f"{step}/{len(sample_size) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")

epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
Expand Down
22 changes: 4 additions & 18 deletions examples/monai-2D-mednist/client/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,7 @@
sys.path.append(os.path.abspath(dir_path))


def pre_validation_settings(num_class, batch_size, train_x, train_y, val_x, val_y, num_workers=2):

train_transforms = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
]
)
def pre_validation_settings(batch_size, train_x, train_y, val_x, val_y, num_workers=2):

val_transforms = Compose([LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])

Expand All @@ -63,13 +52,11 @@ def __len__(self):
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]

train_ds = MedNISTDataset(train_x, train_y, val_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers)

val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers)

return train_ds, train_loader, val_ds, val_loader
return val_loader



Expand Down Expand Up @@ -103,9 +90,8 @@ def validate(in_model_path, out_json_path, data_path=None, client_settings_path=
# Load data
x_train, y_train = load_data(data_path, sample_size)
x_val, y_val = load_data(data_path, sample_size, is_train=False)

num_class = len(get_classes(data_path))
train_ds, train_loader, val_ds, val_loader = pre_validation_settings(num_class, batch_size, x_train, y_train, x_val, y_val, num_workers)

val_loader = pre_validation_settings(batch_size, x_train, y_train, x_val, y_val, num_workers)

# Load model
model = load_parameters(in_model_path)
Expand Down

0 comments on commit ba822a6

Please sign in to comment.