Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Radha Guhane <[email protected]>
  • Loading branch information
RadhaGulhane13 committed Nov 8, 2023
1 parent da24acd commit d7579aa
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 53 deletions.
18 changes: 18 additions & 0 deletions benchmarks/gems_master_model/benchmark_amoebanet_gems_master.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Copyright 2023, The Ohio State University. All rights reserved.
# The MPI4DL software package is developed by the team members of
# The Ohio State University's Network-Based Computing Laboratory (NBCL),
# headed by Professor Dhabaleswar K. (DK) Panda.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.distributed as dist
import torchvision.transforms as transforms
Expand Down
18 changes: 18 additions & 0 deletions benchmarks/gems_master_model/benchmark_resnet_gems_master.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Copyright 2023, The Ohio State University. All rights reserved.
# The MPI4DL software package is developed by the team members of
# The Ohio State University's Network-Based Computing Laboratory (NBCL),
# headed by Professor Dhabaleswar K. (DK) Panda.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.distributed as dist
import torchvision.transforms as transforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_depth(version, n):
num_workers=0,
pin_memory=True,
)
size_dataset = 1030
size_dataset = len(my_dataloader.dataset)
elif APP == 2:
transform = transforms.Compose(
[
Expand All @@ -358,7 +358,7 @@ def get_depth(version, n):
pin_memory=True,
)
size_dataset = len(my_dataloader.dataset)
elif APP == 3:
else:
my_dataset = torchvision.datasets.FakeData(
size=10 * batch_size * args.times,
image_size=(3, image_size, image_size),
Expand All @@ -375,28 +375,6 @@ def get_depth(version, n):
pin_memory=True,
)
size_dataset = 10 * batch_size
else:
transform = transforms.Compose(
[
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.ImageFolder(
datapath,
transform=transform,
target_transform=None,
)
my_dataloader = torch.utils.data.DataLoader(
trainset,
batch_size=times * batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
size_dataset = len(my_dataloader.dataset)

################################################################################

sync_comm = gems_comm.SyncAllreduce(mpi_comm_first)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_depth(version, n):
num_workers=0,
pin_memory=True,
)
size_dataset = 1030
size_dataset = len(my_dataloader.dataset)
elif APP == 2:
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
Expand All @@ -349,8 +349,8 @@ def get_depth(version, n):
num_workers=0,
pin_memory=True,
)
size_dataset = 50000
elif APP == 3:
size_dataset = len(my_dataloader.dataset)
else:
my_dataset = torchvision.datasets.FakeData(
size=10 * batch_size * args.times,
image_size=(3, image_size, image_size),
Expand All @@ -367,27 +367,6 @@ def get_depth(version, n):
pin_memory=True,
)
size_dataset = 10 * batch_size
else:
transform = transforms.Compose(
[
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trainset = torchvision.datasets.ImageFolder(
datapath,
transform=transform,
target_transform=None,
)
my_dataloader = torch.utils.data.DataLoader(
trainset,
batch_size=times * batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
)
size_dataset = len(my_dataloader.dataset)

################################################################################

Expand Down
1 change: 0 additions & 1 deletion src/torchgems/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def create_allreduce_comm_spatial(self):

if self.ENABLE_MASTER:
for i in range(len(ranks)):
# ranks[i] = (self.mp_size - 1 - ranks[i])
ranks.append(self.mp_size - 1 - ranks[i])

temp_spatial_allreduce_grp = torch.distributed.new_group(ranks=ranks)
Expand Down
8 changes: 4 additions & 4 deletions src/torchgems/train_spatial_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,21 +462,21 @@ def run_step(self, inputs, labels):
loss, correct = 0, 0
# torch.cuda.empty_cache()

# self.train_model1.models = self.train_model1.models.to("cuda")
# self.train_model1.models = self.train_model1.models.to('cuda')
temp_loss, temp_correct = self.train_model1.run_step(
inputs[: self.batch_size], labels[: self.batch_size]
)
loss += temp_loss
correct += temp_correct

# torch.cuda.empty_cache()
# self.train_model1.models = self.train_model1.models.to("cpu")
# self.train_model2.models = self.train_model2.models.to("cuda")
# self.train_model1.models = self.train_model1.models.to('cpu')
# self.train_model2.models = self.train_model2.models.to('cuda')
temp_loss, temp_correct = self.train_model2.run_step(
inputs[self.batch_size : 2 * self.batch_size],
labels[self.batch_size : 2 * self.batch_size],
)
# self.train_model2.models = self.train_model2.models.to("cpu")
# self.train_model2.models = self.train_model2.models.to('cpu')

# torch.cuda.empty_cache()

Expand Down

0 comments on commit d7579aa

Please sign in to comment.