Skip to content

Commit

Permalink
Merge branch 'main' into anna/asynceval
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 30, 2023
2 parents 91bdf43 + 1191267 commit e8f4661
Show file tree
Hide file tree
Showing 95 changed files with 73,386 additions and 711 deletions.
5 changes: 4 additions & 1 deletion .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
type=int,
default=1800,
help='Timeout for run (in seconds)')
parser.add_argument('--deps_group',
type=str,
help='Dependency group to install')
args = parser.parse_args()

name = args.name
Expand Down Expand Up @@ -89,7 +92,7 @@
clear_tmp_path_flag = '-o tmp_path_retention_policy=none'
command += f'''
pip install --upgrade --user .[all]
pip install --upgrade --user .[{args.deps_group}]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
Expand Down
14 changes: 6 additions & 8 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,17 @@ jobs:
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
if [ "${{ github.event_name }}" == "push" ]; then
echo "Triggered by push event."
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
elif [ "${{ github.event_name }}" == "pull_request" ]; then
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
echo "Triggered by unknown event: ${{ github.event_name }}"
exit 1
# Triggered by push or workflow_dispatch event
echo "Triggered by ${{ github.event_name }} event, releasing to prod"
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
fi
echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ jobs:
strategy:
matrix:
include:
- name: 'cpu-latest'
container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
- name: 'cpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.1.0'
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@ jobs:
uses: ./.github/workflows/pytest-gpu.yaml
strategy:
matrix:
# TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite
include:
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
- name: 'gpu-1.13.1'
container: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0'
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all'
- name: 'gpu-2.1.0-flash2'
container: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
deps_group: 'all-flash2'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand All @@ -45,5 +43,6 @@ jobs:
pytest-command: ${{ matrix.pytest_command }}
pytest-markers: ${{ matrix.markers }}
python-version: 3.9
deps-group: ${{ matrix.deps_group }}
secrets:
mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }}
6 changes: 5 additions & 1 deletion .github/workflows/pytest-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ on:
required: false
type: string
default: 3.9
deps-group:
required: true
type: string
secrets:
mcloud-api-key:
required: true
Expand Down Expand Up @@ -76,4 +79,5 @@ jobs:
--image '${{ inputs.container }}' \
--pytest_markers '${{ inputs.pytest-markers }}' \
--pytest_command '${{ inputs.pytest-command }}' \
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
--timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS} \
--deps_group ${{ inputs.deps-group }}
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.3.0'
__version__ = '0.4.0'
95 changes: 52 additions & 43 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(self,
f'local directory {local} does not contain split {split}'
)

# Build Dataset
super().__init__(
local=local,
remote=remote,
Expand Down Expand Up @@ -345,51 +344,57 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
error: Optional[Exception] = None
filtered_dataset = None
try:
dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)
except Exception as e:
error = e
# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -403,7 +408,11 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

if error is not None:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')
assert filtered_dataset is not None
return filtered_dataset

def build_from_streaming(self, *args: Any,
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from composer.utils import using_torch_2
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -347,7 +348,7 @@ def profile_packing(
dataloader_cfg.dataset.packing_ratio = None
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.prefetch_factor = None if using_torch_2() else 2
dataloader_cfg.persistent_workers = False

# Determine the packing_ratio values we'll try
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def flash_attn_fn(
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
verbose (int): Deprecated.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
Expand Down
22 changes: 12 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
device, or b) step() is executed on a non-CUDA parameter.
"""

def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True, # XXX this flag is mostly for testing...
):

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
Loading

0 comments on commit e8f4661

Please sign in to comment.