Skip to content

Commit

Permalink
Merge branch 'main' into anna/evalloader
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Nov 28, 2023
2 parents 5b85218 + 613457a commit 1a0a3da
Show file tree
Hide file tree
Showing 86 changed files with 72,706 additions and 535 deletions.
5 changes: 4 additions & 1 deletion .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
type=int,
default=1800,
help='Timeout for run (in seconds)')
parser.add_argument('--deps_group',
type=str,
help='Dependency group to install')
args = parser.parse_args()

name = args.name
Expand Down Expand Up @@ -89,7 +92,7 @@
clear_tmp_path_flag = '-o tmp_path_retention_policy=none'
command += f'''
pip install --upgrade --user .[all]
pip install --upgrade --user .[{args.deps_group}]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ jobs:
strategy:
matrix:
include:
- name: 'cpu-latest'
container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- name: 'cpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.1.0'
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@ jobs:
uses: ./.github/workflows/pytest-gpu.yaml
strategy:
matrix:
# TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite
include:
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
- name: 'gpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0'
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0-flash2'
container: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all-flash2'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand All @@ -45,5 +43,6 @@ jobs:
pytest-command: ${{ matrix.pytest_command }}
pytest-markers: ${{ matrix.markers }}
python-version: 3.9
deps-group: ${{ matrix.deps_group }}
secrets:
mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }}
6 changes: 5 additions & 1 deletion .github/workflows/pytest-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ on:
required: false
type: string
default: 3.9
deps-group:
required: true
type: string
secrets:
mcloud-api-key:
required: true
Expand Down Expand Up @@ -77,4 +80,5 @@ jobs:
--image '${{ inputs.container }}' \
--pytest_markers '${{ inputs.pytest-markers }}' \
--pytest_command '${{ inputs.pytest-command }}' \
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS} \
--deps_group ${{ inputs.deps-group }}
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.3.0'
__version__ = '0.4.0'
95 changes: 52 additions & 43 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(self,
f'local directory {local} does not contain split {split}'
)

# Build Dataset
super().__init__(
local=local,
remote=remote,
Expand Down Expand Up @@ -345,51 +344,57 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
error: Optional[Exception] = None
filtered_dataset = None
try:
dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)
except Exception as e:
error = e
# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -403,7 +408,11 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

if error is not None:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')
assert filtered_dataset is not None
return filtered_dataset

def build_from_streaming(self, *args: Any,
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from composer.utils import using_torch_2
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -347,7 +348,7 @@ def profile_packing(
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.prefetch_factor = None if using_torch_2() else 2
dataloader_cfg.persistent_workers = False

# Determine the packing_ratio values we'll try
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def flash_attn_fn(
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
verbose (int): Deprecated.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
Expand Down
27 changes: 27 additions & 0 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from transformers import PreTrainedTokenizer

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible."""


class TiktokenTokenizerWrapper(PreTrainedTokenizer):
"""A thin wrapper around tiktoken to make it compatible with Hugging Face.
Expand All @@ -23,6 +25,7 @@ def __init__(self,
encoding_name: Optional[str] = None,
add_bos_token: bool = False,
add_eos_token: bool = False,
use_default_system_prompt: bool = False,
unk_token: Optional[str] = '<|endoftext|>',
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
Expand All @@ -39,6 +42,7 @@ def __init__(self,
Either model_name or encoding_name must be set, but not both.
add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False.
add_eos_token (bool, optional): Whether to add eos tokens. Defaults to False.
use_default_system_prompt (bool, optional): Use the default system prompt or not. Defaults to False.
unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'.
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
Expand Down Expand Up @@ -87,11 +91,13 @@ def pickle_Encoding(enc: Encoding):

self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt

super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
unk_token=unk_token,
eos_token=eos_token,
bos_token=bos_token,
Expand All @@ -107,6 +113,27 @@ def vocab_size(self) -> int:
def is_fast(self) -> bool:
return False

@property
def default_chat_template(self):
"""Chat ML Template for User/Assistant.
Pinning default Chat ML template in case defaults change.
"""
template = (
"{% set system_message = '' %}"
'{% if USE_DEFAULT_PROMPT == true %}'
"{{'<|im_start|>system\n' + 'DEFAULT_SYSTEM_PROMPT'}}"
'{% endif %}'
'{% for message in messages %}'
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
'{% endfor %}')
template = template.replace(
'USE_DEFAULT_PROMPT',
'true' if self.use_default_system_prompt else 'false')
template = template.replace('DEFAULT_SYSTEM_PROMPT',
DEFAULT_SYSTEM_PROMPT)
return template

def get_vocab(self) -> Dict[str, int]:
"""Returns vocab as a dict.
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,12 @@ def _validate_cfg(icl_cfg: DictConfig):
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
continuation_delimiter=icl_cfg.continuation_delimiter,
question_prelimiter=icl_cfg.get('question_prelimiter', ''),
destination_path=destination_path,
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
)
cot_delimiter=icl_cfg.get('cot_delimiter', ''))
if hasattr(
icl_cfg,
'has_categories') and icl_cfg.has_categories and isinstance(
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
3 changes: 1 addition & 2 deletions mcli/mcli-1b-max-seq-len-8k.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down Expand Up @@ -123,7 +123,6 @@ parameters:
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
verbose: false

# Logging
progress_bar: false
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: -e .[gpu]
ssh_clone: false # Should be true if using a private repo
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-benchmark-mpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: '.[gpu]'

Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-convert-composer-to-hf.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: -e .
ssh_clone: false # Should be true if using a private repo
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-hf-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: v0.4.0
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu]"
ssh_clone: false # Should be true if using a private repo
Expand Down
Loading

0 comments on commit 1a0a3da

Please sign in to comment.