Skip to content

Commit

Permalink
Add bfloat16 support and Fix Pipeline Parallelsim for SP
Browse files Browse the repository at this point in the history
Signed-off-by: Radha Gulhane <[email protected]>
  • Loading branch information
RadhaGulhane13 committed Dec 28, 2023
1 parent 0365979 commit eae2ec9
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 34 deletions.
23 changes: 14 additions & 9 deletions benchmarks/gems_master_model/benchmark_resnet_gems_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ def init_processes(backend="mpi"):
precision = str(args.precision)
backend = args.backend


if precision == "bf_16":
assert torch.cuda.is_bf16_supported() == True, "Native System doen't support bf16"

EVAL_MODE = args.enable_evaluation
CHECKPOINT = None
if EVAL_MODE and APP != 3:
# Note MPI4DL_ImageNeteee.pth is with image_size 256 and 10 num_classes
CHECKPOINT = "/home/gulhane.2/github_torch_gems/MPI4DL/benchmarks/MPI4DL_Checkpoints/MPI4DL_ImageNeteee.pth"
CHECKPOINT = "/users/PAS2312/rgulhane/nowlab/checkpoints/imagenetee_img_size_64/MPI4DL_ImageNeteee_TensorRT_model_temp.pth"


################## ResNet model specific parameters/functions ##################
Expand Down Expand Up @@ -254,7 +258,7 @@ def init_processes(backend="mpi"):
# root="/home/gulhane.2/GEMS_Inference/datasets/ImageNet/", split='val', transform=transform
# )
testset = torchvision.datasets.ImageFolder(
root="/home/gulhane.2/github_torch_gems/MPI4DL/benchmarks/single_gpu/imagenette2-320/val",
root=datapath,
transform=transform,
target_transform=None,
)
Expand Down Expand Up @@ -400,15 +404,16 @@ def run_eval():
start_event.record()
if precision == "fp_16":
inputs = inputs.half()
# inputs = inputs.to(torch.float16)
elif precision == "bfp_16":
inputs = inputs.to(torch.bfloat16)
# labels = labels.to(torch.float16)

if batch > math.floor(size_dataset / (times * batch_size)) - 1:
break
before_step = torch.cuda.max_memory_allocated(device="cuda")
print(
f"Max Memory before step {batch} on rank {local_rank} Using PyTorch CUDA: {before_step / (1024 ** 2):.2f} MB"
)
# print(
# f"Max Memory before step {batch} on rank {local_rank} Using PyTorch CUDA: {before_step / (1024 ** 2):.2f} MB"
# )

local_loss, local_correct = tm_master.run_step(
inputs, labels, eval_mode=EVAL_MODE
Expand All @@ -425,9 +430,9 @@ def run_eval():
f"Step :{batch}, LOSS: {local_loss}, Global loss: {loss/(batch+1)} Acc: {local_correct} [{batch * len(inputs):>5d}/{size:>5d}]"
)
after_step = torch.cuda.max_memory_allocated(device="cuda")
print(
f"Max Memory after step {batch} on rank {local_rank} Using PyTorch CUDA: {after_step / (1024 ** 2):.2f} MB"
)
# print(
# f"Max Memory after step {batch} on rank {local_rank} Using PyTorch CUDA: {after_step / (1024 ** 2):.2f} MB"
# )

if local_rank == 0:
print(f"images per sec:{batch_size / t}")
Expand Down
14 changes: 10 additions & 4 deletions benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def init_processes(backend="mpi"):
# 3: synthetic
APP = args.app
num_classes = args.num_classes
backend = args.backend
EVAL_MODE = args.enable_evaluation

temp_num_spatial_parts = args.num_spatial_parts.split(",")

Expand Down Expand Up @@ -133,6 +135,7 @@ def init_processes(backend="mpi"):
num_spatial_parts=num_spatial_parts,
spatial_size=spatial_size,
LOCAL_DP_LP=LOCAL_DP_LP,
backend=backend
)
sync_allreduce = gems_comm.SyncAllreduce(mpi_comm)

Expand All @@ -149,7 +152,7 @@ def init_processes(backend="mpi"):
else:
balance = None


print(f"At start : Rank : {local_rank} : {torch.cuda.mem_get_info()}")
# Initialize AmoebaNet model
model_seq = amoebanet.amoebanetd(
num_layers=num_layers, num_filters=num_filters, num_classes=num_classes
Expand Down Expand Up @@ -181,7 +184,7 @@ def init_processes(backend="mpi"):
del model_seq
del model_gen_seq
torch.cuda.ipc_collect()

print(f"before model defin : Rank : {local_rank} : {torch.cuda.mem_get_info()}")
# Initialize AmoebaNet model with Spatial and Model Parallelism support
if args.halo_d2:
model = amoebanet_d2.amoebanetd_spatial(
Expand Down Expand Up @@ -222,7 +225,7 @@ def init_processes(backend="mpi"):
model_gen.DDP_model(mpi_comm, num_spatial_parts, spatial_size, bucket_size=0)

logging.info(f"Shape of model on local_rank {local_rank} : {model_gen.shape_list}")

print(f"After model alloc : Rank : {local_rank} : {torch.cuda.mem_get_info()}")

# Initialize parameters require for training the model with Spatial and Model
# Parallelism support
Expand Down Expand Up @@ -305,6 +308,8 @@ def init_processes(backend="mpi"):
)
size_dataset = 10 * batch_size


print(f"After dataloader : Rank : {local_rank} : {torch.cuda.mem_get_info()}")
################################################################################

################################# Train Model ##################################
Expand All @@ -319,6 +324,7 @@ def run_epoch():
size = len(my_dataloader.dataset)
t = time.time()
for batch, data in enumerate(my_dataloader, 0):
print(f"At batch {batch} Rank : {local_rank} : {torch.cuda.mem_get_info()}")
start_event = torch.cuda.Event(enable_timing=True, blocking=True)
end_event = torch.cuda.Event(enable_timing=True, blocking=True)
start_event.record()
Expand All @@ -337,7 +343,7 @@ def run_epoch():
else:
x = inputs

local_loss, local_correct = t_s.run_step(x, labels)
local_loss, local_correct = t_s.run_step(x, labels, eval_mode=EVAL_MODE)
loss += local_loss
correct += local_correct

Expand Down
25 changes: 22 additions & 3 deletions benchmarks/spatial_parallelism/benchmark_resnet_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def init_processes(backend="mpi"):
times = args.times
datapath = args.datapath
num_workers = args.num_workers
LOCAL_DP_LP = args.local_DP

# APP
# 1: Medical
Expand All @@ -102,12 +103,16 @@ def init_processes(backend="mpi"):
precision = str(args.precision)
backend = args.backend

if precision == "bf_16":
assert torch.cuda.is_bf16_supported() == True, "Native System doen't support bf16"


EVAL_MODE = args.enable_evaluation
CHECKPOINT = None
if EVAL_MODE and APP != 3:
# Note MPI4DL_ImageNeteee.pth is with image_size 256 and 10 num_classes
CHECKPOINT = "/home/gulhane.2/github_torch_gems/MPI4DL/benchmarks/MPI4DL_Checkpoints/MPI4DL_ImageNeteee.pth"

# CHECKPOINT=f"/users/PAS2312/rgulhane/nowlab/checkpoints/sp_precision_32_gpu_5/checkpt_resnet_sp_{local_rank}.pth"

temp_num_spatial_parts = args.num_spatial_parts.split(",")

Expand Down Expand Up @@ -141,6 +146,7 @@ def init_processes(backend="mpi"):
ENABLE_SPATIAL=True,
num_spatial_parts=num_spatial_parts,
spatial_size=spatial_size,
LOCAL_DP_LP=LOCAL_DP_LP,
backend=backend,
)
sync_allreduce = gems_comm.SyncAllreduce(mpi_comm)
Expand All @@ -149,6 +155,9 @@ def init_processes(backend="mpi"):
local_rank = rank
split_rank = mpi_comm.split_rank

if EVAL_MODE and APP != 3:
# Note MPI4DL_ImageNeteee.pth is with image_size 256 and 10 num_classes
CHECKPOINT=f"/users/PAS2312/rgulhane/nowlab/checkpoints/sp_precision_32_gpu_5/checkpt_resnet_sp_{local_rank}.pth"

if balance != None:
balance = balance.split(",")
Expand Down Expand Up @@ -234,6 +243,8 @@ def init_processes(backend="mpi"):
checkpoint_path=CHECKPOINT,
precision=precision,
)
# model_gen.DDP_model(mpi_comm, num_spatial_parts, spatial_size, bucket_size=0)


logging.info(f"Shape of model on local_rank {local_rank} : {model_gen.shape_list}")

Expand All @@ -248,11 +259,12 @@ def init_processes(backend="mpi"):
num_spatial_parts=num_spatial_parts,
criterion=None,
optimizer=None,
parts=1,
parts=parts,
ASYNC=True,
GEMS_INVERSE=False,
slice_method=slice_method,
mpi_comm=mpi_comm,
LOCAL_DP_LP=LOCAL_DP_LP,
precision=precision,
eval_mode=EVAL_MODE
)
Expand All @@ -276,6 +288,10 @@ def init_processes(backend="mpi"):
torch.manual_seed(0)

if APP == 1:
transform = transforms.Compose(
[transforms.Resize((image_size, image_size)),
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.ImageFolder(
datapath, transform=transform, target_transform=None
)
Expand Down Expand Up @@ -319,7 +335,7 @@ def init_processes(backend="mpi"):

################################################################################

sync_allreduce.sync_model_spatial(model_gen)
# sync_allreduce.sync_model_spatial(model_gen)

################################# Train Model ##################################

Expand Down Expand Up @@ -354,6 +370,9 @@ def run_eval():

if precision == "fp_16":
x = x.half()
elif precision == "bfp_16":
x = x.to(torch.bfloat16)


local_loss, local_correct = t_s.run_step(x, labels, eval_mode=EVAL_MODE)
loss += local_loss
Expand Down
1 change: 1 addition & 0 deletions src/models/amoebanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def forward(
op1 = operations[i]
op2 = operations[i + 1]


h1 = op1(h1)
h2 = op2(h2)

Expand Down
15 changes: 12 additions & 3 deletions src/torchgems/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import math
import numpy as np

from torchgems import parser
parser_obj = parser.get_parser()
args = parser_obj.parse_args()

def env2int(env_list, default=-1):
for e in env_list:
Expand Down Expand Up @@ -165,7 +168,8 @@ def __init__(
self.LOCAL_DP_MP_Comm = None

self.allreduce_grp = self.create_allreduce_comm()
self.test_allreduce_comm(self.allreduce_grp)
if not args.enable_evaluation:
self.test_allreduce_comm(self.allreduce_grp)

def get_split_rank(self, num_spatial_parts_list, local_rank):
if isinstance(num_spatial_parts_list, list):
Expand Down Expand Up @@ -408,10 +412,15 @@ def sync_broadcast(self, model, src, grp_comm):
)

def sync_model_spatial(self, model_gen):
if self.local_rank < self.spatial_size * self.num_spatial_parts:
if isinstance(self.num_spatial_parts, list):
spatial_parts = self.num_spatial_parts[0]
else:
spatial_parts = self.num_spatial_parts

if self.local_rank < self.spatial_size * spatial_parts:
self.sync_broadcast(
model_gen.models,
src=math.floor(self.local_rank / self.num_spatial_parts),
src=math.floor(self.local_rank / spatial_parts),
grp_comm=self.spatial_allreduce_grp,
)

Expand Down
26 changes: 18 additions & 8 deletions src/torchgems/mp_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def get_start_end_layer_index(self, split_rank):
end_layer = len(self.model)
else:
num_layers = len(self.model)
# print(f"NUM LAYERS : {num_layers}")
assert sum(self.balance) == len(
self.model
), "balance and number of layers differs"
), f"balance and number of layers differs, {sum(self.balance)} != {len(self.model)}"

if split_rank == 0:
start_layer = 0
Expand Down Expand Up @@ -91,8 +92,8 @@ def ModelQuantizaton(self, model, split_rank, model_no):

print("In ModelQunarization....")

quantized_model_path = f"/home/gulhane.2/MPI4DL_Copy/MPI4DL/benchmarks/MPI4DL_Checkpoints/gems/split_size{self.split_size}_rank{split_rank}_model{model_no}"
print(quantized_model_path)
# quantized_model_path = f"/home/gulhane.2/MPI4DL_Copy/MPI4DL/benchmarks/MPI4DL_Checkpoints/gems/split_size{self.split_size}_rank{split_rank}_model{model_no}"
# print(quantized_model_path)
# if os.path.exists(quantized_model_path):
if False:
model = torch.jit.load(quantized_model_path)
Expand All @@ -101,9 +102,9 @@ def ModelQuantizaton(self, model, split_rank, model_no):
return model

# Create and save quantized model
checkpoint_path = "/home/gulhane.2/github_torch_gems/MPI4DL/benchmarks/MPI4DL_Checkpoints/MPI4DL_ImageNeteee.pth"
checkpoint = torch.load(checkpoint_path)
model_state_dist_split_layer = {}
# checkpoint_path = "/home/gulhane.2/github_torch_gems/MPI4DL/benchmarks/MPI4DL_Checkpoints/MPI4DL_ImageNeteee.pth"
# checkpoint = torch.load(checkpoint_path)
# model_state_dist_split_layer = {}

# for name, _ in model.named_parameters():
# print(name)
Expand Down Expand Up @@ -187,6 +188,9 @@ def ready_model(
print(f"precision : {precision}")
if precision == "fp_16":
self.models.half()
elif precision == "bfp_16":
print(f"using bfloat16.....main")
self.models = self.models.to(dtype=torch.bfloat16)
elif precision == "int_8":
self.models = self.ModelQuantizaton(temp_model, split_rank, model_no=1)
self.models.to("cuda")
Expand All @@ -198,10 +202,11 @@ def ready_model(

# eval_mode is True
assert checkpoint_path is not None, "No checkpoints found"

self.models = temp_model
checkpoint = torch.load(checkpoint_path)

# load required layers from entire model
model_state_dist_split_layer = {}
self.models = temp_model

for name, _ in self.models.named_parameters():
model_state_dist_split_layer[name] = checkpoint["model_state_dict"][name]
Expand All @@ -217,8 +222,11 @@ def ready_model(
][running_var]

self.models.load_state_dict(model_state_dist_split_layer)

if precision == "fp_16":
self.models.half()
elif precision == "bfp_16":
self.models = self.models.to(dtype=torch.bfloat16)
self.models.eval()
print_model_size(self.models, split_rank, True)
self.models.to("cuda")
Expand Down Expand Up @@ -390,6 +398,8 @@ def initialize_recv_buffers(self):
datatype = torch.float32
if self.precision == "fp_16":
datatype = torch.float16
elif self.precision == "bfp_16":
datatype = torch.bfloat16

# intializing recv buffer for the input
# For parts we need different buffers as in backward pass we using grad variable to
Expand Down
Loading

0 comments on commit eae2ec9

Please sign in to comment.