Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Radha Guhane <[email protected]>
  • Loading branch information
RadhaGulhane13 committed Nov 8, 2023
1 parent 2aa0262 commit 8624851
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 148 deletions.
9 changes: 1 addition & 8 deletions benchmarks/gems_master_model/benchmark_resnet_gems_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torchgems.mp_pipeline import model_generator
from torchgems.gems_master import train_model_master
import torchgems.comm as gems_comm
from torchgems.utils import get_depth

parser_obj = parser.get_parser()
args = parser_obj.parse_args()
Expand Down Expand Up @@ -92,14 +93,6 @@ def init_processes(backend="mpi"):
ENABLE_ASYNC = True
resnet_n = 12


def get_depth(version, n):
if version == 1:
return n * 6 + 2
elif version == 2:
return n * 9 + 2


###############################################################################
mpi_comm = gems_comm.MPIComm(split_size=mp_size, ENABLE_MASTER=True)
rank = mpi_comm.rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,14 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


def init_processes(backend="tcp"):
def init_processes(backend="mpi"):
"""Initialize the distributed environment."""
dist.init_process_group(backend)
size = dist.get_world_size()
rank = dist.get_rank()
return size, rank


def get_depth(version, n):
if version == 1:
return n * 6 + 2
elif version == 2:
return n * 9 + 2


sys.stdout = Unbuffered(sys.stdout)

# Example of GEMS + SPATIAL split_size = 2, spatial_size = 1, num_spatial_parts = 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
import torchgems.comm as gems_comm
from models import resnet
from torchgems.utils import get_depth

# Example of GEMS + SPATIAL split_size = 2, spatial_size = 1, num_spatial_parts = 4
#
Expand Down Expand Up @@ -85,7 +86,7 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


def init_processes(backend="tcp"):
def init_processes(backend="mpi"):
"""Initialize the distributed environment."""
dist.init_process_group(backend)
size = dist.get_world_size()
Expand Down Expand Up @@ -136,14 +137,6 @@ def init_processes(backend="tcp"):
image_size_seq = 32
resnet_n = 12


def get_depth(version, n):
if version == 1:
return n * 6 + 2
elif version == 2:
return n * 9 + 2


###############################################################################

mpi_comm_first = gems_comm.MPIComm(
Expand Down
10 changes: 1 addition & 9 deletions benchmarks/layer_parallelism/benchmark_resnet_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchgems.mp_pipeline import model_generator, train_model
from models import resnet
import torchgems.comm as gems_comm

from torchgems.utils import get_depth

parser_obj = parser.get_parser()
args = parser_obj.parse_args()
Expand Down Expand Up @@ -79,14 +79,6 @@ def __getattr__(self, attr):
image_size_seq = 32
resnet_n = 12


def get_depth(version, n):
if version == 1:
return n * 6 + 2
elif version == 2:
return n * 9 + 2


###############################################################################

mpi_comm = gems_comm.MPIComm(split_size=mp_size, ENABLE_MASTER=False)
Expand Down
59 changes: 12 additions & 47 deletions benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
import logging
from torchgems import parser
from torchgems.mp_pipeline import model_generator
from torchgems.train_spatial import train_model_spatial, split_input, get_shapes_spatial
from torchgems.train_spatial import (
train_model_spatial,
split_input,
get_shapes_spatial,
verify_spatial_config,
)
import torchgems.comm as gems_comm

parser_obj = parser.get_parser()
Expand Down Expand Up @@ -62,7 +67,7 @@ def __getattr__(self, attr):
return getattr(self.stream, attr)


def init_processes(backend="tcp"):
def init_processes(backend="mpi"):
"""Initialize the distributed environment."""
dist.init_process_group(backend)
size = dist.get_world_size()
Expand All @@ -84,6 +89,7 @@ def init_processes(backend="tcp"):
balance = args.balance
split_size = args.split_size
spatial_size = args.spatial_size
slice_method = args.slice_method
times = args.times
datapath = args.datapath
num_workers = args.num_workers
Expand All @@ -107,48 +113,7 @@ def init_processes(backend="tcp"):
spatial_part_size = num_spatial_parts_list[0] # Partition size for spatial parallelism


def isPowerTwo(num):
return not (num & (num - 1))


"""
For Amoebanet model, image size and image size after partitioning should be power of two.
As, Amoebanet performs summation of results of two convolution layers during training,
odd input size(i.e. image size which is not power of 2) will give different output sizes
for convolution operations present at same layer, thus it will throw error as addition
operation can not be performed with diffent size outputs.
"""


def verify_config():
assert args.slice_method in [
"square",
"vertical",
"horizontal",
], "Possible slice methods are ['square', 'vertical', 'horizontal']"

assert args.app in range(
1, 4
), "Possible Application values should be 1, 2, or 3 i.e. 1.medical, 2.cifar, and 3.synthetic"

assert isPowerTwo(int(image_size)), "Image size should be power of Two"

if args.slice_method == "square":
assert isPowerTwo(
int(image_size / math.sqrt(spatial_part_size))
), "Image size of each partition should be power of Two"
else:
assert isPowerTwo(
int(image_size / spatial_part_size)
), "Image size of each partition should be power of Two"

for each_part_size in num_spatial_parts_list:
assert (
each_part_size == spatial_part_size
), "Size of each SP partition should be same"


verify_config()
verify_spatial_config(slice_method, image_size, num_spatial_parts_list)

##################### AmoebaNet model specific parameters #####################

Expand Down Expand Up @@ -207,7 +172,7 @@ def verify_config():
image_size_times = int(image_size / image_size_seq)
amoebanet_shapes_list = get_shapes_spatial(
model_gen_seq.shape_list,
args.slice_method,
slice_method,
spatial_size,
num_spatial_parts_list,
image_size_times,
Expand Down Expand Up @@ -273,7 +238,7 @@ def verify_config():
parts=parts,
ASYNC=True,
GEMS_INVERSE=False,
slice_method=args.slice_method,
slice_method=slice_method,
LOCAL_DP_LP=LOCAL_DP_LP,
mpi_comm=mpi_comm,
)
Expand Down Expand Up @@ -365,7 +330,7 @@ def run_epoch():
x = split_input(
inputs,
image_size,
args.slice_method,
slice_method,
local_rank,
num_spatial_parts_list,
)
Expand Down
76 changes: 18 additions & 58 deletions benchmarks/spatial_parallelism/benchmark_resnet_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@
import logging
from torchgems import parser
from torchgems.mp_pipeline import model_generator
from torchgems.train_spatial import train_model_spatial, split_input, get_shapes_spatial
from torchgems.train_spatial import (
train_model_spatial,
split_input,
get_shapes_spatial,
verify_spatial_config,
)
import torchgems.comm as gems_comm
from torchgems.utils import get_depth

parser_obj = parser.get_parser()
args = parser_obj.parse_args()
Expand Down Expand Up @@ -82,6 +88,7 @@ def init_processes(backend="mpi"):
balance = args.balance
split_size = args.split_size
spatial_size = args.spatial_size
slice_method = args.slice_method
times = args.times
datapath = args.datapath
num_workers = args.num_workers
Expand Down Expand Up @@ -114,58 +121,10 @@ def init_processes(backend="mpi"):
image_size_seq = 32
resnet_n = 12


def get_depth(version, n):
if version == 1:
return n * 6 + 2
elif version == 2:
return n * 9 + 2


###############################################################################


def isPowerTwo(num):
return not (num & (num - 1))


"""
For ResNet model, image size and image size after partitioning should be power of two.
As, ResNet performs convolution operations at different layers, odd input size
(i.e. image size which is not power of 2) will lead to truncation of input. Thus,
other GPU devices will receive truncated input with unexpected input size.
"""


def verify_config():
assert args.slice_method in [
"square",
"vertical",
"horizontal",
], "Possible slice methods are ['square', 'vertical', 'horizontal']"

assert args.app in range(
1, 4
), "Possible Application values should be 1, 2, or 3 i.e. 1.medical, 2.cifar, and 3.synthetic"

assert isPowerTwo(int(image_size)), "Image size should be power of Two"

if args.slice_method == "square":
assert isPowerTwo(
int(image_size / math.sqrt(spatial_part_size))
), "Image size of each partition should be power of Two"
else:
assert isPowerTwo(
int(image_size / spatial_part_size)
), "Image size of each partition should be power of Two"

for each_part_size in num_spatial_parts_list:
assert (
each_part_size == spatial_part_size
), "Size of each SP partition should be same"


verify_config()
verify_spatial_config(slice_method, image_size, num_spatial_parts_list)

mpi_comm = gems_comm.MPIComm(
split_size=split_size,
Expand All @@ -189,7 +148,8 @@ def verify_config():

# Initialize ResNet model
model_seq = resnet.get_resnet_v2(
(int(batch_size / parts), 3, image_size_seq, image_size_seq), depth=get_depth(2, 12)
(int(batch_size / parts), 3, image_size_seq, image_size_seq),
depth=get_depth(2, resnet_n),
)

model_gen_seq = model_generator(
Expand All @@ -209,7 +169,7 @@ def verify_config():
image_size_times = int(image_size / image_size_seq)
resnet_shapes_list = get_shapes_spatial(
model_gen_seq.shape_list,
args.slice_method,
slice_method,
spatial_size,
num_spatial_parts_list,
image_size_times,
Expand All @@ -223,28 +183,28 @@ def verify_config():
if args.halo_d2:
model, balance = resnet_spatial.get_resnet_v2(
input_shape=(batch_size / parts, 3, image_size, image_size),
depth=get_depth(2, 12),
depth=get_depth(2, resnet_n),
local_rank=local_rank % spatial_part_size,
mp_size=split_size,
balance=balance,
spatial_size=spatial_size,
num_spatial_parts=num_spatial_parts,
num_classes=num_classes,
fused_layers=args.fused_layers,
slice_method=args.slice_method,
slice_method=slice_method,
)
else:
model = resnet_spatial.get_resnet_v2(
input_shape=(batch_size / parts, 3, image_size, image_size),
depth=get_depth(2, 12),
depth=get_depth(2, resnet_n),
local_rank=local_rank % spatial_part_size,
mp_size=split_size,
balance=balance,
spatial_size=spatial_size,
num_spatial_parts=num_spatial_parts,
num_classes=num_classes,
fused_layers=args.fused_layers,
slice_method=args.slice_method,
slice_method=slice_method,
)


Expand Down Expand Up @@ -275,7 +235,7 @@ def verify_config():
parts=1,
ASYNC=True,
GEMS_INVERSE=False,
slice_method=args.slice_method,
slice_method=slice_method,
mpi_comm=mpi_comm,
)

Expand Down Expand Up @@ -366,7 +326,7 @@ def run_epoch():
x = split_input(
inputs,
image_size,
args.slice_method,
slice_method,
local_rank,
num_spatial_parts_list,
)
Expand Down
6 changes: 1 addition & 5 deletions src/torchgems/train_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
import torch
import math
import torch.distributed as dist


def isPowerTwo(num):
return not (num & (num - 1))

from utils import isPowerTwo

"""
For SP, image size and image size after partitioning should be power of two.
Expand Down
4 changes: 0 additions & 4 deletions src/torchgems/train_spatial_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import torch.distributed as dist


def isPowerTwo(num):
return not (num & (num - 1))


"""
For SP, image size and image size after partitioning should be power of two.
As, while performing convolution operations at different layers, odd input size
Expand Down
Loading

0 comments on commit 8624851

Please sign in to comment.