Skip to content

Commit

Permalink
Merge branch 'main' into human_eval_simple
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Oct 3, 2023
2 parents 7968ca4 + 7f7c097 commit 0f8b160
Show file tree
Hide file tree
Showing 32 changed files with 833 additions and 70 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.2.0'
__version__ = '0.3.0'
11 changes: 8 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,14 @@ def _build_hf_dataset_from_remote(
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
cfg.dataset.split if cfg.dataset.split != 'data' else 'data_not',
destination_split if destination_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
Expand All @@ -306,10 +309,12 @@ def _build_hf_dataset_from_remote(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{cfg.dataset.split}-00000-of-00001.{extension}')))
f'{destination_split}-00000-of-00001.{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
signal_file_path = os.path.join(
finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def build_from_hf(
Dataset: The tokenized dataset.
"""
dataset_name = cfg.hf_name
split = cfg.split
# HF datasets does not support a split with dashes,so we replace split
# dashes with underscore.
split = cfg.split.replace('-', '_')
kwargs = cfg.get('hf_kwargs', {})
proto_preprocessing_fn = cfg.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, dict) or isinstance(
Expand Down
6 changes: 4 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, om_model_config: Union[DictConfig,
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(self, om_model_config: Union[DictConfig,
f'init_device="{init_device}" must be either "cpu" or "meta".'
)

signal_file_path = '.local_rank0_completed_autoresume'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
60 changes: 47 additions & 13 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -243,10 +263,16 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -383,9 +409,8 @@ def triton_flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand Down Expand Up @@ -419,6 +444,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -450,7 +476,9 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -557,6 +585,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -569,7 +598,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


class MultiQueryAttention(GroupedQueryAttention):
Expand All @@ -591,6 +622,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -603,7 +635,9 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device)
device=device,
bias=bias,
)


def attn_bias_shape(
Expand Down
15 changes: 10 additions & 5 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -66,11 +67,14 @@ def __init__(
}

self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
fc_type=fc_type,
device=device,
**attn_config_subset_for_attn_class,
bias=not no_bias,
)
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
Expand All @@ -79,6 +83,7 @@ def __init__(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
**ffn_config,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down
8 changes: 7 additions & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ def __init__(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()
fc_kwargs = {}
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
self.up_proj = FC_CLASS_REGISTRY[fc_type](
Expand Down Expand Up @@ -60,6 +63,7 @@ def build_ffn(
expansion_ratio: int,
fc_type: str = 'torch',
device: Optional[str] = None,
bias: bool = True,
**kwargs: Any,
) -> nn.Module:
ffn_type = kwargs.pop('ffn_type')
Expand All @@ -72,12 +76,14 @@ def build_ffn(
expansion_ratio=expansion_ratio,
fc_type=fc_type,
device=device,
bias=bias,
)
elif ffn_type == 'te_ln_mlp':
assert te is not None
return te.LayerNormMLP(
hidden_size=d_model,
ffn_hidden_size=d_model * expansion_ratio,
bias=bias,
**kwargs,
)

Expand Down
9 changes: 8 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningLMAccuracy,
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -150,6 +151,11 @@ def __init__(self, config: MPTConfig):
log.info(f'Removing bias ({module.bias}) from {module}.')
module.register_parameter('bias', None)

# For transformer engine
if hasattr(module, 'use_bias'):
log.info(f'Setting use_bias=False for {module}.')
module.use_bias = False

log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

Expand Down Expand Up @@ -695,6 +701,7 @@ def __init__(
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
]
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _validate_cfg(icl_cfg: DictConfig):
]
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
Expand All @@ -244,6 +246,10 @@ def _validate_cfg(icl_cfg: DictConfig):
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
Expand Down Expand Up @@ -274,6 +280,8 @@ def _validate_cfg(icl_cfg: DictConfig):
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
Expand Down
6 changes: 3 additions & 3 deletions mcli/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
# git_branch: # Specify your git branch
git_commit: 186dd19888a8c8874584f9e78619f3fb0348309f # TODO: repin after next release
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo

Expand Down Expand Up @@ -33,7 +33,7 @@ parameters:
model_max_length: ${max_seq_len}
model:
name: mpt_causal_lm
init_device: meta
init_device: mixed
d_model: 2048
n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention
n_layers: 24
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b-max-seq-len-8k.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
3 changes: 2 additions & 1 deletion mcli/mcli-benchmark-mpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: '.[gpu]'

command: |
Expand Down
Loading

0 comments on commit 0f8b160

Please sign in to comment.