diff --git a/.github/workflows/MPI4DL-github-repo-stats.yml b/.github/workflows/MPI4DL-github-repo-stats.yml new file mode 100644 index 00000000..513486f8 --- /dev/null +++ b/.github/workflows/MPI4DL-github-repo-stats.yml @@ -0,0 +1,20 @@ +name: MPI4DL-github-repo-stats.yml + +on: + schedule: + # Run this once per day, towards the end of the day for keeping the most + # recent data point most meaningful (hours are interpreted in UTC). + - cron: "0 23 * * *" + workflow_dispatch: # Allow for running this manually. + +jobs: + j1: + name: MPI4DL-github-repo-stats.yml + runs-on: ubuntu-latest + steps: + - name: run-ghrs + # Use latest release. + uses: jgehrcke/github-repo-stats@RELEASE + with: + ghtoken: ${{ secrets.ghrs_github_api_token }} + diff --git a/README.md b/README.md index 1ed73914..4594d3f7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +MPI4DL is a [HiDL](https://hidl.cse.ohio-state.edu/) project. We encourage you to visit the [HiDL website](https://hidl.cse.ohio-state.edu/) for additional information, the latest performance numbers, and similar projects on high-performance machine and deep learning. For the latest announcements on HiDL projects, [register for the HiDL mailing list](https://hidl.cse.ohio-state.edu/mailinglists/). + # MPI4DL v0.5 The size of image-based DL models regularly grows beyond the memory available on a single processor (we call such models **out-of-core**), and require advanced parallelism schemes to fit within device memory. Further, the massive image sizes required in specialized applications such as medical and satellite imaging can themselves place significant device memory pressure, and require parallelism schemes to process efficiently during training. Finally, the simplest parallelism scheme, [layer parallelism](#layer-parallelism), is highly inefficient. While there are several approaches that have been proposed to address some of the limitations of layer parallelism. However, most studies are performed for low-resolution images that exhibit different characteristics. Compared to low-resolution images, high-resolution images (e.g. digital pathology, satellite imaging) result in higher activation memory and larger tensors, which in turn lead to a larger communication overhead. diff --git a/benchmarks/gems_master_model/README.md b/benchmarks/gems_master_model/README.md new file mode 100644 index 00000000..2ad92904 --- /dev/null +++ b/benchmarks/gems_master_model/README.md @@ -0,0 +1,67 @@ +# GEMS: GPU-Enabled Memory-Aware Model-Parallelism System for Distributed DNN Training +Model Parallelism is necessary for training out-of-core models; however, it can lead to the underutilization of resources. To address this limitation, Pipeline Parallelism is employed, where the batch size is set to greater than 1. But, when dealing with very high-resolution images, certain state-of-the-art models can only work with a unit batch size. GEMS is a memory-efficient design for model parallelism that enables training models with any batch size while utilizing the same resources. For more details, please refer to the original paper: [GEMS: GPU-Enabled Memory-Aware Model-Parallelism System for Distributed DNN Training](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9355254). + +## Run GEMS-MASTER: + +#### Generic command: +```bash +$MV2_HOME/bin/mpirun_rsh --export-all -np $np --hostfile ${HOSTFILE} MV2_USE_GDRCOPY=0 MV2_ENABLE_AFFINITY=0 MV2_USE_CUDA=1 LD_PRELOAD=$MV2_HOME/lib/libmpi.so python ${gems_model_script} --split-size ${split_size} --image-size ${image_size} --batch-size ${batch_size} --times ${times} +``` +#### Examples + +- Example to run AmoebaNet MASTER model for 1024 * 1024 image size with 4 model split size(i.e. # of partitions for MP), model replication factor (η = 2) and batch size for each model replica as 1 (i.e. effective batch size (EBS) = η × BS = 2). + +```bash +$MV2_HOME/bin/mpirun_rsh --export-all -np $np --hostfile ${HOSTFILE} MV2_USE_GDRCOPY=0 MV2_ENABLE_AFFINITY=0 MV2_USE_CUDA=1 LD_PRELOAD=$MV2_HOME/lib/libmpi.so python benchmarks/gems_master_model/benchmark_amoebanet_gems_master.py --split-size 4 --image-size 1024 --batch-size 1 --times 2 +``` +- Similarly, we can run benchmark for ResNet MASTER model. +Below is example to run ResNet MASTER model for 2048 * 2048 image size with 4 model split size(i.e. # of partitions for MP), model replication factor (η = 4) and batch size for each model replica as 1 (i.e. effective batch size (EBS) = η × BS = 4). +```bash +$MV2_HOME/bin/mpirun_rsh --export-all -np $np --hostfile ${HOSTFILE} MV2_USE_GDRCOPY=0 MV2_ENABLE_AFFINITY=0 MV2_USE_CUDA=1 LD_PRELOAD=$MV2_HOME/lib/libmpi.so python benchmarks/gems_master_model/benchmark_resnet_gems_master.py --split-size 4 --image-size 2048 --batch-size 1 --times 4 &>> $OUTFILE 2>&1 + +``` + +Below are the available configuration options : + +
+usage: benchmark_amoebanet_sp.py [-h] [-v] [--batch-size BATCH_SIZE] [--parts PARTS] [--split-size SPLIT_SIZE] [--num-spatial-parts NUM_SPATIAL_PARTS] + [--spatial-size SPATIAL_SIZE] [--times TIMES] [--image-size IMAGE_SIZE] [--num-epochs NUM_EPOCHS] [--num-layers NUM_LAYERS] + [--num-filters NUM_FILTERS] [--balance BALANCE] [--halo-D2] [--fused-layers FUSED_LAYERS] [--local-DP LOCAL_DP] [--slice-method SLICE_METHOD] + [--app APP] [--datapath DATAPATH] + +SP-MP-DP Configuration Script + +optional arguments: + -h, --help show this help message and exit + -v, --verbose Prints performance numbers or logs (default: False) + --batch-size BATCH_SIZE + input batch size (default: 32) + --parts PARTS Number of parts for MP (default: 1) + --split-size SPLIT_SIZE + Number of process for MP (default: 2) + --num-spatial-parts NUM_SPATIAL_PARTS + Number of partitions in spatial parallelism (default: 4) + --spatial-size SPATIAL_SIZE + Number splits for spatial parallelism (default: 1) + --times TIMES Number of times to repeat MASTER 1: 2 repications, 2: 4 replications (default: 1) + --image-size IMAGE_SIZE + Image size for synthetic benchmark (default: 32) + --num-epochs NUM_EPOCHS + Number of epochs (default: 1) + --num-layers NUM_LAYERS + Number of layers in amoebanet (default: 18) + --num-filters NUM_FILTERS + Number of layers in amoebanet (default: 416) + --balance BALANCE length of list equals to number of partitions and sum should be equal to num layers (default: None) + --halo-D2 Enable design2 (do halo exhange on few convs) for spatial conv. (default: False) + --fused-layers FUSED_LAYERS + When D2 design is enables for halo exchange, number of blocks to fuse in ResNet model (default: 1) + --local-DP LOCAL_DP LBANN intergration of SP with MP. MP can apply data parallelism. 1: only one GPU for a given split, 2: two gpus for a given split (uses DP) + (default: 1) + --slice-method SLICE_METHOD + Slice method (square, vertical, and horizontal) in Spatial parallelism (default: square) + --app APP Application type (1.medical, 2.cifar, and synthetic) in Spatial parallelism (default: 3) + --datapath DATAPATH local Dataset path (default: ./train) ++ + *Note:"--times" is GEMS specific parameter and certain parameters such as "--num-spatial-parts", "--slice-method", "--halo-D2" would not be required by GEMS.* diff --git a/benchmarks/gems_model/benchmark_amoebanet_gems+spatial.py b/benchmarks/gems_master_model/benchmark_amoebanet_gems+spatial.py similarity index 100% rename from benchmarks/gems_model/benchmark_amoebanet_gems+spatial.py rename to benchmarks/gems_master_model/benchmark_amoebanet_gems+spatial.py diff --git a/benchmarks/gems_model/benchmark_amoebanet_gems.py b/benchmarks/gems_master_model/benchmark_amoebanet_gems_master.py similarity index 97% rename from benchmarks/gems_model/benchmark_amoebanet_gems.py rename to benchmarks/gems_master_model/benchmark_amoebanet_gems_master.py index 811d03b0..3d7a80f5 100644 --- a/benchmarks/gems_model/benchmark_amoebanet_gems.py +++ b/benchmarks/gems_master_model/benchmark_amoebanet_gems_master.py @@ -38,7 +38,7 @@ def __getattr__(self, attr): return getattr(self.stream, attr) -def init_processes(backend="tcp"): +def init_processes(backend="mpi"): """Initialize the distributed environment.""" dist.init_process_group(backend) size = dist.get_world_size() @@ -57,19 +57,20 @@ def init_processes(backend="tcp"): # 2: Cifar # 3: synthetic APP = args.app +times = args.times image_size = int(args.image_size) num_layers = args.num_layers num_filters = args.num_filters balance = args.balance mp_size = args.split_size datapath = args.datapath +num_workers = args.num_workers num_classes = args.num_classes ##################### AmoebaNet GEMS model specific parameters ##################### image_size_seq = 512 ENABLE_ASYNC = True -times = 2 ############################################################################### mpi_comm = gems_comm.MPIComm(split_size=mp_size, ENABLE_MASTER=True) @@ -192,7 +193,7 @@ def init_processes(backend="tcp"): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -212,10 +213,10 @@ def init_processes(backend="tcp"): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) - size_dataset = 50000 + size_dataset = len(my_dataloader.dataset) else: my_dataset = torchvision.datasets.FakeData( size=10 * batch_size, @@ -229,7 +230,7 @@ def init_processes(backend="tcp"): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size diff --git a/benchmarks/gems_model/benchmark_resnet_gems+spatial.py b/benchmarks/gems_master_model/benchmark_resnet_gems+spatial.py similarity index 100% rename from benchmarks/gems_model/benchmark_resnet_gems+spatial.py rename to benchmarks/gems_master_model/benchmark_resnet_gems+spatial.py diff --git a/benchmarks/gems_model/benchmark_resnet_gems.py b/benchmarks/gems_master_model/benchmark_resnet_gems_master.py similarity index 97% rename from benchmarks/gems_model/benchmark_resnet_gems.py rename to benchmarks/gems_master_model/benchmark_resnet_gems_master.py index 2f6e9fbd..bacde2d2 100644 --- a/benchmarks/gems_model/benchmark_resnet_gems.py +++ b/benchmarks/gems_master_model/benchmark_resnet_gems_master.py @@ -38,7 +38,7 @@ def __getattr__(self, attr): return getattr(self.stream, attr) -def init_processes(backend="tcp"): +def init_processes(backend="mpi"): """Initialize the distributed environment.""" dist.init_process_group(backend) size = dist.get_world_size() @@ -58,12 +58,14 @@ def init_processes(backend="tcp"): # 2: Cifar # 3: synthetic APP = args.app +times = args.times image_size = int(args.image_size) num_layers = args.num_layers num_filters = args.num_filters balance = args.balance mp_size = args.split_size datapath = args.datapath +num_workers = args.num_workers num_classes = args.num_classes ################## ResNet model specific parameters/functions ################## @@ -71,7 +73,6 @@ def init_processes(backend="tcp"): image_size_seq = 32 ENABLE_ASYNC = True resnet_n = 12 -times = 2 def get_depth(version, n): @@ -208,7 +209,7 @@ def get_depth(version, n): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -220,7 +221,7 @@ def get_depth(version, n): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -237,7 +238,7 @@ def get_depth(version, n): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size diff --git a/benchmarks/layer_parallelism/benchmark_amoebanet_lp.py b/benchmarks/layer_parallelism/benchmark_amoebanet_lp.py index ebad7692..3dc47968 100644 --- a/benchmarks/layer_parallelism/benchmark_amoebanet_lp.py +++ b/benchmarks/layer_parallelism/benchmark_amoebanet_lp.py @@ -70,6 +70,7 @@ def __getattr__(self, attr): mp_size = args.split_size times = args.times datapath = args.datapath +num_workers = args.num_workers # APP # 1: Medical # 2: Cifar @@ -186,7 +187,7 @@ def __getattr__(self, attr): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -198,7 +199,7 @@ def __getattr__(self, attr): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 50000 @@ -215,7 +216,7 @@ def __getattr__(self, attr): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size diff --git a/benchmarks/layer_parallelism/benchmark_resnet_lp.py b/benchmarks/layer_parallelism/benchmark_resnet_lp.py index 8d48f473..9cf1d13c 100644 --- a/benchmarks/layer_parallelism/benchmark_resnet_lp.py +++ b/benchmarks/layer_parallelism/benchmark_resnet_lp.py @@ -67,6 +67,7 @@ def __getattr__(self, attr): mp_size = args.split_size times = args.times datapath = args.datapath +num_workers = args.num_workers # APP # 1: Medical # 2: Cifar @@ -197,7 +198,7 @@ def get_depth(version, n): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -209,7 +210,7 @@ def get_depth(version, n): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 50000 @@ -226,7 +227,7 @@ def get_depth(version, n): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size diff --git a/benchmarks/spatial_parallelism/README.md b/benchmarks/spatial_parallelism/README.md index 474590dd..88a1df43 100644 --- a/benchmarks/spatial_parallelism/README.md +++ b/benchmarks/spatial_parallelism/README.md @@ -14,13 +14,13 @@ $MV2_HOME/bin/mpirun_rsh --export-all -np $np --hostfile {$HOSTFILE} MV2_USE_CU - With 5 GPUs [split size: 2, num_spatial_parts: 4, spatial_size: 1] -Example to run AmoebaNet model with 2 model split size(i.e. # of partitions for MP), spatial partition (# of image partitions) as 4 and 1 as spatial size (i.e. number of model partition which will use spatial partition). In this configuration, we split model into two parts where first part will use spatial parallelism. +Example to run AmoebaNet model with 2 model split size(i.e. # of partitions for MP), spatial partition (# of image partitions) as 4 and 1 as spatial size (i.e. number of model partition which will use spatial partition). In this configuration, we split model into two parts where first part will use spatial parallelism. ```bash $MV2_HOME/bin/mpirun_rsh --export-all -np 5 --hostfile {$HOSTFILE} MV2_USE_CUDA=1 MV2_HYBRID_BINDING_POLICY=spread MV2_CPU_BINDING_POLICY=hybrid MV2_USE_GDRCOPY=0 PYTHONNOUSERSITE=true LD_PRELOAD=$MV2_HOME/lib/libmpi.so python benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py --image-size 512 --num-spatial-parts 4 --slice-method "vertical" --split-size 2 --spatial-size 1 ``` - With 9 GPUs [split size: 3, num_spatial_parts: 4, spatial_size: 2] -In this configuration, we split model int three parts where first two part will use spatial parallelism. +In this configuration, we split model int three parts where first two part will use spatial parallelism. ```bash $MV2_HOME/bin/mpirun_rsh --export-all -np 9 --hostfile {$HOSTFILE} MV2_USE_CUDA=1 MV2_HYBRID_BINDING_POLICY=spread MV2_CPU_BINDING_POLICY=hybrid MV2_USE_GDRCOPY=0 PYTHONNOUSERSITE=true LD_PRELOAD=$MV2_HOME/lib/libmpi.so python benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py --image-size 512 --num-spatial-parts 4 --slice-method "vertical" --split-size 3 --spatial-size 2 @@ -30,7 +30,19 @@ $MV2_HOME/bin/mpirun_rsh --export-all -np 9 --hostfile {$HOSTFILE} MV2_USE_CUDA= Find the example to run ResNet with halo-D2 enabled to reduce communication opertaions. To learn more about halo-D2, refer [Hy-Fi: Hybrid Five-Dimensional Parallel DNN Training on High-Performance GPU Clusters](https://dl.acm.org/doi/abs/10.1007/978-3-031-07312-0_6) ```bash $MV2_HOME/bin/mpirun_rsh --export-all -np 5 --hostfile {$HOSTFILE} MV2_USE_CUDA=1 MV2_HYBRID_BINDING_POLICY=spread MV2_CPU_BINDING_POLICY=hybrid MV2_USE_GDRCOPY=0 PYTHONNOUSERSITE=true LD_PRELOAD=$MV2_HOME/lib/libmpi.so benchmarks/spatial_parallelism/benchmark_resnet_sp.py --halo-D2 --num-spatial-parts 4 --image-size 1024 --batch-size 2 --slice-method "square" -``` +``` + +## Run spatial + data parallelism: +Currently SP + DP has been supported for AmoebaNet. + +- Enable Data Parallelism using "local-DP" argument. +- Example to run AmoebaNet model with 2 data partition, 2 model split size(i.e. # of partitions for MP), spatial partition (# of image partitions) as 4 and 1 as spatial size (i.e. number of model partition which will use spatial partition). In this configuration, we have 2 data partition and for each part, model will split into two parts where first part will use spatial parallelism. + + +```bash +$MV2_HOME/bin/mpirun_rsh --export-all -np $np --hostfile ${hostfile} MV2_USE_CUDA=1 MV2_HYBRID_BINDING_POLICY=spread MV2_CPU_BINDING_POLICY=hybrid MV2_USE_GDRCOPY=0 PYTHONNOUSERSITE=true LD_PRELOAD=$MV2_HOME/lib/libmpi.so python benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py --local-DP 2 --image-size ${image_size} --batch-size ${batch_size} --slice-method ${partition} +``` + Below are the available configuration options : diff --git a/benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py b/benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py index 1080d6d9..cf99736a 100644 --- a/benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py +++ b/benchmarks/spatial_parallelism/benchmark_amoebanet_sp.py @@ -27,7 +27,7 @@ import logging from torchgems import parser from torchgems.mp_pipeline import model_generator -from torchgems.train_spatial import train_model_spatial +from torchgems.train_spatial import train_model_spatial, split_input, get_shapes_spatial import torchgems.comm as gems_comm parser_obj = parser.get_parser() @@ -86,6 +86,7 @@ def init_processes(backend="tcp"): spatial_size = args.spatial_size times = args.times datapath = args.datapath +num_workers = args.num_workers LOCAL_DP_LP = args.local_DP # APP # 1: Medical @@ -111,11 +112,11 @@ def isPowerTwo(num): """ -For Amoebanet model, image size and image size after partitioning should be power of two. -As, Amoebanet performs summation of results of two convolution layers during training, -odd input size(i.e. image size which is not power of 2) will give different output sizes -for convolution operations present at same layer, thus it will throw error as addition -operation can not be performed with diffent size outputs. +For Amoebanet model, image size and image size after partitioning should be power of two. +As, Amoebanet performs summation of results of two convolution layers during training, +odd input size(i.e. image size which is not power of 2) will give different output sizes +for convolution operations present at same layer, thus it will throw error as addition +operation can not be performed with diffent size outputs. """ @@ -152,7 +153,7 @@ def verify_config(): ##################### AmoebaNet model specific parameters ##################### """ -"image_size_seq" is required to determine the output shape after spatial partitioning of images. +"image_size_seq" is required to determine the output shape after spatial partitioning of images. The shape of the output will be determined for each model partition based on the values in "image_size_seq." These values will then be used to calculate the output shape for a given input size and spatial partition. """ @@ -204,178 +205,13 @@ def verify_config(): # Get the shape of model on each split rank for image_size and number of spatial parts image_size_times = int(image_size / image_size_seq) -temp_count = 0 -if args.slice_method == "square": - amoebanet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - # reduce shape only when it is smaller than spatial size - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int( - shape_tuple[2] - * image_size_times - / int(math.sqrt(spatial_part_size)) - ), - int( - shape_tuple[3] - * image_size_times - / int(math.sqrt(spatial_part_size)) - ), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - amoebanet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - amoebanet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int( - output_shape[2] - * image_size_times - / int(math.sqrt(spatial_part_size)) - ), - int( - output_shape[3] - * image_size_times - / int(math.sqrt(spatial_part_size)) - ), - ) - amoebanet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - amoebanet_shapes_list.append(x) - temp_count += 1 - -elif args.slice_method == "vertical": - amoebanet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times / 1), - int( - shape_tuple[3] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - amoebanet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - amoebanet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times / 1), - int( - output_shape[3] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - ) - amoebanet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - amoebanet_shapes_list.append(x) - temp_count += 1 - - -elif args.slice_method == "horizontal": - amoebanet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int( - shape_tuple[2] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - int(shape_tuple[3] * image_size_times / 1), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - amoebanet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - amoebanet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int( - output_shape[2] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - int(output_shape[3] * image_size_times / 1), - ) - amoebanet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - amoebanet_shapes_list.append(x) - temp_count += 1 +amoebanet_shapes_list = get_shapes_spatial( + model_gen_seq.shape_list, + args.slice_method, + spatial_size, + num_spatial_parts_list, + image_size_times, +) del model_seq del model_gen_seq @@ -470,7 +306,7 @@ def verify_config(): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -482,7 +318,7 @@ def verify_config(): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 50000 @@ -499,63 +335,13 @@ def verify_config(): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size ################################################################################ - -def split_input(inputs): - if args.slice_method == "square": - image_height_local = int(image_size / math.sqrt(spatial_part_size)) - image_width_local = int(image_size / math.sqrt(spatial_part_size)) - - total_rows = int(math.sqrt(spatial_part_size)) - total_cols = int(math.sqrt(spatial_part_size)) - - # current position of rank in matrix of math.sqrt(spatial_part_size) * math.sqrt(num_spatial_parts) - row = int(local_rank / total_cols) - col = int(local_rank % total_cols) - - start_left = col * image_width_local - end_right = (col + 1) * image_width_local - - start_top = row * image_height_local - end_bottom = (row + 1) * image_height_local - - return inputs[:, :, start_top:end_bottom, start_left:end_right] - - elif args.slice_method == "vertical": - image_height_local = int(image_size / spatial_part_size) - image_width_local = int(image_size / spatial_part_size) - - start_left = local_rank * image_width_local - end_right = (local_rank + 1) * image_width_local - - if local_rank == spatial_part_size - 1: - # In case of GPU count, partition size will be uneven and last - # rank will receive remaining image - return inputs[:, :, :, start_left:] - else: - return inputs[:, :, :, start_left:end_right] - - elif args.slice_method == "horizontal": - image_height_local = int(image_size / spatial_part_size) - image_width_local = int(image_size / spatial_part_size) - - start_top = local_rank * image_height_local - end_bottom = (local_rank + 1) * image_height_local - - if local_rank == spatial_part_size - 1: - # In case of odd GPU count, partition size will be uneven and last - # rank will receive remaining image - return inputs[:, :, start_top:, :] - else: - return inputs[:, :, start_top:end_bottom, :] - - ################################# Train Model ################################## perf = [] @@ -576,7 +362,9 @@ def run_epoch(): inputs, labels = data if local_rank < spatial_part_size: - x = split_input(inputs) + x = split_input( + inputs, args.slice_method, image_size, spatial_part_size, local_rank + ) else: x = inputs diff --git a/benchmarks/spatial_parallelism/benchmark_resnet_sp.py b/benchmarks/spatial_parallelism/benchmark_resnet_sp.py index 7029ade9..b1a60734 100644 --- a/benchmarks/spatial_parallelism/benchmark_resnet_sp.py +++ b/benchmarks/spatial_parallelism/benchmark_resnet_sp.py @@ -28,7 +28,7 @@ import logging from torchgems import parser from torchgems.mp_pipeline import model_generator -from torchgems.train_spatial import train_model_spatial +from torchgems.train_spatial import train_model_spatial, split_input, get_shapes_spatial import torchgems.comm as gems_comm parser_obj = parser.get_parser() @@ -84,6 +84,7 @@ def init_processes(backend="mpi"): spatial_size = args.spatial_size times = args.times datapath = args.datapath +num_workers = args.num_workers # APP # 1: Medical @@ -206,179 +207,13 @@ def verify_config(): # Get the shape of model on each split rank for image_size and number of spatial parts image_size_times = int(image_size / image_size_seq) -temp_count = 0 -if args.slice_method == "square": - resnet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - # reduce shape only when it is smaller than spatial size - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int( - shape_tuple[2] - * image_size_times - / math.sqrt(spatial_part_size) - ), - int( - shape_tuple[3] - * image_size_times - / math.sqrt(spatial_part_size) - ), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - resnet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - resnet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int( - output_shape[2] - * image_size_times - / math.sqrt(spatial_part_size) - ), - int( - output_shape[3] - * image_size_times - / math.sqrt(spatial_part_size) - ), - ) - resnet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - resnet_shapes_list.append(x) - temp_count += 1 - -elif args.slice_method == "vertical": - resnet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times / 1), - int( - shape_tuple[3] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - resnet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - resnet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times / 1), - int( - output_shape[3] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - ) - resnet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - resnet_shapes_list.append(x) - temp_count += 1 - - -elif args.slice_method == "horizontal": - resnet_shapes_list = [] - for output_shape in model_gen_seq.shape_list: - if isinstance(output_shape, list): - temp_shape = [] - for shape_tuple in output_shape: - if temp_count < spatial_size: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int( - shape_tuple[2] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - int(shape_tuple[3] * image_size_times / 1), - ) - temp_shape.append(x) - else: - x = ( - int(shape_tuple[0]), - shape_tuple[1], - int(shape_tuple[2] * image_size_times), - int(shape_tuple[3] * image_size_times), - ) - temp_shape.append(x) - resnet_shapes_list.append(temp_shape) - else: - if len(output_shape) == 2: - x = (int(output_shape[0]), output_shape[1]) - resnet_shapes_list.append(x) - else: - if temp_count < spatial_size: - x = ( - int(output_shape[0]), - output_shape[1], - int( - output_shape[2] - * image_size_times - / num_spatial_parts_list[temp_count] - ), - int(output_shape[3] * image_size_times / 1), - ) - resnet_shapes_list.append(x) - else: - x = ( - int(output_shape[0]), - output_shape[1], - int(output_shape[2] * image_size_times), - int(output_shape[3] * image_size_times), - ) - resnet_shapes_list.append(x) - temp_count += 1 - +resnet_shapes_list = get_shapes_spatial( + model_gen_seq.shape_list, + args.slice_method, + spatial_size, + num_spatial_parts_list, + image_size_times, +) del model_seq del model_gen_seq @@ -470,7 +305,7 @@ def verify_config(): trainset, batch_size=times * batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = len(my_dataloader.dataset) @@ -482,7 +317,7 @@ def verify_config(): trainset, batch_size=times * batch_size, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 50000 @@ -499,7 +334,7 @@ def verify_config(): my_dataset, batch_size=batch_size * times, shuffle=False, - num_workers=0, + num_workers=num_workers, pin_memory=True, ) size_dataset = 10 * batch_size @@ -508,56 +343,6 @@ def verify_config(): sync_allreduce.sync_model_spatial(model_gen) - -def split_input(inputs): - if args.slice_method == "square": - image_height_local = int(image_size / math.sqrt(spatial_part_size)) - image_width_local = int(image_size / math.sqrt(spatial_part_size)) - - total_rows = int(math.sqrt(spatial_part_size)) - total_cols = int(math.sqrt(spatial_part_size)) - - # current position of rank in matrix of math.sqrt(spatial_part_size) * math.sqrt(spatial_part_size) - row = int(local_rank / total_cols) - col = int(local_rank % total_cols) - - start_left = col * image_width_local - end_right = (col + 1) * image_width_local - - start_top = row * image_height_local - end_bottom = (row + 1) * image_height_local - - return inputs[:, :, start_top:end_bottom, start_left:end_right] - - elif args.slice_method == "vertical": - image_height_local = int(image_size / spatial_part_size) - image_width_local = int(image_size / spatial_part_size) - - start_left = local_rank * image_width_local - end_right = (local_rank + 1) * image_width_local - - if local_rank == spatial_part_size - 1: - # In case of GPU count, partition size will be uneven and last - # rank will receive remaining image - return inputs[:, :, :, start_left:] - else: - return inputs[:, :, :, start_left:end_right] - - elif args.slice_method == "horizontal": - image_height_local = int(image_size / spatial_part_size) - image_width_local = int(image_size / spatial_part_size) - - start_top = local_rank * image_height_local - end_bottom = (local_rank + 1) * image_height_local - - if local_rank == spatial_part_size - 1: - # In case of odd GPU count, partition size will be uneven and last - # rank will receive remaining image - return inputs[:, :, start_top:, :] - else: - return inputs[:, :, start_top:end_bottom, :] - - ################################# Train Model ################################## perf = [] @@ -578,7 +363,9 @@ def run_epoch(): inputs, labels = data if local_rank < spatial_part_size: - x = split_input(inputs) + x = split_input( + inputs, args.slice_method, image_size, spatial_part_size, local_rank + ) else: x = inputs diff --git a/src/torchgems/parser.py b/src/torchgems/parser.py index ca8c8bd5..2cf27de7 100644 --- a/src/torchgems/parser.py +++ b/src/torchgems/parser.py @@ -133,4 +133,11 @@ def get_parser(): help="Enable communication optimization for MASTER in Spatial", ) + parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Slice method (square, vertical, and horizontal) in Spatial parallelism", + ) + return parser diff --git a/src/torchgems/train_spatial.py b/src/torchgems/train_spatial.py index 7f09d18b..d4834049 100644 --- a/src/torchgems/train_spatial.py +++ b/src/torchgems/train_spatial.py @@ -67,8 +67,10 @@ def get_shapes_spatial( ): temp_count = 0 spatial_shapes_list = [] + spatial_part_size = num_spatial_parts_list[0] if slice_method == "square": + spatial_shapes_list = [] for output_shape in shape_list: if isinstance(output_shape, list): temp_shape = [] @@ -78,8 +80,16 @@ def get_shapes_spatial( x = ( int(shape_tuple[0]), shape_tuple[1], - int(shape_tuple[2] * image_size_times / 2), - int(shape_tuple[3] * image_size_times / 2), + int( + shape_tuple[2] + * image_size_times + / math.sqrt(spatial_part_size) + ), + int( + shape_tuple[3] + * image_size_times + / math.sqrt(spatial_part_size) + ), ) temp_shape.append(x) else: @@ -100,8 +110,16 @@ def get_shapes_spatial( x = ( int(output_shape[0]), output_shape[1], - int(output_shape[2] * image_size_times / 2), - int(output_shape[3] * image_size_times / 2), + int( + output_shape[2] + * image_size_times + / math.sqrt(spatial_part_size) + ), + int( + output_shape[3] + * image_size_times + / math.sqrt(spatial_part_size) + ), ) spatial_shapes_list.append(x) else: @@ -115,6 +133,7 @@ def get_shapes_spatial( temp_count += 1 elif slice_method == "vertical": + spatial_shapes_list = [] for output_shape in shape_list: if isinstance(output_shape, list): temp_shape = [] @@ -168,6 +187,7 @@ def get_shapes_spatial( temp_count += 1 elif slice_method == "horizontal": + spatial_shapes_list = [] for output_shape in shape_list: if isinstance(output_shape, list): temp_shape = [] @@ -222,66 +242,53 @@ def get_shapes_spatial( return spatial_shapes_list -def split_input_2(inputs, image_size, slice_method, local_rank): - image_height_local = int(image_size / 2) - image_width_local = int(image_size / 2) - - # square == vertical +def split_input(inputs, slice_method, image_size, spatial_part_size, local_rank): + if slice_method == "square": + image_height_local = int(image_size / math.sqrt(spatial_part_size)) + image_width_local = int(image_size / math.sqrt(spatial_part_size)) - if slice_method == "square" or slice_method == "vertical": - if local_rank == 0: - return inputs[:, :, :, :image_width_local] - elif local_rank == 1: - return inputs[:, :, :, image_width_local : 2 * image_width_local] + total_rows = int(math.sqrt(spatial_part_size)) + total_cols = int(math.sqrt(spatial_part_size)) - elif slice_method == "horizontal": - if local_rank == 0: - return inputs[:, :, :image_height_local, :] - elif local_rank == 1: - return inputs[:, :, image_height_local : 2 * image_height_local, :] + # current position of rank in matrix of math.sqrt(spatial_part_size) * math.sqrt(spatial_part_size) + row = int(local_rank / total_cols) + col = int(local_rank % total_cols) + start_left = col * image_width_local + end_right = (col + 1) * image_width_local -def split_input_4(inputs, image_size, slice_method, local_rank): - image_height_local = int(image_size / 4) - image_width_local = int(image_size / 4) + start_top = row * image_height_local + end_bottom = (row + 1) * image_height_local - if slice_method == "square": - if local_rank == 0: - return inputs[:, :, : int(image_size / 2), : int(image_size / 2)] - elif local_rank == 1: - return inputs[:, :, : int(image_size / 2), int(image_size / 2) :] - elif local_rank == 2: - return inputs[:, :, int(image_size / 2) :, : int(image_size / 2)] - elif local_rank == 3: - return inputs[:, :, int(image_size / 2) :, int(image_size / 2) :] + return inputs[:, :, start_top:end_bottom, start_left:end_right] elif slice_method == "vertical": - if local_rank == 0: - return inputs[:, :, :, :image_width_local] - elif local_rank == 1: - return inputs[:, :, :, image_width_local : 2 * image_width_local] - elif local_rank == 2: - return inputs[:, :, :, 2 * image_width_local : 3 * image_width_local] - elif local_rank == 3: - return inputs[:, :, :, 3 * image_width_local : 4 * image_width_local] + image_height_local = int(image_size / spatial_part_size) + image_width_local = int(image_size / spatial_part_size) + + start_left = local_rank * image_width_local + end_right = (local_rank + 1) * image_width_local + + if local_rank == spatial_part_size - 1: + # In case of GPU count, partition size will be uneven and last + # rank will receive remaining image + return inputs[:, :, :, start_left:] + else: + return inputs[:, :, :, start_left:end_right] elif slice_method == "horizontal": - if local_rank == 0: - return inputs[:, :, :image_height_local, :] - elif local_rank == 1: - return inputs[:, :, image_height_local : 2 * image_height_local, :] - elif local_rank == 2: - return inputs[:, :, 2 * image_height_local : 3 * image_height_local, :] - elif local_rank == 3: - return inputs[:, :, 3 * image_height_local : 4 * image_height_local, :] - - -def split_input(inputs, image_size, slice_method, local_rank, num_spatial_parts_list): - if num_spatial_parts_list[0] == 2: - return split_input_2(inputs, image_size, slice_method, local_rank) - - elif num_spatial_parts_list[0] == 4: - return split_input_4(inputs, image_size, slice_method, local_rank) + image_height_local = int(image_size / spatial_part_size) + image_width_local = int(image_size / spatial_part_size) + + start_top = local_rank * image_height_local + end_bottom = (local_rank + 1) * image_height_local + + if local_rank == spatial_part_size - 1: + # In case of odd GPU count, partition size will be uneven and last + # rank will receive remaining image + return inputs[:, :, start_top:, :] + else: + return inputs[:, :, start_top:end_bottom, :] class train_model_spatial(train_model):