Skip to content

Commit

Permalink
Merge branch 'main' into safe-load
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Jan 4, 2024
2 parents 03e2af4 + 083b4b2 commit efccf8d
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 88 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators)
* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
Expand Down Expand Up @@ -186,6 +189,12 @@ Notes:
1. `attn_impl: triton` does not work.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.

### Intel Gaudi
Support for LLM Foundry on Intel Gaudi devices is experimental, please use the branch `habana_alpha` and see the [README on that branch](https://github.com/mosaicml/llm-foundry/blob/habana_alpha) which has [install instructions and known issues.](https://github.com/mosaicml/llm-foundry/tree/habana_alpha?tab=readme-ov-file#intel-gaudi)

For training and inference performance results on Intel Gaudi2 accelerators, see our blog: https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators


# Quickstart

> **Note**
Expand Down
57 changes: 35 additions & 22 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand All @@ -32,31 +32,40 @@ class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
save_folder (str): Top level folder to save checkpoints to (can be a
URI). It is likely that this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often
checkpoints should be saved. If an integer, it will be assumed to be
in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
:attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
huggingface_folder_name (str): Folder to save each checkpoint under (can
be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``.
Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
mlflow_registered_model_name (Optional[str]): The name to register the
model under in the MLflow model registry. If ``None``, the model
will not be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments
that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is
unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
"""

def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -66,6 +75,7 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
self.flatten_imports = flatten_imports

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
Expand All @@ -91,7 +101,7 @@ def __init__(
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.save_interval: Time = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
Expand Down Expand Up @@ -229,7 +239,10 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
Expand Down
10 changes: 7 additions & 3 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,10 @@ def get_tokens_per_batch_func(
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'attention_mask' not in batch:
if not isinstance(batch, Mapping) or ('attention_mask' not in batch and
'input_ids' not in batch):
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an attention_mask key'
'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key'
)

if not decoder_only and 'decoder_attention_mask' not in batch:
Expand All @@ -336,7 +337,10 @@ def get_num_samples_in_batch(batch: Batch) -> int:
)

# Count number of non padding tokens in batch
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
if 'attention_mask' in batch:
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
else:
input_ids_tokens = batch['input_ids'].numel()

# For encoder decoder models only
decoder_input_ids_tokens = 0
Expand Down
12 changes: 6 additions & 6 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)))


class MPTGeGLU(MPTMLP):
class MPTGLU(MPTMLP):

def __init__(
self,
Expand All @@ -138,19 +138,19 @@ def __init__(
device=device,
bias=bias,
)
self.gate = FC_CLASS_REGISTRY[fc_type](
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)) * self.gate(x))
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


FFN_CLASS_REGISTRY = {
'mptmlp': MPTMLP,
'mptgeglu': MPTGeGLU,
'mptglu': MPTGLU,
}

if te is not None:
Expand All @@ -169,10 +169,10 @@ def build_ffn(
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
if ffn_type in ['mptmlp', 'mptgeglu']:
if ffn_type in ['mptmlp', 'mptglu']:
if len(kwargs) > 0:
raise ValueError(
f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}'
f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}'
)
return FFN_CLASS_REGISTRY[ffn_type](
d_model=d_model,
Expand Down
10 changes: 8 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
Expand Down Expand Up @@ -291,7 +291,13 @@ def _validate_config(self) -> None:
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
)
if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']:
if self.ffn_config['ffn_type'] == 'mptgeglu':
raise ValueError(
'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. '
+
'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.'
)
elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.utils.adapt_tokenizer import (
AutoTokenizerForMOD, # type: ignore (see note),
AutoTokenizerForMOD, # type: ignore (see note)
adapt_tokenizer_for_denoising, # type: ignore (see note)
)
from llmfoundry.models.utils.hf_prefixlm_converter import (
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True):
self._f_encode = None
self._f_decode = None
if self._try_quantize:
from turbo import dequantize8b, quantize8b
self._f_encode = quantize8b
self._f_decode = dequantize8b
from turbo import dequantize_signed, quantize_signed
self._f_encode = quantize_signed
self._f_decode = dequantize_signed

if data is not None:
self.set_data(data)
Expand Down Expand Up @@ -277,7 +277,7 @@ def set_data(self, data: torch.Tensor) -> None:
f'on device {data.device} with shape {data.shape}.')
self.data = None
assert self._f_encode is not None # pyright
self.quantized, self.scales = self._f_encode(data)
self.quantized, self.scales, _ = self._f_encode(data)
else:
self.data = data.to(dtype=torch.float32)
self.quantized = None
Expand Down
51 changes: 37 additions & 14 deletions llmfoundry/utils/checkpoint_conversion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def load_tokenizer(


def _write_zero_bias(weight_name: str, weight_file_path: str,
bias_shape: Union[Tuple[int, ...], int]) -> None:
bias_shape: Union[Tuple[int, ...],
int], np_data_type: np.dtype) -> None:
"""Write zeros for bias when converting MPT to FasterTransformer weights.
MPT model might not have bias while FT expects bias.
Expand All @@ -121,20 +122,22 @@ def _write_zero_bias(weight_name: str, weight_file_path: str,
weight_name (str): Name of the weight tensor.
weight_file_path (str): Output path for storing the weight (NOT zero bias).
bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array.
np_data_type (np.dtype): The data type for bias.
"""
if 'weight' not in weight_file_path:
raise RuntimeError(
f'Cannot write zero bias for {weight_name}. Input is not a weight tensor'
)
log.debug(f'zero bias for weight: {weight_name}')
bias_file_path = weight_file_path.replace('.weight', '.bias')
bias = np.zeros(bias_shape, dtype=np.float32)
bias = np.zeros(bias_shape, dtype=np_data_type)
bias.tofile(bias_file_path)


def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
tensor_name: str, config: Dict[str, Any],
data: np.ndarray) -> None:
data: np.ndarray,
np_weight_data_type: np.dtype) -> None:
"""Convert each MPT weight to a FasterTransformer compatible format.
Args:
Expand All @@ -155,7 +158,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
save_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
data.tofile(save_path)
if 'weight' in tensor_name and config['no_bias']:
_write_zero_bias(tensor_name, save_path, data.shape[-1])
_write_zero_bias(tensor_name, save_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('attention.dense.weight') != -1:
assert data.shape == (
Expand All @@ -170,11 +175,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
if config['no_bias']:
fake_weight_path = os.path.join(save_dir,
f'model.{tensor_name}.bin')
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1])
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1:
assert data.shape == (
config['d_model'], config['mlp_ratio'] *
config['d_model'], config['expansion_ratio'] *
config['d_model']), f'unexpected dim for {tensor_name}'
# nn.Linear weights are transposed
data = data.T
Expand All @@ -185,11 +192,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
if config['no_bias']:
fake_weight_path = os.path.join(save_dir,
f'model.{tensor_name}.bin')
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1])
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1:
assert data.shape == (
config['mlp_ratio'] * config['d_model'],
config['expansion_ratio'] * config['d_model'],
config['d_model']), f'unexpected dim for {tensor_name}'
# nn.Linear weights are transposed
data = data.T
Expand All @@ -200,11 +209,12 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
split_vals[j].tofile(save_path)
if config['no_bias']:
_write_zero_bias(tensor_name, save_path,
split_vals[j].shape[-1])
split_vals[j].shape[-1], np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1:
assert data.shape == (
config['mlp_ratio'] *
config['expansion_ratio'] *
config['d_model'],), f'unexpected dim for {tensor_name}'
split_vals = np.split(data, infer_gpu_num, axis=-1)
for j in range(infer_gpu_num):
Expand Down Expand Up @@ -238,7 +248,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
split_vals[j].tofile(save_path)
if config['no_bias']:
_write_zero_bias(tensor_name, save_path,
(3, split_vals[j].shape[-1]))
(3, split_vals[j].shape[-1]),
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

else:
raise RuntimeError(f'Tensor with name {tensor_name} is not handled')
Expand Down Expand Up @@ -306,7 +318,12 @@ def convert_and_save_ft_weights(named_params: dict,
'model.final_layernorm.weight.bin')
data.tofile(save_path)
if config['no_bias']:
_write_zero_bias(name, save_path, data.shape[-1])
_write_zero_bias(
name,
save_path,
data.shape[-1],
np_weight_data_type # pyright: ignore [reportGeneralTypeIssues]
)
elif name == 'transformer.lm_head.weight':
data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin'))
else:
Expand All @@ -315,5 +332,11 @@ def convert_and_save_ft_weights(named_params: dict,
new_name = name.replace('transformer.blocks.',
'layers.').replace(
mpt_pattern, ft_pattern)
_convert_weight_to_ft_each(save_dir, infer_gpu_num,
new_name, config, data)
_convert_weight_to_ft_each(
save_dir,
infer_gpu_num,
new_name,
config,
data,
np_weight_data_type # pyright: ignore [reportGeneralTypeIssues]
)
8 changes: 8 additions & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)
# Always set `sync_module_states` to True when using hybrid sharding
if fsdp_config is not None and \
fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \
and not fsdp_config.get('sync_module_states', False):
warnings.warn(
('Setting `sync_module_states = True` for FSDP. This is required '
'when using hybrid sharding.'))
fsdp_config['sync_module_states'] = True

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
Expand Down
Loading

0 comments on commit efccf8d

Please sign in to comment.