Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rm strtobool #1964

Merged
merged 4 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tempfile
import unittest
from contextlib import contextmanager
from distutils.util import strtobool
from functools import partial
from pathlib import Path
from typing import List, Union
Expand All @@ -44,6 +43,7 @@
is_transformers_available,
is_wandb_available,
is_xpu_available,
str_to_bool,
)


Expand All @@ -56,7 +56,7 @@ def parse_flag_from_env(key, default=False):
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
_value = str_to_bool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TensorInformation,
TorchDynamoPlugin,
)
from .environment import get_int_from_env, parse_choice_from_env, parse_flag_from_env
from .environment import get_int_from_env, parse_choice_from_env, parse_flag_from_env, str_to_bool
from .imports import (
get_ccl_version,
is_4bit_bnb_available,
Expand Down
24 changes: 12 additions & 12 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from distutils.util import strtobool
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .versions import compare_versions


Expand Down Expand Up @@ -472,9 +472,9 @@ def __post_init__(self):
if self.mode is None:
self.mode = os.environ.get(prefix + "MODE", "default")
if self.fullgraph is None:
self.fullgraph = strtobool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1
self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1
if self.dynamic is None:
self.dynamic = strtobool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1
self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1

def to_dict(self):
dynamo_config = copy.deepcopy(self.__dict__)
Expand Down Expand Up @@ -635,7 +635,7 @@ def __post_init__(self):
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
if self.zero3_init_flag is None:
self.zero3_init_flag = (
strtobool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
)
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
Expand Down Expand Up @@ -897,7 +897,7 @@ def __post_init__(self):
self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1)))

if self.cpu_offload is None:
if strtobool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:
if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1:
self.cpu_offload = CPUOffload(offload_params=True)
else:
self.cpu_offload = CPUOffload(offload_params=False)
Expand All @@ -910,10 +910,10 @@ def __post_init__(self):
if self.state_dict_type is None:
state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT")
self.set_state_dict_type(state_dict_type_policy)
self.use_orig_params = strtobool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1
self.sync_module_states = strtobool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1
self.forward_prefetch = strtobool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1
self.activation_checkpointing = strtobool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1
self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1
self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
self.param_init_fn = lambda x: x.to_empty(device=torch.cuda.current_device(), recurse=False)
Expand Down Expand Up @@ -1169,13 +1169,13 @@ def __post_init__(self):
if self.gradient_clipping is None:
self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0))
if self.recompute_activation is None:
self.recompute_activation = strtobool(os.environ.get(prefix + "RECOMPUTE_ACTIVATION", "False")) == 1
self.recompute_activation = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATION", "False")) == 1
if self.use_distributed_optimizer is None:
self.use_distributed_optimizer = (
strtobool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1
str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1
)
if self.sequence_parallelism is None:
self.sequence_parallelism = strtobool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1
self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1

if self.pp_degree > 1 or self.use_distributed_optimizer:
self.DDP_impl = "local"
Expand Down
18 changes: 16 additions & 2 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@
# limitations under the License.

import os
from distutils.util import strtobool


def str_to_bool(value) -> int:
"""
Converts a string representation of truth to `True` (1) or `False` (0).

True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
"""
value = value.lower()
if value in ("y", "yes", "t", "true", "on", "1"):
return 1
elif value in ("n", "no", "f", "false", "off", "0"):
return 0
else:
raise ValueError(f"invalid truth value {value}")


def get_int_from_env(env_keys, default):
Expand All @@ -28,7 +42,7 @@ def get_int_from_env(env_keys, default):
def parse_flag_from_env(key, default=False):
"""Returns truthy value for `key` from the env if available else the default."""
value = os.environ.get(key, str(default))
return strtobool(value) == 1 # As its name indicates `strtobool` actually returns an int...
return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment, which I assume is meant ironically, could be removed if we actually return a bool. But technically, a bool is an int, so not sure if it's really meant to be ironic :D



def parse_choice_from_env(key, default="no"):
Expand Down
5 changes: 2 additions & 3 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import importlib.metadata
import os
import warnings
from distutils.util import strtobool
from functools import lru_cache

import torch
from packaging import version
from packaging.version import parse

from .environment import parse_flag_from_env
from .environment import parse_flag_from_env, str_to_bool
from .versions import compare_versions, is_torch_version


Expand Down Expand Up @@ -143,7 +142,7 @@ def is_bnb_available():


def is_megatron_lm_available():
if strtobool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
package_exists = importlib.util.find_spec("megatron") is not None
if package_exists:
try:
Expand Down