From 1ace241db4107f9a1a40e97f31d6053cb23778eb Mon Sep 17 00:00:00 2001 From: huismiling Date: Thu, 24 Oct 2024 21:30:59 +0800 Subject: [PATCH] MLU devices : Checks if mlu is available via an cndev-based check which won't trigger the drivers and leave mlu (#3187) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Cambricon MLU accelerator support * up mlu support for test * fix mlu device MULTI_MLU * Update src/accelerate/utils/imports.py it's beautiful ! Co-authored-by: Zach Mueller * up mlu for quality check * fix mlu device longTensor error * fix mlu device tensor dtype check * fix mlu device send_to_device with torch dynamo error * Refactor AcceleratorState * Should be near complete now * Last missing piece * Make my way to the acceleratorstate * Include update to global var * Don't use global * gpu -> cuda * Don't use update for dict, easier to read * Fix tests * stash * Getting closer... * Needed to spawn at the very end after env was setup * Explain set_device before deepspeed * Make docstring more accurate * Early return insteaD * Delineat blocks * Make prepare_backend return state + backend for clarity/less magic * fix mlu longtensor.to() bugs. * fix MLU devices rng state save and load. * Cambricon MLU features, Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu uninitialized. * MLU devices : Checks if mlu is available via an cndev-based check which won't trigger the drivers and leave mlu * fix code style and quality * fix is_cuda_available error --------- Co-authored-by: Zach Mueller --- src/accelerate/utils/__init__.py | 4 +- src/accelerate/utils/dataclasses.py | 2 +- src/accelerate/utils/environment.py | 73 ++++++++++++++++++++++++++++ src/accelerate/utils/imports.py | 28 ++++------- src/accelerate/utils/other.py | 74 ----------------------------- tests/deepspeed/test_deepspeed.py | 2 +- tests/fsdp/test_fsdp.py | 2 +- 7 files changed, 88 insertions(+), 97 deletions(-) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index f1fa57fbdb9..aace236a6c9 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -63,12 +63,14 @@ are_libraries_initialized, check_cuda_p2p_ib_support, check_fp8_capability, + clear_environment, convert_dict_to_env_variables, get_cpu_distributed_information, get_gpu_info, get_int_from_env, parse_choice_from_env, parse_flag_from_env, + patch_environment, set_numa_affinity, str_to_bool, ) @@ -245,14 +247,12 @@ from .other import ( check_os_kernel, clean_state_dict_for_safetensors, - clear_environment, convert_bytes, extract_model_from_parallel, get_pretty_name, is_port_in_use, load, merge_dicts, - patch_environment, recursive_getattr, save, wait_for_everyone, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index edf678eedbc..0c86c796d22 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -61,7 +61,7 @@ def to_kwargs(self): Returns a dictionary containing the attributes with values different from the default of this class. """ # import clear_environment here to avoid circular import problem - from .other import clear_environment + from .environment import clear_environment with clear_environment(): default_dict = self.__class__().to_dict() diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index 34fa7e2ae29..04a41fd78aa 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -18,6 +18,7 @@ import platform import subprocess import sys +from contextlib import contextmanager from dataclasses import dataclass, field from functools import lru_cache from shutil import which @@ -272,3 +273,75 @@ def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) default to True. """ override_numa_affinity(local_process_index=local_process_index, verbose=verbose) + + +@contextmanager +def clear_environment(): + """ + A context manager that will temporarily clear environment variables. + + When this context exits, the previous environment variables will be back. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import clear_environment + + >>> os.environ["FOO"] = "bar" + >>> with clear_environment(): + ... print(os.environ) + ... os.environ["FOO"] = "new_bar" + ... print(os.environ["FOO"]) + {} + new_bar + + >>> print(os.environ["FOO"]) + bar + ``` + """ + _old_os_environ = os.environ.copy() + os.environ.clear() + + try: + yield + finally: + os.environ.clear() # clear any added keys, + os.environ.update(_old_os_environ) # then restore previous environment + + +@contextmanager +def patch_environment(**kwargs): + """ + A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. + + Will convert the values in `kwargs` to strings and upper-case all the keys. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import patch_environment + + >>> with patch_environment(FOO="bar"): + ... print(os.environ["FOO"]) # prints "bar" + >>> print(os.environ["FOO"]) # raises KeyError + ``` + """ + existing_vars = {} + for key, value in kwargs.items(): + key = key.upper() + if key in os.environ: + existing_vars[key] = os.environ[key] + os.environ[key] = str(value) + + try: + yield + finally: + for key in kwargs: + key = key.upper() + if key in existing_vars: + # restore previous value + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 6c0f5d38a22..453042b27a0 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -22,7 +22,7 @@ from packaging import version from packaging.version import parse -from .environment import parse_flag_from_env, str_to_bool +from .environment import parse_flag_from_env, patch_environment, str_to_bool from .versions import compare_versions, is_torch_version @@ -118,15 +118,8 @@ def is_cuda_available(): Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda uninitialized. """ - pytorch_nvml_based_cuda_check_previous_value = os.environ.get("PYTORCH_NVML_BASED_CUDA_CHECK") - try: - os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = str(1) + with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"): available = torch.cuda.is_available() - finally: - if pytorch_nvml_based_cuda_check_previous_value: - os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = pytorch_nvml_based_cuda_check_previous_value - else: - os.environ.pop("PYTORCH_NVML_BASED_CUDA_CHECK", None) return available @@ -327,20 +320,19 @@ def get_major_and_minor_from_version(full_version): @lru_cache def is_mlu_available(check_device=False): - "Checks if `torch_mlu` is installed and potentially if a MLU is in the environment" + """ + Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu + uninitialized. + """ if importlib.util.find_spec("torch_mlu") is None: return False import torch_mlu # noqa: F401 - if check_device: - try: - # Will raise a RuntimeError if no MLU is found - _ = torch.mlu.device_count() - return torch.mlu.is_available() - except RuntimeError: - return False - return hasattr(torch, "mlu") and torch.mlu.is_available() + with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"): + available = torch.mlu.is_available() + + return available @lru_cache diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 520a726a774..ad1118966be 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -13,12 +13,10 @@ # limitations under the License. import collections -import os import platform import re import socket from codecs import encode -from contextlib import contextmanager from functools import partial, reduce from types import MethodType from typing import OrderedDict @@ -262,78 +260,6 @@ def load(f, map_location=None, **kwargs): return loaded_obj -@contextmanager -def clear_environment(): - """ - A context manager that will temporarily clear environment variables. - - When this context exits, the previous environment variables will be back. - - Example: - - ```python - >>> import os - >>> from accelerate.utils import clear_environment - - >>> os.environ["FOO"] = "bar" - >>> with clear_environment(): - ... print(os.environ) - ... os.environ["FOO"] = "new_bar" - ... print(os.environ["FOO"]) - {} - new_bar - - >>> print(os.environ["FOO"]) - bar - ``` - """ - _old_os_environ = os.environ.copy() - os.environ.clear() - - try: - yield - finally: - os.environ.clear() # clear any added keys, - os.environ.update(_old_os_environ) # then restore previous environment - - -@contextmanager -def patch_environment(**kwargs): - """ - A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. - - Will convert the values in `kwargs` to strings and upper-case all the keys. - - Example: - - ```python - >>> import os - >>> from accelerate.utils import patch_environment - - >>> with patch_environment(FOO="bar"): - ... print(os.environ["FOO"]) # prints "bar" - >>> print(os.environ["FOO"]) # raises KeyError - ``` - """ - existing_vars = {} - for key, value in kwargs.items(): - key = key.upper() - if key in os.environ: - existing_vars[key] = os.environ[key] - os.environ[key] = str(value) - - try: - yield - finally: - for key in kwargs: - key = key.upper() - if key in existing_vars: - # restore previous value - os.environ[key] = existing_vars[key] - else: - os.environ.pop(key, None) - - def get_pretty_name(obj): """ Gets a pretty name from `obj`. diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 62de65e9618..368a44675ff 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -42,6 +42,7 @@ slow, ) from accelerate.test_utils.training import RegressionDataset, RegressionModel +from accelerate.utils import patch_environment from accelerate.utils.dataclasses import DeepSpeedPlugin from accelerate.utils.deepspeed import ( DeepSpeedEngineWrapper, @@ -50,7 +51,6 @@ DummyOptim, DummyScheduler, ) -from accelerate.utils.other import patch_environment from accelerate.utils.versions import compare_versions diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 11dc5e3510c..243c4738c33 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -35,6 +35,7 @@ require_non_torch_xla, slow, ) +from accelerate.utils import patch_environment from accelerate.utils.constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, @@ -43,7 +44,6 @@ ) from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading -from accelerate.utils.other import patch_environment set_seed(42)