diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index ea875d660f..7182c47d2a 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -28,7 +28,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import pop_config +from llmfoundry.utils.config_utils import get_hf_config_value, pop_config if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -247,9 +247,9 @@ def _autoset_attn_implementation_monkeypatch( else: setattr(config, k, v) - if hasattr(config, 'attn_config') and config.attn_config.get( + if hasattr(config, 'attn_config') and get_hf_config_value( + config.attn_config, 'seq_parallel_world_size', - None, ) is not None: raise NotImplementedError( 'Sequence Parallelism is not supported for HuggingFace models.', diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 0b05214565..db180e3168 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -12,6 +12,7 @@ from composer.utils import dist, parse_uri from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from transformers import PretrainedConfig from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.models.utils import init_empty_weights @@ -59,6 +60,21 @@ def pop_config( return default_value +def get_hf_config_value(config: Union[dict, PretrainedConfig], key: str) -> Any: + """Get a value from a Hugging Face config. + + Args: + config (Union[dict, PretrainedConfig]): The Hugging Face config object. + key (str): The key to get from the config. + + Returns: + Any: The value from the config. None if the key does not exist. + """ + if isinstance(config, dict): + return config.get(key) + return getattr(config, key, None) + + def calculate_batch_size_info( global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']],