Skip to content

Commit

Permalink
Fix config access for DBRX (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 6, 2024
1 parent bfbb8c5 commit ab9dde7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import pop_config
from llmfoundry.utils.config_utils import get_hf_config_value, pop_config

if TYPE_CHECKING:
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -247,9 +247,9 @@ def _autoset_attn_implementation_monkeypatch(
else:
setattr(config, k, v)

if hasattr(config, 'attn_config') and config.attn_config.get(
if hasattr(config, 'attn_config') and get_hf_config_value(
config.attn_config,
'seq_parallel_world_size',
None,
) is not None:
raise NotImplementedError(
'Sequence Parallelism is not supported for HuggingFace models.',
Expand Down
16 changes: 16 additions & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from composer.utils import dist, parse_uri
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.utils import init_empty_weights
Expand Down Expand Up @@ -59,6 +60,21 @@ def pop_config(
return default_value


def get_hf_config_value(config: Union[dict, PretrainedConfig], key: str) -> Any:
"""Get a value from a Hugging Face config.
Args:
config (Union[dict, PretrainedConfig]): The Hugging Face config object.
key (str): The key to get from the config.
Returns:
Any: The value from the config. None if the key does not exist.
"""
if isinstance(config, dict):
return config.get(key)
return getattr(config, key, None)


def calculate_batch_size_info(
global_batch_size: int,
device_microbatch_size: Union[int, Literal['auto']],
Expand Down

0 comments on commit ab9dde7

Please sign in to comment.