Skip to content

Commit

Permalink
MLU devices : Checks if mlu is available via an cndev-based check whi…
Browse files Browse the repository at this point in the history
…ch won't trigger the drivers and leave mlu (#3187)

* Add Cambricon MLU accelerator support

* up mlu support for test

* fix mlu device MULTI_MLU

* Update src/accelerate/utils/imports.py

it's beautiful !

Co-authored-by: Zach Mueller <[email protected]>

* up mlu for quality check

* fix mlu device longTensor error

* fix mlu device tensor dtype check

* fix mlu device send_to_device with torch dynamo error

* Refactor AcceleratorState

* Should be near complete now

* Last missing piece

* Make my way to the acceleratorstate

* Include update to global var

* Don't use global

* gpu -> cuda

* Don't use update for dict, easier to read

* Fix tests

* stash

* Getting closer...

* Needed to spawn at the very end after env was setup

* Explain set_device before deepspeed

* Make docstring more accurate

* Early return insteaD

* Delineat blocks

* Make prepare_backend return state + backend for clarity/less magic

* fix mlu longtensor.to() bugs.

* fix MLU devices rng state save and load.

* Cambricon MLU features, Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu uninitialized.

* MLU devices : Checks if mlu is available via an cndev-based check which won't trigger the drivers and leave mlu

* fix code style and quality

* fix is_cuda_available error

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
huismiling and muellerzr authored Oct 24, 2024
1 parent 78e1bdd commit 1ace241
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 97 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@
are_libraries_initialized,
check_cuda_p2p_ib_support,
check_fp8_capability,
clear_environment,
convert_dict_to_env_variables,
get_cpu_distributed_information,
get_gpu_info,
get_int_from_env,
parse_choice_from_env,
parse_flag_from_env,
patch_environment,
set_numa_affinity,
str_to_bool,
)
Expand Down Expand Up @@ -245,14 +247,12 @@
from .other import (
check_os_kernel,
clean_state_dict_for_safetensors,
clear_environment,
convert_bytes,
extract_model_from_parallel,
get_pretty_name,
is_port_in_use,
load,
merge_dicts,
patch_environment,
recursive_getattr,
save,
wait_for_everyone,
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def to_kwargs(self):
Returns a dictionary containing the attributes with values different from the default of this class.
"""
# import clear_environment here to avoid circular import problem
from .other import clear_environment
from .environment import clear_environment

with clear_environment():
default_dict = self.__class__().to_dict()
Expand Down
73 changes: 73 additions & 0 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import platform
import subprocess
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import lru_cache
from shutil import which
Expand Down Expand Up @@ -272,3 +273,75 @@ def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None)
default to True.
"""
override_numa_affinity(local_process_index=local_process_index, verbose=verbose)


@contextmanager
def clear_environment():
"""
A context manager that will temporarily clear environment variables.
When this context exits, the previous environment variables will be back.
Example:
```python
>>> import os
>>> from accelerate.utils import clear_environment
>>> os.environ["FOO"] = "bar"
>>> with clear_environment():
... print(os.environ)
... os.environ["FOO"] = "new_bar"
... print(os.environ["FOO"])
{}
new_bar
>>> print(os.environ["FOO"])
bar
```
"""
_old_os_environ = os.environ.copy()
os.environ.clear()

try:
yield
finally:
os.environ.clear() # clear any added keys,
os.environ.update(_old_os_environ) # then restore previous environment


@contextmanager
def patch_environment(**kwargs):
"""
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
Will convert the values in `kwargs` to strings and upper-case all the keys.
Example:
```python
>>> import os
>>> from accelerate.utils import patch_environment
>>> with patch_environment(FOO="bar"):
... print(os.environ["FOO"]) # prints "bar"
>>> print(os.environ["FOO"]) # raises KeyError
```
"""
existing_vars = {}
for key, value in kwargs.items():
key = key.upper()
if key in os.environ:
existing_vars[key] = os.environ[key]
os.environ[key] = str(value)

try:
yield
finally:
for key in kwargs:
key = key.upper()
if key in existing_vars:
# restore previous value
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)
28 changes: 10 additions & 18 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from packaging import version
from packaging.version import parse

from .environment import parse_flag_from_env, str_to_bool
from .environment import parse_flag_from_env, patch_environment, str_to_bool
from .versions import compare_versions, is_torch_version


Expand Down Expand Up @@ -118,15 +118,8 @@ def is_cuda_available():
Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda
uninitialized.
"""
pytorch_nvml_based_cuda_check_previous_value = os.environ.get("PYTORCH_NVML_BASED_CUDA_CHECK")
try:
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = str(1)
with patch_environment(PYTORCH_NVML_BASED_CUDA_CHECK="1"):
available = torch.cuda.is_available()
finally:
if pytorch_nvml_based_cuda_check_previous_value:
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = pytorch_nvml_based_cuda_check_previous_value
else:
os.environ.pop("PYTORCH_NVML_BASED_CUDA_CHECK", None)

return available

Expand Down Expand Up @@ -327,20 +320,19 @@ def get_major_and_minor_from_version(full_version):

@lru_cache
def is_mlu_available(check_device=False):
"Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
"""
Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu
uninitialized.
"""
if importlib.util.find_spec("torch_mlu") is None:
return False

import torch_mlu # noqa: F401

if check_device:
try:
# Will raise a RuntimeError if no MLU is found
_ = torch.mlu.device_count()
return torch.mlu.is_available()
except RuntimeError:
return False
return hasattr(torch, "mlu") and torch.mlu.is_available()
with patch_environment(PYTORCH_CNDEV_BASED_MLU_CHECK="1"):
available = torch.mlu.is_available()

return available


@lru_cache
Expand Down
74 changes: 0 additions & 74 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.

import collections
import os
import platform
import re
import socket
from codecs import encode
from contextlib import contextmanager
from functools import partial, reduce
from types import MethodType
from typing import OrderedDict
Expand Down Expand Up @@ -262,78 +260,6 @@ def load(f, map_location=None, **kwargs):
return loaded_obj


@contextmanager
def clear_environment():
"""
A context manager that will temporarily clear environment variables.
When this context exits, the previous environment variables will be back.
Example:
```python
>>> import os
>>> from accelerate.utils import clear_environment
>>> os.environ["FOO"] = "bar"
>>> with clear_environment():
... print(os.environ)
... os.environ["FOO"] = "new_bar"
... print(os.environ["FOO"])
{}
new_bar
>>> print(os.environ["FOO"])
bar
```
"""
_old_os_environ = os.environ.copy()
os.environ.clear()

try:
yield
finally:
os.environ.clear() # clear any added keys,
os.environ.update(_old_os_environ) # then restore previous environment


@contextmanager
def patch_environment(**kwargs):
"""
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
Will convert the values in `kwargs` to strings and upper-case all the keys.
Example:
```python
>>> import os
>>> from accelerate.utils import patch_environment
>>> with patch_environment(FOO="bar"):
... print(os.environ["FOO"]) # prints "bar"
>>> print(os.environ["FOO"]) # raises KeyError
```
"""
existing_vars = {}
for key, value in kwargs.items():
key = key.upper()
if key in os.environ:
existing_vars[key] = os.environ[key]
os.environ[key] = str(value)

try:
yield
finally:
for key in kwargs:
key = key.upper()
if key in existing_vars:
# restore previous value
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)


def get_pretty_name(obj):
"""
Gets a pretty name from `obj`.
Expand Down
2 changes: 1 addition & 1 deletion tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
slow,
)
from accelerate.test_utils.training import RegressionDataset, RegressionModel
from accelerate.utils import patch_environment
from accelerate.utils.dataclasses import DeepSpeedPlugin
from accelerate.utils.deepspeed import (
DeepSpeedEngineWrapper,
Expand All @@ -50,7 +51,6 @@
DummyOptim,
DummyScheduler,
)
from accelerate.utils.other import patch_environment
from accelerate.utils.versions import compare_versions


Expand Down
2 changes: 1 addition & 1 deletion tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
require_non_torch_xla,
slow,
)
from accelerate.utils import patch_environment
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
Expand All @@ -43,7 +44,6 @@
)
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading
from accelerate.utils.other import patch_environment


set_seed(42)
Expand Down

0 comments on commit 1ace241

Please sign in to comment.