Skip to content

Commit

Permalink
Merge branch 'main' into ft-tokenized-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 5, 2023
2 parents 9cceb70 + cb1d94a commit 2bf2698
Show file tree
Hide file tree
Showing 33 changed files with 834 additions and 65 deletions.
3 changes: 2 additions & 1 deletion .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@
if len(name) > 56:
name = name[:56]

clear_tmp_path_flag = '-o tmp_path_retention_policy=none'
command += f'''
pip install --upgrade --user .[all]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'"
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
run: |
set -ex
export PATH=/composer-python:$PATH
export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}'"
export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}' -o tmp_path_retention_policy=none"
# Necessary to run git diff for doctests
git config --global --add safe.directory /__w/llm-foundry/llm-foundry
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.2.0'
__version__ = '0.3.0'
3 changes: 2 additions & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def _build_hf_dataset_from_remote(

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
signal_file_path = os.path.join(
finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
if dist.get_local_rank() == 0:
try:
get_file(path=name, destination=destination, overwrite=True)
Expand Down
19 changes: 13 additions & 6 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningLMAccuracy,
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -64,16 +65,14 @@ def __init__(self, om_model_config: Union[DictConfig,
nn.Module],
tokenizer: PreTrainedTokenizerBase):
# set up training and eval metrics
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
]
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
Expand All @@ -90,6 +89,9 @@ def __init__(self, om_model_config: Union[DictConfig,
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
)

if not om_model_config.get('use_train_metrics', True):
train_metrics = []

# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
Expand All @@ -107,6 +109,7 @@ def __init__(self, om_model_config: Union[DictConfig,
)

attr = getattr(config, k)
# attempt to disallow typos in nested configs
if isinstance(attr, Mapping):
extra_keys = [
_k for _k in v.keys() if _k not in attr.keys()
Expand All @@ -118,6 +121,10 @@ def __init__(self, om_model_config: Union[DictConfig,
f'Expected (a subset of) keys: {list(attr.keys())}.'
)
getattr(config, k).update(v)
# necessary case to allow for rope_scaling to be overriden in llama config
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
else:
setattr(config, k, v)

Expand Down Expand Up @@ -164,7 +171,7 @@ def __init__(self, om_model_config: Union[DictConfig,
f'init_device="{init_device}" must be either "cpu" or "meta".'
)

signal_file_path = '.local_rank0_completed_autoresume'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')
Expand Down
43 changes: 34 additions & 9 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -243,10 +263,16 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -383,9 +409,8 @@ def triton_flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningLMAccuracy,
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
Expand Down Expand Up @@ -693,13 +694,16 @@ def __init__(
hf_config = MPTConfig.from_dict(resolved_om_model_config)
model = MPTForCausalLM(hf_config)

train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metrics = [LanguageCrossEntropy(),
LanguagePerplexity()] if use_train_metrics else []
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
]
Expand Down
8 changes: 8 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _validate_cfg(icl_cfg: DictConfig):
]
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(
f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.'
Expand All @@ -244,6 +246,10 @@ def _validate_cfg(icl_cfg: DictConfig):
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20

for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
Expand Down Expand Up @@ -274,6 +280,8 @@ def _validate_cfg(icl_cfg: DictConfig):
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
if hasattr(
Expand Down
6 changes: 3 additions & 3 deletions mcli/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
# git_branch: # Specify your git branch
git_commit: 186dd19888a8c8874584f9e78619f3fb0348309f # TODO: repin after next release
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo

Expand Down Expand Up @@ -33,7 +33,7 @@ parameters:
model_max_length: ${max_seq_len}
model:
name: mpt_causal_lm
init_device: meta
init_device: mixed
d_model: 2048
n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention
n_layers: 24
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b-max-seq-len-8k.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
3 changes: 2 additions & 1 deletion mcli/mcli-benchmark-mpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: '.[gpu]'

command: |
Expand Down
4 changes: 2 additions & 2 deletions mcli/mcli-convert-composer-to-hf.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .
ssh_clone: false # Should be true if using a private repo

command: |
cd llm-foundry/llmfoundry/inference
cd llm-foundry/scripts/inference
python convert_composer_to_hf.py \
--composer_path s3://bucket/folder/checkpoint-path.pt \
--hf_output_path s3://bucket/folder/hf/ \
Expand Down
4 changes: 2 additions & 2 deletions mcli/mcli-hf-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu]"
ssh_clone: false # Should be true if using a private repo
Expand All @@ -11,7 +11,7 @@ command: |
composer eval/eval.py /mnt/config/parameters.yaml
# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
run_name: all-eval
run_name: mpt-eval
gpu_num: 8
# gpu_type:
# cluster: # replace with your cluster here!
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-hf-generate.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
3 changes: 2 additions & 1 deletion mcli/mcli-llama2-finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_commit: 5ec4016b40652557d57a1d4949ad13a65251184b # TODO: repin this after next release
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo

Expand Down
4 changes: 2 additions & 2 deletions mcli/mcli-openai-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: # use your branch
# git_commit: 29d65cc26853c09f6de7542978056ddb0b07e98c # OR use your commit hash
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu,openai]"
ssh_clone: false # Should be true if using a private repo

Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-pretokenize-oci-upload.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ integrations:
- oci-cli==3.23.2
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.2.0
git_branch: v0.3.0
# git_commit: # OR use your commit hash
pip_install: '.'
ssh_clone: false # Should be true if using a private repo
Expand Down
Loading

0 comments on commit 2bf2698

Please sign in to comment.