diff --git a/experiment/Dockerfile-centos b/experiment/Dockerfile-centos index 31ffc194..71c63e41 100644 --- a/experiment/Dockerfile-centos +++ b/experiment/Dockerfile-centos @@ -133,6 +133,7 @@ RUN /opt/conda/bin/pip --no-cache-dir install \ botocore \ torch-scatter \ pyecharts \ + py-libnuma \ -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \ && /opt/conda/bin/pip --no-cache-dir install \ --extra-index-url https://download.pytorch.org/whl/cu117 \ diff --git a/experiment/Dockerfile-ubuntu b/experiment/Dockerfile-ubuntu index 230a3b50..0675c2cc 100644 --- a/experiment/Dockerfile-ubuntu +++ b/experiment/Dockerfile-ubuntu @@ -114,6 +114,7 @@ RUN /opt/conda/bin/pip --no-cache-dir install \ botocore \ torch-scatter \ pyecharts \ + py-libnuma \ -f https://data.pyg.org/whl/torch-${PYTORCH_VERSION}+cu117.html \ && /opt/conda/bin/pip --no-cache-dir install \ --extra-index-url https://download.pytorch.org/whl/cu117 \ diff --git a/internlm/initialize/__init__.py b/internlm/initialize/__init__.py index ae94e0a2..14fe06bb 100644 --- a/internlm/initialize/__init__.py +++ b/internlm/initialize/__init__.py @@ -4,6 +4,7 @@ initialize_distributed_env, launch_from_slurm, launch_from_torch, + try_bind_numa, ) __all__ = [ @@ -12,4 +13,5 @@ "launch_from_slurm", "launch_from_torch", "initialize_distributed_env", + "try_bind_numa", ] diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 429beef8..ddd01ef2 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -16,6 +16,16 @@ from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout +# check pacakge +try: + import numa + from numa import memory, schedule + from pynvml.smi import nvidia_smi +except (AttributeError, ImportError): + get_numa = False +else: + get_numa = True + logger = get_logger(__file__) @@ -385,6 +395,8 @@ def launch_from_slurm( except KeyError as e: raise RuntimeError(f"Could not find {e} in the SLURM environment") + try_bind_numa(global_rank=rank, world_size=world_size) + launch( config=config, rank=rank, @@ -418,6 +430,8 @@ def launch_from_torch( except KeyError as e: raise RuntimeError(f"Could not find {e} in the torch environment") + try_bind_numa(global_rank=rank, world_size=world_size, local_rank=local_rank) + launch( config=config, local_rank=local_rank, @@ -447,6 +461,7 @@ def initialize_distributed_env( master_port (str): The master port for distributed training. 8888 by default. seed (int, optional): Specified random seed for every process. 1024 by default. """ + # close automatic garbage collection gc.disable() @@ -484,3 +499,45 @@ def get_config_value(config, key, defalut): except KeyError: value = defalut return value + + +def try_bind_numa(global_rank, world_size, local_rank=None): + # Early return if numa module not available + if not get_numa: + if global_rank == 0: + logger.info( + "Try bind numa failed! Package import error, if numa is not installed, " + "please implement: pip install --upgrade py-libnuma, Ref: https://pypi.org/project/py-libnuma/" + ) + + # get numa node number + try: + numa_node_num = numa.info.get_max_node() + 1 + # get total gpu number of current node + nvsmi = nvidia_smi.getInstance() + total_GPU_per_node = len(nvsmi.DeviceQuery("memory.total")["gpu"]) + + # return while total_GPU_per_node is larger than numa_node_num or is not divisible by numa_node_num + if total_GPU_per_node <= numa_node_num: + return + if total_GPU_per_node % numa_node_num != 0: + return + # return while the number of processes is smaller than one node GPUs num + if world_size < total_GPU_per_node: + return + + if local_rank is None: + devices_per_node = torch.cuda.device_count() + local_rank = global_rank % devices_per_node + + # compute numa id for each locak rank + per_numa = total_GPU_per_node // numa_node_num + numa_id = local_rank // per_numa + + # bind numa node + schedule.run_on_nodes(numa_id) + memory.set_membind_nodes(numa_id) + except Exception: + return # try_bind_numa should not raise exception + else: + logger.info(f"Rank: {global_rank} success bind process to numa node: {numa_id}") diff --git a/requirements/runtime.txt b/requirements/runtime.txt index f46d7ad9..2fbef4a3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -13,4 +13,5 @@ boto3 botocore torch-scatter pyecharts +py-libnuma -f https://data.pyg.org/whl/torch-1.13.1+cu117.html \ No newline at end of file