Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Rajput committed Jan 5, 2024
2 parents 44eb30f + 7c415e3 commit c0501cc
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 28 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators)
* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
Expand Down Expand Up @@ -186,6 +189,12 @@ Notes:
1. `attn_impl: triton` does not work.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.

### Intel Gaudi
Support for LLM Foundry on Intel Gaudi devices is experimental, please use the branch `habana_alpha` and see the [README on that branch](https://github.com/mosaicml/llm-foundry/blob/habana_alpha) which has [install instructions and known issues.](https://github.com/mosaicml/llm-foundry/tree/habana_alpha?tab=readme-ov-file#intel-gaudi)

For training and inference performance results on Intel Gaudi2 accelerators, see our blog: https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators


# Quickstart

> **Note**
Expand Down
12 changes: 6 additions & 6 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)))


class MPTGeGLU(MPTMLP):
class MPTGLU(MPTMLP):

def __init__(
self,
Expand All @@ -138,19 +138,19 @@ def __init__(
device=device,
bias=bias,
)
self.gate = FC_CLASS_REGISTRY[fc_type](
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
self.up_proj.out_features,
**self.fc_kwargs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)) * self.gate(x))
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


FFN_CLASS_REGISTRY = {
'mptmlp': MPTMLP,
'mptgeglu': MPTGeGLU,
'mptglu': MPTGLU,
}

if te is not None:
Expand All @@ -169,10 +169,10 @@ def build_ffn(
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
if ffn_type in ['mptmlp', 'mptgeglu']:
if ffn_type in ['mptmlp', 'mptglu']:
if len(kwargs) > 0:
raise ValueError(
f'MPTMLP (or MPTGeGLU) got an unexpected keyword argument: {kwargs}'
f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}'
)
return FFN_CLASS_REGISTRY[ffn_type](
d_model=d_model,
Expand Down
10 changes: 8 additions & 2 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, mptgeglu, te_ln_mlp
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
Expand Down Expand Up @@ -293,7 +293,13 @@ def _validate_config(self) -> None:
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
)
if self.ffn_config['ffn_type'] in ['mptmlp', 'mptgeglu']:
if self.ffn_config['ffn_type'] == 'mptgeglu':
raise ValueError(
'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. '
+
'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.'
)
elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
self.ffn_config['bias'] = not self.no_bias
Expand Down
51 changes: 37 additions & 14 deletions llmfoundry/utils/checkpoint_conversion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def load_tokenizer(


def _write_zero_bias(weight_name: str, weight_file_path: str,
bias_shape: Union[Tuple[int, ...], int]) -> None:
bias_shape: Union[Tuple[int, ...],
int], np_data_type: np.dtype) -> None:
"""Write zeros for bias when converting MPT to FasterTransformer weights.
MPT model might not have bias while FT expects bias.
Expand All @@ -121,20 +122,22 @@ def _write_zero_bias(weight_name: str, weight_file_path: str,
weight_name (str): Name of the weight tensor.
weight_file_path (str): Output path for storing the weight (NOT zero bias).
bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array.
np_data_type (np.dtype): The data type for bias.
"""
if 'weight' not in weight_file_path:
raise RuntimeError(
f'Cannot write zero bias for {weight_name}. Input is not a weight tensor'
)
log.debug(f'zero bias for weight: {weight_name}')
bias_file_path = weight_file_path.replace('.weight', '.bias')
bias = np.zeros(bias_shape, dtype=np.float32)
bias = np.zeros(bias_shape, dtype=np_data_type)
bias.tofile(bias_file_path)


def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
tensor_name: str, config: Dict[str, Any],
data: np.ndarray) -> None:
data: np.ndarray,
np_weight_data_type: np.dtype) -> None:
"""Convert each MPT weight to a FasterTransformer compatible format.
Args:
Expand All @@ -155,7 +158,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
save_path = os.path.join(save_dir, f'model.{tensor_name}.bin')
data.tofile(save_path)
if 'weight' in tensor_name and config['no_bias']:
_write_zero_bias(tensor_name, save_path, data.shape[-1])
_write_zero_bias(tensor_name, save_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('attention.dense.weight') != -1:
assert data.shape == (
Expand All @@ -170,11 +175,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
if config['no_bias']:
fake_weight_path = os.path.join(save_dir,
f'model.{tensor_name}.bin')
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1])
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1:
assert data.shape == (
config['d_model'], config['mlp_ratio'] *
config['d_model'], config['expansion_ratio'] *
config['d_model']), f'unexpected dim for {tensor_name}'
# nn.Linear weights are transposed
data = data.T
Expand All @@ -185,11 +192,13 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
if config['no_bias']:
fake_weight_path = os.path.join(save_dir,
f'model.{tensor_name}.bin')
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1])
_write_zero_bias(tensor_name, fake_weight_path, data.shape[-1],
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1:
assert data.shape == (
config['mlp_ratio'] * config['d_model'],
config['expansion_ratio'] * config['d_model'],
config['d_model']), f'unexpected dim for {tensor_name}'
# nn.Linear weights are transposed
data = data.T
Expand All @@ -200,11 +209,12 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
split_vals[j].tofile(save_path)
if config['no_bias']:
_write_zero_bias(tensor_name, save_path,
split_vals[j].shape[-1])
split_vals[j].shape[-1], np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1:
assert data.shape == (
config['mlp_ratio'] *
config['expansion_ratio'] *
config['d_model'],), f'unexpected dim for {tensor_name}'
split_vals = np.split(data, infer_gpu_num, axis=-1)
for j in range(infer_gpu_num):
Expand Down Expand Up @@ -238,7 +248,9 @@ def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int,
split_vals[j].tofile(save_path)
if config['no_bias']:
_write_zero_bias(tensor_name, save_path,
(3, split_vals[j].shape[-1]))
(3, split_vals[j].shape[-1]),
np_weight_data_type
) # pyright: ignore [reportGeneralTypeIssues]

else:
raise RuntimeError(f'Tensor with name {tensor_name} is not handled')
Expand Down Expand Up @@ -306,7 +318,12 @@ def convert_and_save_ft_weights(named_params: dict,
'model.final_layernorm.weight.bin')
data.tofile(save_path)
if config['no_bias']:
_write_zero_bias(name, save_path, data.shape[-1])
_write_zero_bias(
name,
save_path,
data.shape[-1],
np_weight_data_type # pyright: ignore [reportGeneralTypeIssues]
)
elif name == 'transformer.lm_head.weight':
data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin'))
else:
Expand All @@ -315,5 +332,11 @@ def convert_and_save_ft_weights(named_params: dict,
new_name = name.replace('transformer.blocks.',
'layers.').replace(
mpt_pattern, ft_pattern)
_convert_weight_to_ft_each(save_dir, infer_gpu_num,
new_name, config, data)
_convert_weight_to_ft_each(
save_dir,
infer_gpu_num,
new_name,
config,
data,
np_weight_data_type # pyright: ignore [reportGeneralTypeIssues]
)
8 changes: 8 additions & 0 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)
# Always set `sync_module_states` to True when using hybrid sharding
if fsdp_config is not None and \
fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \
and not fsdp_config.get('sync_module_states', False):
warnings.warn(
('Setting `sync_module_states = True` for FSDP. This is required '
'when using hybrid sharding.'))
fsdp_config['sync_module_states'] = True

# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.2,<0.18',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.36,<4.37',
'mosaicml-streaming>=0.7.1,<0.8',
'mosaicml-streaming>=0.7.2,<0.8',
'torch>=2.1,<2.1.1',
'datasets==2.15.0',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
'sentencepiece==0.1.97',
'einops==0.5.0',
'einops==0.7.0',
'omegaconf>=2.2.3,<3',
'slack-sdk<4',
'mosaicml-cli>=0.5.27,<1',
Expand Down Expand Up @@ -84,11 +84,11 @@
]

extra_deps['databricks'] = [
'mosaicml[databricks]>=0.17.1,<0.18',
'mosaicml[databricks]>=0.17.2,<0.18',
]

extra_deps['tensorboard'] = [
'mosaicml[tensorboard]>=0.17.1,<0.18',
'mosaicml[tensorboard]>=0.17.2,<0.18',
]

extra_deps['gpu'] = [
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2):
[('torch', torch.float16), ('torch', torch.bfloat16),
pytest.param('flash', torch.float16, marks=pytest.mark.gpu),
pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)])
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu'])
@pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptglu'])
@pytest.mark.parametrize('ffn_act_fn', [
None,
{
Expand Down

0 comments on commit c0501cc

Please sign in to comment.