Skip to content

Commit

Permalink
Merge branch 'mosaicml:main' into cli99/configurable-act-ckpt-from-pr…
Browse files Browse the repository at this point in the history
…ivate
  • Loading branch information
cli99 authored Feb 7, 2024
2 parents f85e29b + 105f766 commit c014bd0
Show file tree
Hide file tree
Showing 16 changed files with 13,058 additions and 467 deletions.
45 changes: 32 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
Expand Down Expand Up @@ -199,7 +200,7 @@ def _tokenize_prompt_response_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
def tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.
Expand Down Expand Up @@ -228,6 +229,33 @@ def _tokenize_formatted_example(
raise ValueError(f'Unknown conversation type {example_format=}')


def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
example: Dict) -> bool:
"""Check if the example is a valid ift example.
This functions does the following check:
a. Length of input_ids should be less than max_seq_len
b. Both input_ids and labels should not be empty
c. Labels should have at least 1 non-padding token.
Args:
pad_token_id (int): The id of the padding token.
max_seq_len (int): Maximum sequence length.
example (Dict): The input example after tokenization, which has
``input_ids`` and ``labels`` fields.
Returns:
bool: Indicator of whether the input example is valid
"""
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and non_empty_labels and
non_padding_response)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Expand Down Expand Up @@ -347,7 +375,7 @@ def __init__(self,
# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


class DatasetConstructor:
Expand Down Expand Up @@ -550,7 +578,7 @@ def build_from_hf(
def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
Expand All @@ -567,17 +595,8 @@ def dataset_mapper(example: Dict):

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
partial(is_valid_ift_example, pad_token_id, max_seq_len),
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand Down
12 changes: 10 additions & 2 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,13 @@ def _validate_cfg(icl_cfg: DictConfig):
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()

early_stopping_criteria = icl_cfg.get('early_stopping_criteria',
None)
if isinstance(early_stopping_criteria, ListConfig):
early_stopping_criteria = om.to_container(
early_stopping_criteria)
assert early_stopping_criteria is None or isinstance(
early_stopping_criteria, list)
dataloaders = get_icl_task_dataloader(
icl_cfg.icl_task_type,
icl_cfg.dataset_uri,
Expand All @@ -520,7 +526,9 @@ def _validate_cfg(icl_cfg: DictConfig):
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.num_beams,
has_categories=icl_cfg.get('has_categories', False),
cot_delimiter=icl_cfg.get('cot_delimiter', ''))
cot_delimiter=icl_cfg.get('cot_delimiter', ''),
early_stopping_criteria=early_stopping_criteria,
do_normalization=icl_cfg.get('do_normalization', True))
if hasattr(
icl_cfg,
'has_categories') and icl_cfg.has_categories and isinstance(
Expand Down
4 changes: 2 additions & 2 deletions mcli/mcli-hf-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ parameters:
limit_all_gathers: True


icl_tasks: "eval/yamls/tasks_v0.2.yaml"
eval_gauntlet: "eval/yamls/eval_gauntlet_v0.2.yaml"
icl_tasks: "eval/yamls/tasks_v0.3.yaml"
eval_gauntlet: "eval/yamls/eval_gauntlet_v0.3.yaml"
80 changes: 70 additions & 10 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
import platform
import warnings
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import psutil
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (dataset_constructor,
is_valid_ift_example,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer


def parse_args() -> Namespace:
Expand All @@ -23,7 +29,7 @@ def parse_args() -> Namespace:
type=str,
required=True,
help=
'Name/path of the dataset (e.g., first argument to `datasets.load_dataset`)'
'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)'
)
parser.add_argument('--data_subset',
type=str,
Expand All @@ -38,6 +44,13 @@ def parse_args() -> Namespace:
default=None,
help='Name or import path of function used to preprocess (reformat) the dataset. ' +\
'See README for additional details.')
parser.add_argument(
'--data_files',
nargs='+',
default=[],
help=
'Data file for each split. If set, its length should be exact same as len(splits)'
)
parser.add_argument(
'--skip-preprocessing',
action='store_true',
Expand All @@ -63,6 +76,9 @@ def parse_args() -> Namespace:
default=None,
help='(Optional) name of compression algorithm to use.')
parser.add_argument('--num_workers', type=int, required=False, default=None)
parser.add_argument('--tokenizer', type=str, required=False, default=None)
parser.add_argument('--tokenizer_kwargs', type=str, required=False)
parser.add_argument('--max_seq_len', type=int, default=2048)

parsed = parser.parse_args()

Expand All @@ -73,6 +89,17 @@ def parse_args() -> Namespace:
f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.'
)

if parsed.tokenizer_kwargs is not None:
parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs)
else:
parsed.tokenizer_kwargs = {}

if len(parsed.data_files) > 0 and len(parsed.data_files) != len(
parsed.splits):
raise ValueError(
f'If data_files is set, data_files and splits must have the same length. Got {len(parsed.data_files)=} while {len(parsed.splits)=}'
)

return parsed


Expand Down Expand Up @@ -170,12 +197,23 @@ def main(args: Namespace) -> None:
'include the "--skip-preprocessing" flag to avoid this error.'
)

columns = ['prompt', 'response']
tokenizer = None
tokenizer_kwargs = args.tokenizer_kwargs
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
else:
columns = {'prompt': 'str', 'response': 'str'}

for split_name in args.splits:
for i, split_name in enumerate(args.splits):
data_file = None
if len(args.data_files) > 0:
data_file = args.data_files[i]
dataset = hf_datasets.load_dataset(path=args.dataset,
name=args.data_subset,
split=split_name,
data_files=data_file,
streaming=True)
loader = build_dataloader(dataset=dataset,
batch_size=512,
Expand All @@ -190,12 +228,14 @@ def main(args: Namespace) -> None:
keep_local = True
else:
keep_local = False
with MDSWriter(columns={key: 'str' for key in columns},
with MDSWriter(columns=columns,
out=out,
compression=args.compression,
keep_local=keep_local) as out:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)

if ('prompt'
not in formatted_sample) or ('response'
not in formatted_sample):
Expand All @@ -204,11 +244,31 @@ def main(args: Namespace) -> None:
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {formatted_sample=}.'
)
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns
}
out.write(encoded_sample)
if tokenizer is not None:
sample = tokenize_formatted_example(sample,
tokenizer=tokenizer)
if not is_valid_ift_example(tokenizer.pad_token_id,
args.max_seq_len, sample):
examples_removed += 1
continue

sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
out.write(sample_to_write)
else:
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns.keys()
}
out.write(encoded_sample)
if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit c014bd0

Please sign in to comment.