diff --git a/llmfoundry/cli/cli.py b/llmfoundry/cli/cli.py index 8e86e76467..6c4a2d12c4 100644 --- a/llmfoundry/cli/cli.py +++ b/llmfoundry/cli/cli.py @@ -1,29 +1,53 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Annotated, Optional -import typer +from typer import Argument, Typer -from llmfoundry.cli import registry_cli -from llmfoundry.train import train_from_yaml +from llmfoundry.cli import ( + data_prep_cli, + registry_cli, +) +from llmfoundry.command_utils import ( + eval_from_yaml, + train_from_yaml, +) -app = typer.Typer(pretty_exceptions_show_locals=False) +app = Typer(pretty_exceptions_show_locals=False) app.add_typer(registry_cli.app, name='registry') +app.add_typer(data_prep_cli.app, name='data_prep') @app.command(name='train') def train( - yaml_path: str = typer.Argument( - ..., - help='Path to the YAML configuration file', - ), # type: ignore - args_list: Optional[list[str]] = typer. - Argument(None, help='Additional command line arguments'), # type: ignore + yaml_path: Annotated[str, + Argument( + ..., + help='Path to the YAML configuration file', + )], + args_list: Annotated[ + Optional[list[str]], + Argument(help='Additional command line arguments')] = None, ): """Run the training with optional overrides from CLI.""" train_from_yaml(yaml_path, args_list) +@app.command(name='eval') +def eval( + yaml_path: Annotated[str, + Argument( + ..., + help='Path to the YAML configuration file', + )], + args_list: Annotated[ + Optional[list[str]], + Argument(help='Additional command line arguments')] = None, +): + """Run the eval with optional overrides from CLI.""" + eval_from_yaml(yaml_path, args_list) + + if __name__ == '__main__': app() diff --git a/llmfoundry/cli/data_prep_cli.py b/llmfoundry/cli/data_prep_cli.py new file mode 100644 index 0000000000..3ca53f4104 --- /dev/null +++ b/llmfoundry/cli/data_prep_cli.py @@ -0,0 +1,150 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Annotated, Optional + +import psutil +from typer import Option, Typer + +from llmfoundry.command_utils import ( + convert_dataset_hf_from_args, + convert_dataset_json_from_args, + convert_text_to_mds_from_args, +) + +app = Typer(pretty_exceptions_show_locals=False) + + +@app.command(name='convert_dataset_hf') +def convert_dataset_hf( + dataset: Annotated[str, Option(..., help='Name of the dataset')], + out_root: Annotated[str, Option(..., help='Output root directory')], + data_subset: Annotated[ + Optional[str], + Option(help='Subset of the dataset (e.g., "all" or "en")'), + ] = None, + splits: Annotated[str, + Option(help='Comma-separated list of dataset splits',), + ] = 'train, train_small, val, val_small, val_xsmall', + compression: Annotated[Optional[str], + Option(help='Compression type')] = None, + concat_tokens: Annotated[ + Optional[int], + Option(help='Concatenate tokens up to this many tokens')] = None, + tokenizer: Annotated[Optional[str], + Option(help='Tokenizer name')] = None, + tokenizer_kwargs: Annotated[ + Optional[str], + Option(help='Tokenizer keyword arguments in JSON format')] = None, + bos_text: Annotated[Optional[str], Option(help='BOS text')] = None, + eos_text: Annotated[Optional[str], Option(help='EOS text')] = None, + no_wrap: Annotated[ + bool, + Option(help='Do not wrap text across max_length boundaries'), + ] = False, + num_workers: Annotated[Optional[int], + Option(help='Number of workers')] = None, +): + """Converts dataset from HuggingFace into JSON files.""" + # Convert comma-separated splits into a list + splits_list = splits.split(',') if splits else [] + convert_dataset_hf_from_args( + dataset=dataset, + data_subset=data_subset, + splits=splits_list, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) + + +@app.command(name='convert_dataset_json') +def convert_dataset_json( + path: Annotated[str, Option(..., help='Path to the input data file')], + out_root: Annotated[str, Option(..., help='Output root directory')], + concat_tokens: Annotated[ + int, + Option( + ..., + help='Convert text to tokens and concatenate up to this many tokens', + )], + tokenizer: Annotated[str, Option(..., help='Tokenizer name')], + compression: Annotated[Optional[str], + Option(help='Compression type, if any')] = 'zstd', + split: Annotated[str, Option(help='Dataset split to process')] = 'train', + bos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the beginning of each sequence')] = None, + eos_text: Annotated[ + Optional[str], + Option(help='Text to insert at the end of each sequence')] = None, + no_wrap: Annotated[ + bool, + Option(help='Do not wrap text across max_length boundaries')] = False, + num_workers: Annotated[ + Optional[int], + Option(help='Number of workers for data loading')] = None, +): + """Convert a dataset from JSON to MDS streaming format.""" + convert_dataset_json_from_args( + path=path, + split=split, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) + + +@app.command(name='convert_text_to_mds') +def convert_text_to_mds( + output_folder: Annotated[str, Option(..., help='The folder to write output to')], + input_folder: Annotated[str, Option(..., help='The folder with text files to convert to MDS')], + concat_tokens: Annotated[int, Option(..., help='Convert text to tokens and concatenate up to this many tokens')], + tokenizer: Annotated[str, Option(..., help='The name of the tokenizer to use')], + bos_text: Annotated[Optional[str], Option(help='The text to prepend to each example to separate concatenated examples')] = None, + eos_text: Annotated[Optional[str], Option(help='The text to append to each example to separate concatenated examples')] = None, + compression: Annotated[str, Option(help='The compression algorithm to use for MDS writing')] = 'zstd', + use_tokenizer_eos: Annotated[bool, Option(help='Use the EOS text from the tokenizer')] = False, + no_wrap: Annotated[bool, Option(help='Whether to let text examples wrap across multiple training examples')] = False, + processes: Annotated[int, Option( + help='The number of processes to use to download and convert the dataset', + )] = min(max(psutil.cpu_count() - 2, 1), 32), # type: ignore + reprocess: Annotated[bool, Option( + help= + 'If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters.', + )] = False, + trust_remote_code: Annotated[bool, Option( + help='If true, allows custom code to be executed to load the tokenizer', + )] = False, + logging_level: Annotated[str, Option( + help='Logging level for the script. Default is INFO.', + )] = 'INFO', + +): + """Convert text files to MDS streaming format.""" + convert_text_to_mds_from_args( + output_folder=output_folder, + input_folder=input_folder, + compression=compression, + concat_tokens=concat_tokens, + tokenizer_name=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + use_tokenizer_eos=use_tokenizer_eos, + no_wrap=no_wrap, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + logging_level=logging_level, + ) diff --git a/llmfoundry/cli/registry_cli.py b/llmfoundry/cli/registry_cli.py index 38ada51fd9..db090cd3aa 100644 --- a/llmfoundry/cli/registry_cli.py +++ b/llmfoundry/cli/registry_cli.py @@ -3,15 +3,15 @@ from typing import Optional -import typer from rich.console import Console from rich.table import Table +from typer import Typer from llmfoundry import registry from llmfoundry.utils.registry_utils import TypedRegistry console = Console() -app = typer.Typer(pretty_exceptions_show_locals=False) +app = Typer(pretty_exceptions_show_locals=False) def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]: diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py new file mode 100644 index 0000000000..995c5345e7 --- /dev/null +++ b/llmfoundry/command_utils/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from llmfoundry.command_utils.data_prep.convert_dataset_hf import ( + convert_dataset_hf, + convert_dataset_hf_from_args, +) +from llmfoundry.command_utils.data_prep.convert_dataset_json import ( + convert_dataset_json, + convert_dataset_json_from_args, +) +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( + convert_text_to_mds, + convert_text_to_mds_from_args, +) +from llmfoundry.command_utils.eval import ( + eval_from_yaml, + evaluate, +) +from llmfoundry.command_utils.train import ( + TRAIN_CONFIG_KEYS, + TrainConfig, + train, + train_from_yaml, + validate_config, +) + +__all__ = [ + 'train', + 'train_from_yaml', + 'TrainConfig', + 'TRAIN_CONFIG_KEYS', + 'validate_config', + 'evaluate', + 'eval_from_yaml', + 'convert_dataset_hf', + 'convert_dataset_hf_from_args', + 'convert_dataset_json', + 'convert_dataset_json_from_args', + 'convert_text_to_mds', + 'convert_text_to_mds_from_args', +] diff --git a/llmfoundry/command_utils/data_prep/__init__.py b/llmfoundry/command_utils/data_prep/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py new file mode 100644 index 0000000000..f9bbe6b0cf --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py @@ -0,0 +1,489 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming dataset conversion scripts for C4 and The Pile.""" +import json +import os +import platform +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, Optional, Union + +import datasets as hf_datasets +import psutil +import torch +from numpy.typing import NDArray +from streaming import MDSWriter +from torch.utils.data import DataLoader, Dataset, IterableDataset +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +from llmfoundry.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.utils.builders import build_tokenizer + + +class ConcatMode(Enum): + NO_CONCAT = 'NO_CONCAT' + CONCAT_TOKENS = 'CONCAT_TOKENS' + + +@dataclass +class DataSplitConstants: + hf_split: str + folder_split: str + raw_samples: Optional[int] + truncated_samples: Union[int, None] + + +@dataclass +class DatasetConstants: + chars_per_sample: int + chars_per_token: int + splits: Dict[str, DataSplitConstants] = field(default_factory=dict) + + def __iter__(self): + for v in self.splits.values(): + yield v + + +class TrainSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'train', + folder_split: str = 'train_small', + raw_samples: int = 100000, + truncated_samples: int = 100000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +class ValSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_small', + raw_samples: int = 10000, + truncated_samples: int = 10000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +class ValXSmallConstants(DataSplitConstants): + + def __init__( + self, + hf_split: str = 'validation', + folder_split: str = 'val_xsmall', + raw_samples: int = 3000, + truncated_samples: int = 3000, + ): + super().__init__(hf_split, folder_split, raw_samples, truncated_samples) + + +pileconstants = DatasetConstants( + chars_per_sample=6212, # Computed over validation set + chars_per_token=4, # OpenAI estimate +) +pileconstants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=210607728, + truncated_samples=None, +) +pileconstants.splits['train_small'] = DataSplitConstants( + hf_split='train', + folder_split='train_small', + raw_samples=100000, + truncated_samples=100000, +) +pileconstants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=214670, + truncated_samples=None, +) +pileconstants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) +pileconstants.splits['val_xsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xsmall', + raw_samples=3000, + truncated_samples=3000, +) + +c4constants = DatasetConstants( + chars_per_sample=2163, # Computed over validation set + chars_per_token=4, # OpenAI estimate +) +c4constants.splits['train'] = DataSplitConstants( + hf_split='train', + folder_split='train', + raw_samples=364868892, + truncated_samples=None, +) +c4constants.splits['train_small'] = DataSplitConstants( + hf_split='train', + folder_split='train_small', + raw_samples=100000, + truncated_samples=100000, +) +c4constants.splits['val'] = DataSplitConstants( + hf_split='validation', + folder_split='val', + raw_samples=364608, + truncated_samples=None, +) +c4constants.splits['val_small'] = DataSplitConstants( + hf_split='validation', + folder_split='val_small', + raw_samples=10000, + truncated_samples=10000, +) +c4constants.splits['val_xsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xsmall', + raw_samples=3000, + truncated_samples=3000, +) +c4constants.splits['val_xxsmall'] = DataSplitConstants( + hf_split='validation', + folder_split='val_xxsmall', + raw_samples=100, + truncated_samples=100, +) + +CONSTS = {'c4': c4constants, 'the_pile': pileconstants} + + +def build_hf_dataset( + dataset_name: str, + split: str, + mode: ConcatMode, + max_length: Optional[int] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + tokenizer: PreTrainedTokenizerBase = None, + data_subset: Union[str, None] = None, +) -> IterableDataset: + """Build an IterableDataset over the HF C4 or pile source data. + + Args: + dataset_name (str): Dataset name + split (str): Split name. + mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS + max_length (int): The length of concatenated tokens + bos_text (str): text to insert at the beginning of each sequence + eos_text (str): text to insert at the end of each sequence + no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries + tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use + data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. + Typically "all" (The Pile) or "en" (c4). + + Returns: + An IterableDataset. + """ + hf_dataset = hf_datasets.load_dataset( + path=dataset_name, + name=data_subset, + split=split, + streaming=True, + ) + if mode == ConcatMode.NO_CONCAT: + dataset = NoConcatDataset(hf_dataset) + else: + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) + if max_length is None: + raise ValueError(f'max_length must be set.') + if bos_text + eos_text == '': + test_tokens = tokenizer('test') + if test_tokens['input_ids'][ + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: + tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' + tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' + tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' + tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' + tok_error_msg += '--bos_text=<|endoftext|>.' + raise ValueError(tok_error_msg) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) + return dataset + + +def _est_progress_denominator( + total_samples: int, + chars_per_sample: int, + chars_per_token: int, + mode: ConcatMode, + max_length: int, +): + est_tokens_per_sample = chars_per_sample // chars_per_token + if mode == ConcatMode.NO_CONCAT: + return total_samples + elif mode == ConcatMode.CONCAT_TOKENS: + return total_samples * est_tokens_per_sample // max_length + + +def build_dataloader( + dataset: Dataset, + batch_size: int, + num_workers: Optional[int], +) -> DataLoader: + if num_workers is None: + # Multiple workers is only supported on linux machines + if 'linux' or 'macos' in platform.platform().lower(): + num_workers = max(1, psutil.cpu_count()) + else: + num_workers = 0 + + # If using multiple workers, configure each worker to prefetch as many samples as it can, up to + # the aggregate device batch size + # If not using workers, the torch DataLoader expects the default value for prefetch_factor, + # which non-intuitively must be 2. + prefetch_factor = max( + 1, + 2 * batch_size // num_workers, + ) if num_workers > 0 else 2 + + return DataLoader( + dataset=dataset, + sampler=None, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + + +def generate_samples( + loader: DataLoader, + truncate_num_samples: Optional[int] = None, +) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: + """Generator over samples of a dataloader. + + Args: + loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} + truncate_num_samples (Optional[int]): An optional # of samples to stop at. + + Yields: + Sample dicts. + """ + n_samples = 0 + for batch in loader: + keys = list(batch.keys()) + current_bs = len(batch[keys[0]]) + for idx in range(current_bs): + if truncate_num_samples is not None and n_samples == truncate_num_samples: + return + n_samples += 1 + yield { + k: + v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx] + for k, v in batch.items() + } + + +def convert_dataset_hf( + dataset: str, + data_subset: Optional[str], + splits: list[str], + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: dict[str, Any], + bos_text: str, + eos_text: str, + no_wrap: bool, + num_workers: Optional[int], +) -> None: + """Converts HuggingFace datasets to MDS format. + + Args: + dataset (str): Name of the dataset + data_subset (Optional[str]): Subset of the dataset (e.g., "all" or "en") + splits (list[str]): Comma-separated list of dataset splits + out_root (str): Output root directory + compression (Optional[str]): Compression type + concat_tokens (Optional[int]): Concatenate tokens up to this many tokens + tokenizer (Optional[str]): Tokenizer name + tokenizer_kwargs (dict[str, Any]): Tokenizer keyword arguments + bos_text (str): BOS text + eos_text (str): EOS text + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers + + Raises: + KeyError: If constants are not defined for the split + """ + try: + dataset_constants = CONSTS[dataset] + except KeyError: + raise ValueError( + f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "c4" are supported.', + ) + + if concat_tokens is not None and tokenizer is not None: + mode = ConcatMode.CONCAT_TOKENS + built_tokenizer = build_tokenizer(tokenizer, tokenizer_kwargs) + # we will enforce length, so suppress warnings about sequences too long for the model + built_tokenizer.model_max_length = int(1e30) + columns = {'tokens': 'ndarray:int32'} + else: + mode = ConcatMode.NO_CONCAT + built_tokenizer = None + columns = {'text': 'str'} + + for split_name in splits: + try: + split = dataset_constants.splits[split_name] + except KeyError: + raise KeyError(f'Constants not defined for split {split_name}.') + hf_split = split.hf_split + folder_split = split.folder_split + expected_num_samples = split.raw_samples + truncate_num_samples = split.truncated_samples + # Only generate the splits requested + if folder_split not in splits: + continue + + # Get samples + hf_dataset = build_hf_dataset( + dataset_name=dataset, + data_subset=data_subset, + split=hf_split, + mode=mode, + max_length=concat_tokens, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + tokenizer=built_tokenizer, + ) + loader = build_dataloader( + dataset=hf_dataset, + batch_size=512, + num_workers=num_workers, + ) + samples = generate_samples( + loader, + truncate_num_samples=truncate_num_samples, + ) + + if expected_num_samples is not None and concat_tokens is not None: + denominator = truncate_num_samples if truncate_num_samples is not None else _est_progress_denominator( + total_samples=expected_num_samples, + chars_per_sample=dataset_constants.chars_per_sample, + chars_per_token=dataset_constants.chars_per_token, + mode=mode, + max_length=concat_tokens, + ) + else: + denominator = None + + # Write samples + print(f'Converting {folder_split} to MDS format...') + print( + f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.', + ) + with MDSWriter( + columns=columns, + out=os.path.join(out_root, folder_split), + compression=compression, + ) as out: + if denominator is not None: + for sample in tqdm( + samples, + desc=folder_split, + total=denominator, + ): + out.write(sample) + else: + for sample in tqdm(samples, desc=folder_split): + out.write(sample) + + +def convert_dataset_hf_from_args( + dataset: str, + data_subset: Optional[str], + splits: list[str], + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + tokenizer: Optional[str], + tokenizer_kwargs: Optional[str], + bos_text: Optional[str], + eos_text: Optional[str], + no_wrap: bool, + num_workers: Optional[int], +) -> None: + """A wrapper for `convert_dataset_hf` that parses arguments. + + Args: + dataset (str): Name of the dataset + data_subset (Optional[str]): Subset of the dataset (e.g., "all" or "en") + splits (list[str]): Comma-separated list of dataset splits + out_root (str): Output root directory + compression (Optional[str]): Compression type + concat_tokens (Optional[int]): Concatenate tokens up to this many tokens + tokenizer (Optional[str]): Tokenizer name + tokenizer_kwargs (Optional[str]): Tokenizer keyword arguments in JSON format + bos_text (Optional[str]): BOS text + eos_text (Optional[str]): EOS text + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers + + Raises: + ValueError: If the output directory already contains the requested splits + ValueError: If `concat_tokens` is set but `tokenizer` is not + """ + if tokenizer_kwargs: + parsed_tokenizer_kwargs = json.loads(tokenizer_kwargs) + else: + parsed_tokenizer_kwargs = {} + + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(splits)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {splits}.', + ) + + # Make sure we have needed concat options + if ( + concat_tokens is not None and isinstance(concat_tokens, int) and + tokenizer is None + ): + raise ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + + # now that we have validated them, change BOS/EOS to strings and convert + convert_dataset_hf( + dataset=dataset, + data_subset=data_subset, + splits=splits, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + tokenizer=tokenizer, + tokenizer_kwargs=parsed_tokenizer_kwargs, + bos_text=bos_text if bos_text else '', + eos_text=eos_text if eos_text else '', + no_wrap=no_wrap, + num_workers=num_workers, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py new file mode 100644 index 0000000000..9f174d1aaf --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -0,0 +1,222 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming dataset conversion scripts for json files.""" +import os +from enum import Enum +from glob import glob +from typing import Optional + +import datasets as hf_datasets +from streaming import MDSWriter +from torch.utils.data import IterableDataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data import ConcatTokensDataset, NoConcatDataset + + +class ConcatMode(Enum): + NO_CONCAT = 'NO_CONCAT' + CONCAT_TOKENS = 'CONCAT_TOKENS' + + +def build_hf_dataset( + path: str, + split: str, + mode: ConcatMode, + max_length: Optional[int] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + tokenizer: PreTrainedTokenizerBase = None, +) -> IterableDataset: + """Build an IterableDataset over the HF C4 or pile source data. + + Args: + dataset_name (str): Dataset name + split (str): Split name. + mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS + max_length (int): The length of concatenated tokens + bos_text (str): text to insert at the beginning of each sequence + eos_text (str): text to insert at the end of each sequence + no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries + tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use + data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. + Typically "all" (The Pile) or "en" (c4). + + Returns: + An IterableDataset. + """ + if os.path.isdir(path): + data_files = glob(f'{path}/*') + else: + data_files = path + + hf_dataset = hf_datasets.load_dataset( + 'json', + data_files=data_files, + split=split, + ) + + if mode == ConcatMode.NO_CONCAT: + dataset = NoConcatDataset(hf_dataset) + else: + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + f'{tokenizer=} must be of type PreTrainedTokenizerBase', + ) + if max_length is None: + raise ValueError(f'max_length must be set.') + if bos_text + eos_text == '': + test_tokens = tokenizer('test') + if test_tokens['input_ids'][ + 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ + -1] != tokenizer.eos_token_id: + tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' + tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' + tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' + tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' + tok_error_msg += '--bos_text=<|endoftext|>.' + raise ValueError(tok_error_msg) + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + tokenizer=tokenizer, + max_length=max_length, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + ) + return dataset + + +def convert_dataset_json( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: str = '', + eos_text: str = '', + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """Create C4/pile streaming dataset. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (str): Text to insert at the beginning of each sequence + eos_text (str): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + """ + if concat_tokens is not None: + mode = ConcatMode.CONCAT_TOKENS + built_tokenizer = AutoTokenizer.from_pretrained(tokenizer) + # we will enforce length, so suppress warnings about sequences too long for the model + built_tokenizer.model_max_length = int(1e30) + columns = {'tokens': 'ndarray:int32'} + else: + mode = ConcatMode.NO_CONCAT + built_tokenizer = None + columns = {'text': 'str'} + + # Get samples + dataset = build_hf_dataset( + path=path, + split=split, + mode=mode, + max_length=concat_tokens, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + tokenizer=built_tokenizer, + ) + + print('here') + + # Write samples + print(f'Converting to MDS format...') + print( + f'Note that the progress bar is based on the dataset length before tokenization.', + ) + print(f'It will finish at a value below 100% if tokenizing') + with MDSWriter( + columns=columns, + out=os.path.join(out_root), + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + +def convert_dataset_json_from_args( + path: str, + out_root: str, + compression: Optional[str], + concat_tokens: Optional[int], + split: str, + tokenizer: Optional[str] = None, + bos_text: Optional[str] = None, + eos_text: Optional[str] = None, + no_wrap: bool = False, + num_workers: Optional[int] = None, +) -> None: + """A wrapper for `convert_dataset_json` that parses arguments. + + Args: + path (str): Path to the input data file + out_root (str): Output root directory + compression (Optional[str]): Compression type, if any + concat_tokens (Optional[int]): Convert text to tokens and concatenate up to this many tokens + split (str): Dataset split to process + tokenizer (Optional[str]): Tokenizer name + bos_text (Optional[str]): Text to insert at the beginning of each sequence + eos_text (Optional[str]): Text to insert at the end of each sequence + no_wrap (bool): Do not wrap text across max_length boundaries + num_workers (Optional[int]): Number of workers for data loading + + Raises: + ValueError: If the out_root directory exists and contains files that overlap with the requested splits + ValueError: If concat_tokens is set and a tokenizer is not provided + """ + if os.path.isdir(out_root) and len( + set(os.listdir(out_root)).intersection(set(split)), + ) > 0: + raise ValueError( + f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {split}.', + ) + + # Make sure we have needed concat options + if ( + concat_tokens is not None and isinstance(concat_tokens, int) and + tokenizer is None + ): + ValueError( + 'When setting --concat_tokens, you must specify a --tokenizer', + ) + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + + convert_dataset_json( + path=path, + out_root=out_root, + compression=compression, + concat_tokens=concat_tokens, + split=split, + tokenizer=tokenizer, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=no_wrap, + num_workers=num_workers, + ) diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py new file mode 100644 index 0000000000..14afe279fd --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -0,0 +1,582 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import math +import os +import tempfile +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from glob import glob +from typing import Dict, Iterable, List, Optional, Tuple, cast + +import numpy as np +from composer.utils import ( + ObjectStore, + maybe_create_object_store_from_uri, + parse_uri, +) +from numpy.typing import NDArray +from streaming import MDSWriter +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.data.data import AbstractConcatTokensDataset +from llmfoundry.utils.data_prep_utils import ( + DownloadingIterable, + download_file, + merge_shard_groups, +) +from llmfoundry.utils.exceptions import ( + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) + +log = logging.getLogger(__name__) + +DONE_FILENAME = '.text_to_mds_conversion_done' + + +class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): + """An IterableDataset that returns token samples for MDSWriter from files. + + Returns dicts of {'tokens': ndarray:int32} + + Each file is considered a sequence. + """ + + def __init__( + self, + files: Iterable[str], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + bos_text: str, + eos_text: str, + no_wrap: bool, + ): + self.files = files + super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) + log.info(f'Initialized ConcatTokensFromFilesDataset.') + + def __iter__(self) -> Iterable[Dict[str, NDArray]]: + log.info( + 'Starting iteration over files in ConcatTokensFromFilesDataset', + ) + buffer = [] + for file in self.files: + log.info(f'Processing file: {file}') + with open(file, 'r') as f: + buffer += self.bos_tokens + first_chunk = True + # Read the file in 1MB chunks to avoid memory issues + for chunk in iter(partial(f.read, 1000000), ''): + # Tokenize the chunk + encoded = self.tokenizer( + chunk, + truncation=False, + padding=False, + ) + iids = encoded['input_ids'] + + # If this is not the first chunk, remove the BOS token + if not first_chunk: + if iids[0] == self.tokenizer.bos_token_id: + iids = iids[1:] + + # Add the tokens to the buffer + buffer += iids + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self. + max_length:] if self.should_wrap else [] + yield { + 'tokens': np.asarray(concat_sample, dtype=np.int32), + } + + first_chunk = False + + # Add the EOS token to the buffer to separate files. + buffer += self.eos_tokens + + # Yield any remaining samples of size max_length. + while len(buffer) >= self.max_length: + concat_sample = buffer[:self.max_length] + buffer = buffer[self.max_length:] if self.should_wrap else [] + yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} + + log.info( + 'Finished iterating over files in ConcatTokensFromFilesDataset', + ) + + +def get_object_names(input_folder: str) -> List[str]: + """Get object names from a local or remote folder. + + Args: + input_folder (str): local or remote folder path. + """ + object_store = maybe_create_object_store_from_uri(input_folder) + if object_store is not None: + _, _, folder_prefix = parse_uri(input_folder) + names = [ + name for name in object_store.list_objects(folder_prefix) + if name.endswith('.txt') + ] + log.info(f'Found {len(names)} text files in remote storage') + else: + # input_folder is a local folder + names = [ + text_file for dirpath, _, _ in os.walk(input_folder) + for text_file in glob(os.path.join(dirpath, '*.txt')) + ] + # return names, sizes + log.info(f'Found {len(names)} text files at {input_folder}') + + return names + + +def get_task_args( + object_names: List[str], + output_root: str, + input_folder: str, + n_groups: int, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +) -> Iterable: + """Get download_and_convert arguments split across n_groups. + + Each group handles a portion of object_names. + + Args: + object_names (List[str]): Names of objects to process + output_root (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + n_groups (int): Number of groups to split the object names into + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info( + f'Preparing task arguments for {len(object_names)} objects across {n_groups} groups', + ) + num_objects = len(object_names) + objs_per_group = math.ceil(num_objects / n_groups) + for group, i in enumerate(range(0, num_objects, objs_per_group)): + output_subdir = os.path.join(output_root, str(group)) + log.info( + f'Created task for group {group} with {min(objs_per_group, num_objects - i)} objects', + ) + yield ( + object_names[i:min(i + objs_per_group, num_objects)], + output_subdir, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + +def download_and_convert_starargs(args: Tuple): + """Helper function to call download_and_convert with star args. + + This helps us use download_and_convert with multiprocessing. + """ + return download_and_convert(*args) + + +def download_and_convert( + file_names: List[str], + output_folder: str, + input_folder: str, + tokenizer_name: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + trust_remote_code: bool, +): + """Downloads and converts text files to MDS format. + + Args: + file_names (List[str]): Files to process + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + tokenizer_name (str): Name of tokenizer to use + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + log.info(f'Starting download and conversion for {len(file_names)} files') + + object_store = maybe_create_object_store_from_uri(input_folder) + + # Download file_names + with tempfile.TemporaryDirectory() as tmp_dir: + log.info(f'Created temporary directory: {tmp_dir}') + downloading_iter = DownloadingIterable( + object_names=file_names, + output_folder=tmp_dir, + object_store=object_store, + ) + log.info(f'Initializing tokenizer: {tokenizer_name}') + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + + # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up + # to the maximum sequence length + dataset = ConcatTokensFromFilesDataset( + files=downloading_iter, + max_length=concat_tokens, + tokenizer=tokenizer, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + ) + + columns = {'tokens': 'ndarray:int32'} + + log.info('Converting to MDS format...') + with MDSWriter( + out=output_folder, + columns=columns, + compression=compression, + ) as out: + for sample in tqdm(dataset): + out.write(sample) + + log.info(f'Completed download and conversion for {len(file_names)} files') + + +def is_remote_path(path: str) -> bool: + """Checks whether a path is a remote path. + + Args: + path (str): path to check + """ + backend, _, _ = parse_uri(path) + return backend != '' + + +def is_already_processed( + output_root: str, + args_str: str, + object_names: List[str], +) -> bool: + """Determines whether a group of text files has already been processed. + + Checks the done fie at output root to determine this. + + Args: + output_root (str): Output folder where a done file may exist + args_str (str): String representation of the arguments + object_names (List[str]): Names of objects to convert to MDS format + """ + log.info( + f'Checking if {len(object_names)} objects have already been processed in {output_root}', + ) + + # Retrieve the done file contents + output_object_store = maybe_create_object_store_from_uri(output_root) + if output_object_store is not None: + # Download and read the done file from the remote object store + _, _, output_folder_prefix = parse_uri(output_root) + try: + with tempfile.TemporaryDirectory() as tmp_dir: + done_file = os.path.join(tmp_dir, DONE_FILENAME) + download_file( + object_store=output_object_store, + object_name=os.path.join( + output_folder_prefix, + DONE_FILENAME, + ), + output_filename=done_file, + ) + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from remote storage') + except FileNotFoundError: + log.info('Done file not found in remote storage') + return False + else: + # Read the local done file + done_file = os.path.join(output_root, DONE_FILENAME) + if not os.path.isfile(done_file): + log.info('Done file not found in local storage') + return False + with open(done_file) as df: + done_file_contents = df.read().splitlines() + log.info(f'Retrieved done file contents from local storage') + + # Compare the arguments + prev_args_str = done_file_contents[0] + if prev_args_str != args_str: + log.info('Arguments have changed, reprocessing required') + return False + + # Compare file names + prev_names = done_file_contents[1:] + if len(prev_names) != len(object_names): + log.info('Number of files has changed, reprocessing required') + return False + for idx, prev_name in enumerate(prev_names): + if object_names[idx] != prev_name: + log.info('File names have changed, reprocessing required') + return False + + log.info('All files have already been processed') + return True + + +def write_done_file(folder: str, args_str: str, object_names: List[str]): + """Write a file to signify completion. + + This the done file includes the arguments to processing and + a list of objects that were processed. + + Args: + folder (str): Folder to write the done file to + args_str (str): String representation of arguments + object_names (List[str]): List of objects to convert to MDS format + """ + with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: + log.info(f'Writing done file.') + done_file.write('\n'.join([args_str] + object_names) + '\n') + log.info(f'Done file written successfully') + + +def convert_text_to_mds( + tokenizer_name: str, + output_folder: str, + input_folder: str, + concat_tokens: int, + eos_text: str, + bos_text: str, + no_wrap: bool, + compression: str, + processes: int, + args_str: str, + reprocess: bool, + trust_remote_code: bool, +): + """Convert a folder of text files to MDS format. + + Args: + tokenizer_name (str): Name of tokenizer to use + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + concat_tokens (int): Concatenate up to this many tokens + eos_text (str): Text to append to each example to separate concatenated samples + bos_text (str): Text to prepend to each example to separate concatenated samples + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + compression (str): The compression algorithm to use for MDS writing + processes (int): The number of processes to use. + args_str (str): String representation of the arguments + reprocess (bool): Whether to always reprocess the given folder of text files + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + """ + is_remote_output = is_remote_path(output_folder) + log.info(f'Output is remote: {is_remote_output}') + + object_names = get_object_names(input_folder) + if len(object_names) == 0: + log.error(f'No text files found in input folder: {input_folder}') + raise InputFolderMissingDataError(input_folder) + + # Check if the text files in the bucket have already been processed. + if not reprocess and is_already_processed( + output_folder, + args_str, + object_names, + ): + log.info( + f'Input folder {input_folder} is already processed at {output_folder} and ' + + + 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', + ) + return + + # Use a temporary local directory if the output is remote and there are more than 1 processes + local_output_folder = tempfile.TemporaryDirectory( + ).name if is_remote_output else output_folder + log.info(f'Using local output folder: {local_output_folder}') + + if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: + log.error(f'Output folder is not empty: {output_folder}') + raise OutputFolderNotEmptyError(output_folder) + + if processes > 1: + log.info(f'Using multiprocessing with {processes} processes') + # Download and convert the text files in parallel + args = get_task_args( + object_names, + local_output_folder, + input_folder, + processes, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_and_convert_starargs, args)) + + log.info('Merging MDS shards from each process') + # Merge the mds shards from each of the processes into a single folder + merge_shard_groups(local_output_folder) + else: + log.info('Using single process for download and conversion') + download_and_convert( + object_names, + local_output_folder, + input_folder, + tokenizer_name, + concat_tokens, + eos_text, + bos_text, + no_wrap, + compression, + trust_remote_code, + ) + + # Write a done file with the args and object names + write_done_file(local_output_folder, args_str, object_names) + + if is_remote_output: + # Upload the local output to the remote location + output_object_store = cast( + ObjectStore, + maybe_create_object_store_from_uri(output_folder), + ) + _, _, output_folder_prefix = parse_uri(output_folder) + files_to_upload = os.listdir(local_output_folder) + + for file in files_to_upload: + assert not os.path.isdir(file) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object( + remote_path, + os.path.join(local_output_folder, file), + ) + + +def _configure_logging(logging_level: str): + """Configure logging. + + Args: + logging_level (str): Logging level. + """ + logging.basicConfig( + format= + f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging_level = logging_level.upper() + logging.getLogger('llmfoundry').setLevel(logging_level) + logging.getLogger(__name__).setLevel(logging_level) + log.info(f'Logging level set to {logging_level}') + + +def convert_text_to_mds_from_args( + output_folder: str, + input_folder: str, + compression: str, + concat_tokens: int, + tokenizer_name: str, + bos_text: Optional[str], + eos_text: Optional[str], + use_tokenizer_eos: bool, + no_wrap: bool, + processes: int, + reprocess: bool, + trust_remote_code: bool, + logging_level: str, +) -> None: + """A wrapper for `convert_text_to_mds` to parse arguments. + + Args: + output_folder (str): Folder to write MDS shards to + input_folder (str): Folder of text files to process + compression (str): The compression algorithm to use for MDS writing + concat_tokens (int): Concatenate up to this many tokens + tokenizer_name (str): The name of the tokenizer to use + bos_text (Optional[str]): The text to prepend to each example to separate concatenated examples + eos_text (Optional[str]): The text to append to each example to separate concatenated examples + use_tokenizer_eos (bool): Use the EOS text from the tokenizer + no_wrap (bool): Whether to let text examples wrap across multiple training examples + processes (int): The number of processes to use to download and convert the dataset + reprocess (bool): If true, reprocess the input_folder to MDS format. Otherwise, only reprocess upon changes to the input folder or dataset creation parameters. + trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer + logging_level (str): Logging level for the script. Default is INFO. + + Raises: + ValueError: If `use_tokenizer_eos` is True and `eos_text` is not None + """ + if use_tokenizer_eos: + # Ensure that eos text is not specified twice. + if eos_text is not None: + ValueError( + 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', + ) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + ) + eos_text = tokenizer.eos_token + + # now that we have validated them, change BOS/EOS to strings + if bos_text is None: + bos_text = '' + if eos_text is None: + eos_text = '' + _configure_logging(logging_level) + + # Define args for _args_str + args = { + 'tokenizer': tokenizer_name, + 'output_folder': output_folder, + 'input_folder': input_folder, + 'compression': compression, + 'concat_tokens': concat_tokens, + 'eos_text': eos_text, + 'bos_text': bos_text, + 'no_wrap': no_wrap, + 'processes': processes, + 'reprocess': reprocess, + 'trust_remote_code': trust_remote_code, + } + convert_text_to_mds( + tokenizer_name=tokenizer_name, + output_folder=output_folder, + input_folder=input_folder, + concat_tokens=concat_tokens, + eos_text=eos_text, + bos_text=bos_text, + no_wrap=no_wrap, + compression=compression, + processes=processes, + reprocess=reprocess, + trust_remote_code=trust_remote_code, + args_str=str(args), + ) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py new file mode 100644 index 0000000000..630d418ad0 --- /dev/null +++ b/llmfoundry/command_utils/eval.py @@ -0,0 +1,496 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from typing import Any, Dict, Optional, Tuple, Union + +import pandas as pd +import torch +from composer.core import Callback +from composer.loggers.logger_destination import LoggerDestination +from composer.trainer import Trainer +from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry.utils import ( + find_mosaicml_logger, + log_eval_analytics, + maybe_create_mosaicml_logger, +) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_callback, + build_composer_model, + build_evaluators, + build_logger, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + EVAL_CONFIG_KEYS, + EvalConfig, + log_config, + make_dataclass_and_log_config, + process_init_device, +) +from llmfoundry.utils.registry_utils import import_file + +log = logging.getLogger(__name__) + + +def evaluate_model( + tokenizer: Dict[str, Any], + model_name: str, + model: Dict[str, Any], + dist_timeout: Union[float, int], + run_name: str, + seed: int, + icl_tasks: Union[str, list[Dict[str, Any]]], + max_seq_len: int, + device_eval_batch_size: Union[int, float], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], + eval_loader_config: Optional[Union[Dict[str, Any], list[Dict[str, Any]]]], + fsdp_config: Optional[Dict[str, Any]], + loggers: list[LoggerDestination], + python_log_level: Optional[str], + precision: str, + eval_gauntlet_df: Optional[pd.DataFrame], + eval_subset_num_batches: int, + icl_subset_num_batches: Optional[int], + callback_configs: Optional[Dict[str, Any]], + metadata: Optional[Dict[str, str]], + logged_config: Dict[str, Any], + should_log_config: bool = True, + load_path: Optional[str] = None, +): + log.info(f'Evaluating model: {model_name}') + # Build tokenizer and model + tokenizer_cfg = tokenizer + tokenizer_name = tokenizer_cfg['name'] + tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=max_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) + + # Callbacks + callbacks: list[Callback] = [ + build_callback(name=str(name), kwargs=callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else [] + + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) + + if metadata is not None: + # Find the MosaicMLLogger + mosaicml_logger = find_mosaicml_logger(loggers) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + + if fsdp_config and model.get('load_in_8bit', False): + raise ValueError( + 'The FSDP config block is not supported when loading ' + + 'Hugging Face models in 8bit.', + ) + + init_context = process_init_device(model, fsdp_config) + + name = model.pop('name') + composer_model = build_composer_model( + name=name, + tokenizer=tokenizer, + init_context=init_context, + cfg=model, + ) + + # Now add the eval metrics + if eval_loader_config is not None: + train_metrics = composer_model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders( + evaluators, + list(train_metrics.keys()), + ) + + if eval_gauntlet_df is None and eval_gauntlet_callback is not None: + eval_gauntlet_df = pd.DataFrame( + columns=['model_name'] + list(eval_gauntlet_callback.averages) + + [t['name'] for t in eval_gauntlet_callback.categories], + ) + + if name == 'mpt_causal_lm' and load_path is None: + raise ValueError( + 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' + + + ' Please check your yaml and the model_cfg to ensure that load_path is set.', + ) + + assert composer_model is not None + + log.info(f'Building trainer for {model_name}...') + trainer = Trainer( + run_name=run_name, + seed=seed, + model=composer_model, + callbacks=callbacks, + loggers=loggers, + precision=precision, + fsdp_config=fsdp_config, + load_path=load_path, + load_weights_only=True, + progress_bar=False, + log_to_console=True, + dist_timeout=dist_timeout, + python_log_level=python_log_level, + ) + + if should_log_config: + log.info('Evaluation config:') + log_config(logged_config) + + log.info(f'Starting eval for {model_name}...') + if torch.cuda.is_available(): + torch.cuda.synchronize() + a = time.time() + trainer.eval( + eval_dataloader=evaluators, + subset_num_batches=eval_subset_num_batches, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + b = time.time() + + log.info(f'Ran {model_name} eval in: {b-a} seconds') + return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) + + +def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Transform the config to allow top-level keys for model configuration. + + This function allows users to use the 'train.py' syntax in 'eval.py'. + It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys + into the nested 'models' list format required by 'eval.py'. + + Input config format (train.py style): + ```yaml + model: + + load_path: /path/to/checkpoint + tokenizer: + + ``` + + Output config format (eval.py style): + ```yaml + models: + - model: + + tokenizer: + + load_path: /path/to/checkpoint + ``` + """ + if 'model' in cfg: + if 'models' in cfg: + raise ValueError( + 'Please specify either model or models in the config, not both', + ) + default_name = cfg.get('model').get('name') + model_cfg = { + 'model': cfg.pop('model'), + 'tokenizer': cfg.pop('tokenizer', None), + 'model_name': cfg.pop('model_name', default_name), + } + if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None: + raise ValueError( + 'When specifying model, "tokenizer" must be provided in the config', + ) + if 'load_path' in cfg: + model_cfg['load_path'] = cfg.pop('load_path') + cfg['models'] = [model_cfg] + + return cfg + + +def evaluate(cfg: DictConfig) -> Tuple[list[Trainer], pd.DataFrame]: + # Run user provided code if specified + for code_path in cfg.get('code_paths', []): + import_file(code_path) + + logged_cfg, eval_config = make_dataclass_and_log_config( + cfg, + EvalConfig, + EVAL_CONFIG_KEYS, + transforms=[allow_toplevel_keys], + icl_tasks_required=True, + ) + + model_configs = eval_config.models + eval_gauntlet_config = eval_config.eval_gauntlet or eval_config.eval_gauntlet_str + + fsdp_config = eval_config.fsdp_config + + # Mandatory Evaluation Parameters + icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str + if icl_tasks is None: + raise ValueError('icl_tasks must be specified in the config') + + # Optional Evaluation Parameters with default values + eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders + default_run_name: str = os.environ.get('RUN_NAME', 'llm') + run_name = eval_config.run_name if eval_config.run_name else default_run_name + + reproducibility.seed_all(eval_config.seed) + dist.initialize_dist(get_device(None), timeout=eval_config.dist_timeout) + + if eval_config.python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging.getLogger('llmfoundry').setLevel( + eval_config.python_log_level.upper(), + ) + + # Default argument values for evaluate_model + eval_gauntlet_df = None + models_df = None + composite_scores = None + trainers = [] + + # Build loggers + loggers: list[LoggerDestination] = [ + build_logger(name, logger_cfg) + for name, logger_cfg in (eval_config.loggers or {}).items() + ] + + mosaicml_logger = find_mosaicml_logger(loggers) + if mosaicml_logger is None: + mosaicml_logger = maybe_create_mosaicml_logger() + # mosaicml_logger will be None if run isn't on MosaicML platform + if mosaicml_logger is not None: + loggers.append(mosaicml_logger) + + # mosaicml_logger will be None if the run isn't from the MosaicML platform + if mosaicml_logger is not None: + log_eval_analytics( + mosaicml_logger, + model_configs, + icl_tasks, + eval_gauntlet_config, + ) + + for model_cfg in model_configs: + + attn_config = model_cfg['model'].get('attn_config', None) + if attn_config is not None: + seq_parallel_world_size = attn_config.get( + 'seq_parallel_world_size', + None, + ) + if seq_parallel_world_size is not None and seq_parallel_world_size != 1: + raise ValueError( + 'Offline eval does not support sequence parallelism.', + ) + + (trainer, logger_keys, eval_gauntlet_callback, + eval_gauntlet_df) = evaluate_model( + dist_timeout=eval_config.dist_timeout, + run_name=run_name, + seed=eval_config.seed, + icl_tasks=icl_tasks, + max_seq_len=eval_config.max_seq_len, + device_eval_batch_size=eval_config.device_eval_batch_size, + eval_gauntlet_config=eval_gauntlet_config, + eval_loader_config=eval_loader_config, + fsdp_config=fsdp_config, + loggers=loggers, + python_log_level=eval_config.python_log_level, + precision=eval_config.precision, + eval_gauntlet_df=eval_gauntlet_df, + callback_configs=eval_config.callbacks, + eval_subset_num_batches=eval_config.eval_subset_num_batches, + icl_subset_num_batches=eval_config.icl_subset_num_batches, + metadata=eval_config.metadata, + logged_config=logged_cfg, + should_log_config=eval_config.log_config, + **model_cfg, + ) + trainers.append(trainer) + + if eval_gauntlet_callback is not None: + composite_scores = eval_gauntlet_callback.eval_after_all( + trainer.state, + trainer.logger, + ) + + benchmark_to_taxonomy = {} + if eval_gauntlet_callback is not None: + for t in eval_gauntlet_callback.categories: + for b in t['benchmarks']: + benchmark_to_taxonomy[b['name']] = t['name'] + + assert 'model_name' in model_cfg, 'model_name must be specified in model config' + model_results = calculate_markdown_results( + logger_keys, + trainer, + benchmark_to_taxonomy, + model_cfg['model_name'], + ) + + if models_df is None: + models_df = model_results + else: + models_df = pd.concat([models_df, model_results], ignore_index=True) + + if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: + assert composite_scores is not None + row = {'model_name': model_cfg['model_name']} + row.update({ + k.split('/')[-1]: v for k, v in composite_scores.items() + }) + eval_gauntlet_df = pd.concat([ + eval_gauntlet_df, + pd.DataFrame([row]), + ], + ignore_index=True) + + print(f'Printing gauntlet results for all models') + + print( + eval_gauntlet_df.sort_values( + list(eval_gauntlet_callback.averages.keys())[0], + ascending=False, + ).to_markdown(index=False), + ) + print(f'Printing complete results for all models') + assert models_df is not None + print(models_df.to_markdown(index=False)) + + trainer.close() + + return trainers, eval_gauntlet_df + + +def calculate_markdown_results( + logger_keys: list[str], + trainer: Trainer, + benchmark_to_taxonomy: Dict[str, str], + model_name: str, +): + results = {} + + for key in logger_keys: + # dl_name is either 2-tuple (benchmark_name, num_fewshot) + # or 3-tuple (benchmark_name, num_fewshot, subcategory) + dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1] + if 'Accuracy' not in metric_name: + continue + + metric = trainer.state.eval_metrics.get('/'.join(dl_name), + {}).get(metric_name, None) + + if metric is None: + continue + if dl_name[1] not in results: + results[dl_name[1]] = {} + + if dl_name[0] not in results[dl_name[1]]: + results[dl_name[1]][dl_name[0]] = {} + + if metric_name not in results[dl_name[1]][dl_name[0]]: + results[dl_name[1]][dl_name[0]][metric_name] = [] + + results[dl_name[1]][dl_name[0]][metric_name].append({ + 'val': metric.compute(), + 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat', + }) + + df = pd.DataFrame( + columns=[ + 'Category', + 'Benchmark', + 'Subtask', + 'Accuracy', + 'Number few shot', + 'Model', + ], + ) + + for num_shot in results: + for benchmark in results[num_shot]: + for metric in results[num_shot][benchmark]: + subscores = results[num_shot][benchmark][metric] + if len(subscores) == 1: + row = { + 'Category': benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': benchmark, + 'Subtask': None, + 'Accuracy': subscores[0]['val'], + 'Number few shot': num_shot, + 'Model': model_name, + } + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) + else: + row = { + 'Category': + benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': + benchmark, + 'Subtask': + 'Average', + 'Accuracy': + sum(s['val'] for s in subscores) / len(subscores), + 'Number few shot': + num_shot, + 'Model': + model_name, + } + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) + for sub in subscores: + row = { + 'Category': + benchmark_to_taxonomy.get(benchmark, ''), + 'Benchmark': + None, + 'Subtask': + sub['subcat'], + 'Accuracy': + sub['val'], + 'Number few shot': + num_shot, + 'Model': + model_name, + } + df = pd.concat([df, pd.DataFrame([row])], + ignore_index=True) + return df + + +def eval_from_yaml( + yaml_path: str, + args_list: Optional[list[str]], +) -> Tuple[list[Trainer], pd.DataFrame]: + """Run the evaluation with optional overrides from CLI.""" + # Load yaml and CLI arguments. + om.clear_resolver('oc.env') + with open(yaml_path) as f: + yaml_cfg = om.load(f) + if args_list: + cli_cfg = om.from_cli(args_list) + yaml_cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(yaml_cfg, DictConfig) + return evaluate(yaml_cfg) diff --git a/llmfoundry/train/train.py b/llmfoundry/command_utils/train.py similarity index 94% rename from llmfoundry/train/train.py rename to llmfoundry/command_utils/train.py index 273372e1cd..feed1e9fb1 100644 --- a/llmfoundry/train/train.py +++ b/llmfoundry/command_utils/train.py @@ -36,8 +36,10 @@ build_callback, build_composer_model, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_scheduler, build_tokenizer, ) @@ -256,6 +258,31 @@ def train(cfg: DictConfig) -> Trainer: # Optional fsdp data, fine-tuning, and eval configs fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config + if fsdp_config is not None: + if 'load_planner' in fsdp_config: + load_planners = fsdp_config['load_planner'].items() + if len(load_planners) > 1: + raise ValueError( + 'Only one load planner can be specified in the config.', + ) + load_planner_name, load_planner_config = load_planners[0] + fsdp_config['load_planner'] = build_load_planner( + load_planner_name, + **load_planner_config, + ) + + if 'save_planner' in fsdp_config: + save_planners = fsdp_config['save_planner'].items() + if len(save_planners) > 1: + raise ValueError( + 'Only one save planner can be specified in the config.', + ) + save_planner_name, save_planner_config = save_planners[0] + fsdp_config['save_planner'] = build_save_planner( + save_planner_name, + **save_planner_config, + ) + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str eval_gauntlet_config = train_cfg.eval_gauntlet or train_cfg.eval_gauntlet_str @@ -537,6 +564,12 @@ def train(cfg: DictConfig) -> Trainer: hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) return trainer + if train_cfg.only_composer_checkpoint: + log.info('Not training. Only saving composer checkpoint.') + trainer.save_checkpoint_to_save_folder() + log.info('Done saving checkpoint.') + return trainer + if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg) @@ -561,6 +594,7 @@ def train_from_yaml( ) -> Trainer: """Run the training with optional overrides from CLI.""" # Load yaml and CLI arguments. + om.clear_resolver('oc.env') with open(yaml_path) as f: yaml_cfg = om.load(f) if args_list: diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0adad8af4e..78bfb9c74c 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -913,6 +913,8 @@ def dataset_mapper(example: Dict): detected_cpu_count = os.cpu_count() or 1 detected_cpus_with_margin = detected_cpu_count - 8 num_cpus_to_use = max(1, detected_cpus_with_margin) + if len(dataset) < num_cpus_to_use: + num_cpus_to_use = 1 columns_to_remove = list(dataset[0].keys()) tokenized_dataset = dataset.map( diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py index c87b38b09a..8a8b9de551 100644 --- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py +++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py @@ -172,17 +172,26 @@ def __init__( self.dataset = self.dataset.map(strip_data) fewshot_rng = random.Random(fewshot_random_seed) + self._prepared = False + self.num_fewshot = num_fewshot + self.prompt_string = prompt_string + self.fewshot_rng = fewshot_rng + + def _prepare_dataset(self): self.dataset: HFDataset = self.dataset.map( self._prep_example, with_indices=True, fn_kwargs={ - 'num_fewshot': num_fewshot, - 'prompt_string': prompt_string, - 'fewshot_rng': fewshot_rng, + 'num_fewshot': self.num_fewshot, + 'prompt_string': self.prompt_string, + 'fewshot_rng': self.fewshot_rng, }, ) + self._prepared = True def __getitem__(self, index: int) -> Dict: + if not self._prepared: + self._prepare_dataset() return self.dataset[index] def __len__(self) -> int: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 6ab8249bac..536cd0257d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -90,6 +90,7 @@ def __init__( use_train_metrics: bool = True, additional_train_metrics: Optional[List] = None, additional_eval_metrics: Optional[List] = None, + should_save_peft_only: bool = True, ): config_overrides = config_overrides or {} @@ -131,6 +132,7 @@ def __init__( eval_metrics=eval_metrics, init_device=init_device, peft_config=peft_config_object, + should_save_peft_only=should_save_peft_only, ) @staticmethod diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index c667c6026a..7051986df8 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -40,6 +40,7 @@ def __init__( shift_labels: bool = False, init_device: Optional[str] = None, peft_config: Optional['PeftConfig'] = None, + should_save_peft_only: bool = True, ): super().__init__( model, @@ -49,7 +50,7 @@ def __init__( eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, - should_save_peft_only=True, + should_save_peft_only=should_save_peft_only, ) self.prepare_inner_model(self.model, init_device) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index dde7d64cd7..8e740be2b3 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -411,6 +411,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -426,6 +427,7 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln self.qk_gn = qk_gn + self.fused_qkv = fused_qkv self.d_model = d_model self.n_heads = n_heads @@ -462,7 +464,17 @@ def __init__( self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = attn_pdrop - if self.reuse_kv_layer_idx is None: + if self.reuse_kv_layer_idx is not None: + self.Wq = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.d_model, + fc_kwargs=fc_type, + ) + # for param init fn; enables shape based init of fused layers + fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + self.Wq._fused = (0, fuse_splits) + elif self.fused_qkv: self.Wqkv = build_fc( name=fc_type_name, in_features=self.d_model, @@ -482,9 +494,26 @@ def __init__( out_features=self.d_model, fc_kwargs=fc_type, ) + self.Wk = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) + self.Wv = build_fc( + name=fc_type_name, + in_features=self.d_model, + out_features=self.kv_n_heads * self.head_dim, + fc_kwargs=fc_type, + ) # for param init fn; enables shape based init of fused layers - fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] - self.Wq._fused = (0, fuse_splits) + q_fuse_splits = [i * self.head_dim for i in range(1, self.n_heads)] + kv_fuse_splits = [ + i * self.head_dim for i in range(1, self.kv_n_heads) + ] + self.Wq._fused = (0, q_fuse_splits) + self.Wk._fused = (0, kv_fuse_splits) + self.Wv._fused = (0, kv_fuse_splits) if self.qk_ln or self.qk_gn: norm_size = self.head_dim if qk_gn else d_model @@ -601,19 +630,29 @@ def get_qkv( query = self.q_ln(query).to(dtype).view(q_shape) return query, key, value - qkv = self.Wqkv(x) + if self.fused_qkv: + qkv = self.Wqkv(x) - if self.clip_qkv: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + if self.clip_qkv: + qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query, key, value = qkv.split( + [ + self.d_model, + self.kv_n_heads * self.head_dim, + self.kv_n_heads * self.head_dim, + ], + dim=2, + ) + else: + query = self.Wq(x) + key = self.Wk(x) + value = self.Wv(x) - query, key, value = qkv.split( - [ - self.d_model, - self.kv_n_heads * self.head_dim, - self.kv_n_heads * self.head_dim, - ], - dim=2, - ) + if self.clip_qkv: + query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) + key = key.clamp(min=-self.clip_qkv, max=self.clip_qkv) + value = value.clamp(min=-self.clip_qkv, max=self.clip_qkv) if self.qk_ln or self.qk_gn: # Applying layernorm to qk @@ -753,6 +792,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -770,6 +810,7 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, @@ -796,6 +837,7 @@ def __init__( clip_qkv: Optional[float] = None, qk_ln: bool = False, qk_gn: bool = False, + fused_qkv: bool = True, softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', @@ -813,6 +855,7 @@ def __init__( clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, + fused_qkv=fused_qkv, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index a1fdc25f50..3de3744745 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -70,6 +70,8 @@ def __init__( attn_impl (str): The attention implementation to use. One of 'torch' or 'flash'. qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer. + fused_qkv (bool): Whether to fuse the Wq, Wk, and Wv weight matrices in the attention layer. If True, the weights are fused into a single + Wqkv matrix, which can be faster for matmuls. If False, the weights are kept separate. Defaults to True. clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to this value. softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 2b6fc2f7c7..c272a52dd4 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -15,6 +15,7 @@ 'attn_impl': 'flash', 'qk_ln': False, 'qk_gn': False, + 'fused_qkv': True, 'clip_qkv': None, 'softmax_scale': None, 'attn_uses_sequence_id': False, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 50481211ac..e31840d3fb 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -6,6 +6,7 @@ from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim import ComposerScheduler +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torch.utils.data import Dataset @@ -339,6 +340,42 @@ description=_config_transforms_description, ) +_load_planners_description = ( + """The load_planners registry is used to register classes that implement the LoadPlanner interface. + + The LoadPlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to load distributed checkpoints. + + Returns: + LoadPlanner: The load planner. + """ +) + +load_planners = create_registry( + 'llmfoundry', + 'load_planners', + generic_type=Type[LoadPlanner], + entry_points=True, + description=_load_planners_description, +) + +_save_planners_description = ( + """The save_planners registry is used to register classes that implement the SavePlanner interface. + + The savePlanner will be passed as part of the FSDP config arg of the Trainer. It will be used to save distributed checkpoints. + + Returns: + SavePlanner: The save planner. + """ +) + +save_planners = create_registry( + 'llmfoundry', + 'save_planners', + generic_type=Type[SavePlanner], + entry_points=True, + description=_save_planners_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -363,4 +400,6 @@ 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', ] diff --git a/llmfoundry/train/__init__.py b/llmfoundry/train/__init__.py deleted file mode 100644 index 8a4c2749db..0000000000 --- a/llmfoundry/train/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 -from llmfoundry.train.train import ( - TRAIN_CONFIG_KEYS, - TrainConfig, - train, - train_from_yaml, - validate_config, -) - -__all__ = [ - 'train', - 'train_from_yaml', - 'TrainConfig', - 'TRAIN_CONFIG_KEYS', - 'validate_config', -] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 012a0b704f..0437736f74 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -27,6 +27,7 @@ from composer.utils import dist from omegaconf import DictConfig from omegaconf import OmegaConf as om +from torch.distributed.checkpoint import LoadPlanner, SavePlanner from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -187,6 +188,44 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb +def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner: + """Builds a load planner from the registry. + + Args: + name: Name of the load planner to build. + + Returns: + LoadPlanner: The load planner. + """ + return construct_from_registry( + name=name, + registry=registry.load_planners, + partial_function=True, + pre_validation_function=LoadPlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + +def build_save_planner(name: str, **kwargs: Any) -> SavePlanner: + """Builds a save planner from the registry. + + Args: + name: Name of the save planner to build. + + Returns: + savePlanner: The save planner. + """ + return construct_from_registry( + name=name, + registry=registry.save_planners, + partial_function=True, + pre_validation_function=SavePlanner, + post_validation_function=None, + kwargs=kwargs, + ) + + def build_composer_model( name: str, cfg: Dict[str, Any], diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 2667fceb67..4b86de99b8 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -162,6 +162,7 @@ class TrainConfig: load_ignore_keys: Optional[List[str]] = None save_ignore_keys: Optional[List[str]] = None only_hf_checkpoint: bool = False + only_composer_checkpoint: bool = False # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index bf7f145610..3b893868b2 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -2,30 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for C4 and The Pile.""" -import json -import os -import platform from argparse import ArgumentParser, Namespace -from dataclasses import dataclass, field -from enum import Enum -from typing import Dict, Iterable, Optional, Union -import datasets as hf_datasets -import psutil -import torch -from numpy.typing import NDArray -from streaming import MDSWriter -from torch.utils.data import DataLoader, Dataset, IterableDataset -from tqdm import tqdm -from transformers import PreTrainedTokenizerBase - -from llmfoundry.data import ConcatTokensDataset, NoConcatDataset -from llmfoundry.utils.builders import build_tokenizer - - -class ConcatMode(Enum): - NO_CONCAT = 'NO_CONCAT' - CONCAT_TOKENS = 'CONCAT_TOKENS' +from llmfoundry.command_utils import convert_dataset_hf_from_args def parse_args() -> Namespace: @@ -64,398 +43,22 @@ def parse_args() -> Namespace: parser.add_argument('--num_workers', type=int, required=False, default=None) parsed = parser.parse_args() - - if parsed.tokenizer_kwargs is not None: - parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) - else: - parsed.tokenizer_kwargs = {} - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.splits)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - # Make sure we have needed concat options - if ( - parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None - ): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer', - ) - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -@dataclass -class DataSplitConstants: - hf_split: str - folder_split: str - raw_samples: Optional[int] - truncated_samples: Union[int, None] - - -@dataclass -class DatasetConstants: - chars_per_sample: int - chars_per_token: int - splits: Dict[str, DataSplitConstants] = field(default_factory=dict) - - def __iter__(self): - for v in self.splits.values(): - yield v - - -class TrainSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'train', - folder_split: str = 'train_small', - raw_samples: int = 100000, - truncated_samples: int = 100000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -class ValSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'validation', - folder_split: str = 'val_small', - raw_samples: int = 10000, - truncated_samples: int = 10000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -class ValXSmallConstants(DataSplitConstants): - - def __init__( - self, - hf_split: str = 'validation', - folder_split: str = 'val_xsmall', - raw_samples: int = 3000, - truncated_samples: int = 3000, - ): - super().__init__(hf_split, folder_split, raw_samples, truncated_samples) - - -pileconstants = DatasetConstants( - chars_per_sample=6212, # Computed over validation set - chars_per_token=4, # OpenAI estimate -) -pileconstants.splits['train'] = DataSplitConstants( - hf_split='train', - folder_split='train', - raw_samples=210607728, - truncated_samples=None, -) -pileconstants.splits['train_small'] = DataSplitConstants( - hf_split='train', - folder_split='train_small', - raw_samples=100000, - truncated_samples=100000, -) -pileconstants.splits['val'] = DataSplitConstants( - hf_split='validation', - folder_split='val', - raw_samples=214670, - truncated_samples=None, -) -pileconstants.splits['val_small'] = DataSplitConstants( - hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000, -) -pileconstants.splits['val_xsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xsmall', - raw_samples=3000, - truncated_samples=3000, -) - -c4constants = DatasetConstants( - chars_per_sample=2163, # Computed over validation set - chars_per_token=4, # OpenAI estimate -) -c4constants.splits['train'] = DataSplitConstants( - hf_split='train', - folder_split='train', - raw_samples=364868892, - truncated_samples=None, -) -c4constants.splits['train_small'] = DataSplitConstants( - hf_split='train', - folder_split='train_small', - raw_samples=100000, - truncated_samples=100000, -) -c4constants.splits['val'] = DataSplitConstants( - hf_split='validation', - folder_split='val', - raw_samples=364608, - truncated_samples=None, -) -c4constants.splits['val_small'] = DataSplitConstants( - hf_split='validation', - folder_split='val_small', - raw_samples=10000, - truncated_samples=10000, -) -c4constants.splits['val_xsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xsmall', - raw_samples=3000, - truncated_samples=3000, -) -c4constants.splits['val_xxsmall'] = DataSplitConstants( - hf_split='validation', - folder_split='val_xxsmall', - raw_samples=100, - truncated_samples=100, -) - -CONSTS = {'c4': c4constants, 'the_pile': pileconstants} - - -def build_hf_dataset( - dataset_name: str, - split: str, - mode: ConcatMode, - max_length: Optional[int] = None, - bos_text: str = '', - eos_text: str = '', - no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, - data_subset: Union[str, None] = None, -) -> IterableDataset: - """Build an IterableDataset over the HF C4 or pile source data. - - Args: - dataset_name (str): Dataset name - split (str): Split name. - mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS - max_length (int): The length of concatenated tokens - bos_text (str): text to insert at the beginning of each sequence - eos_text (str): text to insert at the end of each sequence - no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries - tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use - data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). - - Returns: - An IterableDataset. - """ - hf_dataset = hf_datasets.load_dataset( - path=dataset_name, - name=data_subset, - split=split, - streaming=True, - ) - if mode == ConcatMode.NO_CONCAT: - dataset = NoConcatDataset(hf_dataset) - else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase', - ) - if max_length is None: - raise ValueError(f'max_length must be set.') - if bos_text + eos_text == '': - test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: - tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' - tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' - tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' - tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' - tok_error_msg += '--bos_text=<|endoftext|>.' - raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap, - ) - return dataset - - -def _est_progress_denominator( - total_samples: int, - chars_per_sample: int, - chars_per_token: int, - mode: ConcatMode, - max_length: int, -): - est_tokens_per_sample = chars_per_sample // chars_per_token - if mode == ConcatMode.NO_CONCAT: - return total_samples - elif mode == ConcatMode.CONCAT_TOKENS: - return total_samples * est_tokens_per_sample // max_length - - -def build_dataloader( - dataset: Dataset, - batch_size: int, - num_workers: Optional[int], -) -> DataLoader: - if num_workers is None: - # Multiple workers is only supported on linux machines - if 'linux' or 'macos' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) - else: - num_workers = 0 - - # If using multiple workers, configure each worker to prefetch as many samples as it can, up to - # the aggregate device batch size - # If not using workers, the torch DataLoader expects the default value for prefetch_factor, - # which non-intuitively must be 2. - prefetch_factor = max( - 1, - 2 * batch_size // num_workers, - ) if num_workers > 0 else 2 - - return DataLoader( - dataset=dataset, - sampler=None, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=prefetch_factor, - ) - - -def generate_samples( - loader: DataLoader, - truncate_num_samples: Optional[int] = None, -) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]: - """Generator over samples of a dataloader. - - Args: - loader (DataLoader): A dataloader emitting batches like {key: [sample0_bytes, sample1_bytes, sample2_bytes, ...]} - truncate_num_samples (Optional[int]): An optional # of samples to stop at. - - Yields: - Sample dicts. - """ - n_samples = 0 - for batch in loader: - keys = list(batch.keys()) - current_bs = len(batch[keys[0]]) - for idx in range(current_bs): - if truncate_num_samples is not None and n_samples == truncate_num_samples: - return - n_samples += 1 - yield { - k: - v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx] - for k, v in batch.items() - } - - -def main(args: Namespace) -> None: - """Main: create C4/pile streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - try: - dataset_constants = CONSTS[args.dataset] - except KeyError: - raise ValueError( - f'Constants for dataset "{args.dataset}" not found. Currently only "the_pile" and "c4" are supported.', - ) - - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'ndarray:int32'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str'} - - for split_name in args.splits: - try: - split = dataset_constants.splits[split_name] - except KeyError: - raise KeyError(f'Constants not defined for split {split_name}.') - hf_split = split.hf_split - folder_split = split.folder_split - expected_num_samples = split.raw_samples - truncate_num_samples = split.truncated_samples - # Only generate the splits requested - if folder_split not in args.splits: - continue - - # Get samples - dataset = build_hf_dataset( - dataset_name=args.dataset, - data_subset=args.data_subset, - split=hf_split, - mode=mode, - max_length=args.concat_tokens, - bos_text=args.bos_text, - eos_text=args.eos_text, - no_wrap=args.no_wrap, - tokenizer=tokenizer, - ) - loader = build_dataloader( - dataset=dataset, - batch_size=512, - num_workers=args.num_workers, - ) - samples = generate_samples( - loader, - truncate_num_samples=truncate_num_samples, - ) - - if expected_num_samples is not None: - denominator = truncate_num_samples if truncate_num_samples is not None else _est_progress_denominator( - total_samples=expected_num_samples, - chars_per_sample=dataset_constants.chars_per_sample, - chars_per_token=dataset_constants.chars_per_token, - mode=mode, - max_length=args.concat_tokens, - ) - else: - denominator = None - - # Write samples - print(f'Converting {folder_split} to MDS format...') - print( - f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.', - ) - with MDSWriter( - columns=columns, - out=os.path.join(args.out_root, folder_split), - compression=args.compression, - ) as out: - if denominator is not None: - for sample in tqdm( - samples, - desc=folder_split, - total=denominator, - ): - out.write(sample) - else: - for sample in tqdm(samples, desc=folder_split): - out.write(sample) - - if __name__ == '__main__': - main(parse_args()) + args = parse_args() + convert_dataset_hf_from_args( + dataset=args.dataset, + data_subset=args.data_subset, + splits=args.splits, + out_root=args.out_root, + compression=args.compression, + concat_tokens=args.concat_tokens, + tokenizer=args.tokenizer, + tokenizer_kwargs=args.tokenizer_kwargs, + bos_text=args.bos_text, + eos_text=args.eos_text, + no_wrap=args.no_wrap, + num_workers=args.num_workers, + ) diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index 37b0465692..5a6927ac75 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -2,24 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for json files.""" -import os from argparse import ArgumentParser, Namespace -from enum import Enum -from glob import glob -from typing import Optional -import datasets as hf_datasets -from streaming import MDSWriter -from torch.utils.data import IterableDataset -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from llmfoundry.data import ConcatTokensDataset, NoConcatDataset - - -class ConcatMode(Enum): - NO_CONCAT = 'NO_CONCAT' - CONCAT_TOKENS = 'CONCAT_TOKENS' +from llmfoundry.command_utils import convert_dataset_json_from_args def parse_args() -> Namespace: @@ -46,145 +31,19 @@ def parse_args() -> Namespace: parser.add_argument('--no_wrap', default=False, action='store_true') parsed = parser.parse_args() - - if os.path.isdir(parsed.out_root) and len( - set(os.listdir(parsed.out_root)).intersection(set(parsed.split)), - ) > 0: - raise ValueError( - f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.', - ) - - # Make sure we have needed concat options - if ( - parsed.concat_tokens is not None and - isinstance(parsed.concat_tokens, int) and parsed.tokenizer is None - ): - parser.error( - 'When setting --concat_tokens, you must specify a --tokenizer', - ) - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def build_hf_dataset( - path: str, - split: str, - mode: ConcatMode, - max_length: Optional[int] = None, - bos_text: str = '', - eos_text: str = '', - no_wrap: bool = False, - tokenizer: PreTrainedTokenizerBase = None, -) -> IterableDataset: - """Build an IterableDataset over the HF C4 or pile source data. - - Args: - dataset_name (str): Dataset name - split (str): Split name. - mode (ConcatMode): NO_CONCAT, or CONCAT_TOKENS - max_length (int): The length of concatenated tokens - bos_text (str): text to insert at the beginning of each sequence - eos_text (str): text to insert at the end of each sequence - no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries - tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use - data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). - - Returns: - An IterableDataset. - """ - if os.path.isdir(path): - data_files = glob(f'{path}/*') - else: - data_files = path - - hf_dataset = hf_datasets.load_dataset( - 'json', - data_files=data_files, - split=split, - ) - - if mode == ConcatMode.NO_CONCAT: - dataset = NoConcatDataset(hf_dataset) - else: - if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - f'{tokenizer=} must be of type PreTrainedTokenizerBase', - ) - if max_length is None: - raise ValueError(f'max_length must be set.') - if bos_text + eos_text == '': - test_tokens = tokenizer('test') - if test_tokens['input_ids'][ - 0] != tokenizer.bos_token_id and test_tokens['input_ids'][ - -1] != tokenizer.eos_token_id: - tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. ' - tok_error_msg += 'Concatenating with this tokenizer will result in sequences being ' - tok_error_msg += 'attached without a separating token. Please use another tokenizer, ' - tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. ' - tok_error_msg += '--bos_text=<|endoftext|>.' - raise ValueError(tok_error_msg) - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - tokenizer=tokenizer, - max_length=max_length, - bos_text=bos_text, - eos_text=eos_text, - no_wrap=no_wrap, - ) - return dataset - - -def main(args: Namespace) -> None: - """Main: create C4/pile streaming dataset. - - Args: - args (Namespace): Commandline arguments. - """ - if args.concat_tokens is not None: - mode = ConcatMode.CONCAT_TOKENS - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - # we will enforce length, so suppress warnings about sequences too long for the model - tokenizer.model_max_length = int(1e30) - columns = {'tokens': 'ndarray:int32'} - else: - mode = ConcatMode.NO_CONCAT - tokenizer = None - columns = {'text': 'str'} - - # Get samples - dataset = build_hf_dataset( +if __name__ == '__main__': + args = parse_args() + convert_dataset_json_from_args( path=args.path, + out_root=args.out_root, + compression=args.compression, + concat_tokens=args.concat_tokens, split=args.split, - mode=mode, - max_length=args.concat_tokens, + tokenizer=args.tokenizer, bos_text=args.bos_text, eos_text=args.eos_text, no_wrap=args.no_wrap, - tokenizer=tokenizer, - ) - - print('here') - - # Write samples - print(f'Converting to MDS format...') - print( - f'Note that the progress bar is based on the dataset length before tokenization.', ) - print(f'It will finish at a value below 100% if tokenizing') - with MDSWriter( - columns=columns, - out=os.path.join(args.out_root), - compression=args.compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - -if __name__ == '__main__': - main(parse_args()) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f664f5baca..3b88ba668f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -82,6 +82,8 @@ def to_cf(self: SparkConnectClient, - Total row count of all parts of the result. - A boolean indicating whether the result has been truncated. """ + log.info(f'Executing query plan with format: {type}') + req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) @@ -166,6 +168,7 @@ def collect_as_cf(self: DataFrame, - Total row count of all parts of the result. - A boolean indicating whether the result is truncated or overflowed. """ + log.info(f'Collecting DataFrame as cloud fetch with format: {type}') query = self._plan.to_proto(self._session.client) # pyright: ignore return self._session.client.to_cf(query, type) # pyright: ignore @@ -182,13 +185,18 @@ def iterative_combine_jsons(json_directory: str, output_file: str) -> None: json_directory(str): directory containing the JSONL files output_file(str): path to the output combined JSONL file """ + log.info( + f'Starting to combine JSON files from {json_directory} into {output_file}', + ) json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + log.info(f'Found {len(json_files)} JSON files to combine') with open(output_file, 'w') as outfile: for file_name in json_files: + log.debug(f'Processing file: {file_name}') with open(os.path.join(json_directory, file_name), 'r') as infile: for line in infile: outfile.write(line) - log.info('JSON files have been combined into a JSONL file.') + log.info('JSON files have been successfully combined into a JSONL file.') def run_query( @@ -207,6 +215,9 @@ def run_query( spark (Optional[SparkSession]): spark session collect (bool): whether to get the underlying data from spark dataframe """ + log.info(f'Executing query using method: {method}') + log.debug(f'Query: {query}') + if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') @@ -247,6 +258,8 @@ def download( resp_format (str): whether to use arrow or json when collect compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. """ + log.info(f'Downloading part {ipart} from URL: {url}') + resp = requests.get(url) if resp.status_code == 200: if resp_format == 'json': @@ -294,6 +307,7 @@ def format_tablename(table_name: str) -> str: Args: table_name (str): catalog.scheme.tablename on UC """ + log.debug(f'Formatting table name: {table_name}') match = re.match(TABLENAME_PATTERN, table_name) if match is None: @@ -337,6 +351,7 @@ def fetch_data( Returns: None: The function doesn't return any value, but writes the result to a JSONL file. """ + log.info(f'Fetching data from {start} to {end} using method: {method}') query = f""" WITH NumberedRows AS ( SELECT @@ -428,6 +443,11 @@ def fetch( sparkSession (pyspark.sql.sparksession): spark session dbsql (databricks.sql.connect): dbsql session """ + log.info(f'Starting data fetch for table: {tablename}') + log.info( + f'Method: {method}, Batch size: {batch_size}, Processes: {processes}', + ) + cursor = dbsql.cursor() if dbsql is not None else None try: nrows = get_total_rows( @@ -505,6 +525,11 @@ def validate_and_get_cluster_info( http_path (Optional[str]): http path to use for sql connect use_serverless (bool): whether to use serverless or not """ + log.info('Validating cluster information and getting connection details') + log.debug( + f'Cluster ID: {cluster_id}, Host: {databricks_host}, Use Serverless: {use_serverless}', + ) + method = 'dbsql' dbsql = None sparkSession = None @@ -575,6 +600,10 @@ def validate_and_get_cluster_info( def fetch_DT(args: Namespace) -> None: """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') + log.info('Starting Delta Table to JSON conversion process') + log.info(f'Delta Table: {args.delta_table_name}') + log.info(f'Output Folder: {args.json_output_folder}') + log.info(f'Output Filename: {args.json_output_filename}') obj = urllib.parse.urlparse(args.json_output_folder) if obj.scheme != '': @@ -626,6 +655,8 @@ def fetch_DT(args: Namespace) -> None: os.path.join(args.json_output_folder, args.json_output_filename), ) + log.info('Delta Table to JSON conversion completed successfully') + if __name__ == '__main__': parser = ArgumentParser( @@ -695,3 +726,4 @@ def fetch_DT(args: Namespace) -> None: tik = time.time() fetch_DT(args) log.info(f'Elapsed time {time.time() - tik}') + log.info('Delta Table to JSON conversion script completed') diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index 92c36eb35d..c808fa871f 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -2,107 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import math -import os -import tempfile from argparse import ArgumentParser, Namespace -from concurrent.futures import ProcessPoolExecutor -from functools import partial -from glob import glob -from typing import Dict, Iterable, List, Tuple, cast -import numpy as np import psutil -from composer.utils import ( - ObjectStore, - maybe_create_object_store_from_uri, - parse_uri, -) -from numpy.typing import NDArray -from streaming import MDSWriter -from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.data.data import AbstractConcatTokensDataset -from llmfoundry.utils.data_prep_utils import ( - DownloadingIterable, - download_file, - merge_shard_groups, -) -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) +from llmfoundry.command_utils import convert_text_to_mds_from_args log = logging.getLogger(__name__) DONE_FILENAME = '.text_to_mds_conversion_done' -class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset): - """An IterableDataset that returns token samples for MDSWriter from files. - - Returns dicts of {'tokens': ndarray:int32} - - Each file is considered a sequence. - """ - - def __init__( - self, - files: Iterable[str], - tokenizer: PreTrainedTokenizerBase, - max_length: int, - bos_text: str, - eos_text: str, - no_wrap: bool, - ): - self.files = files - super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap) - - def __iter__(self) -> Iterable[Dict[str, NDArray]]: - - buffer = [] - for file in self.files: - with open(file, 'r') as f: - buffer += self.bos_tokens - first_chunk = True - # Read the file in 1MB chunks to avoid memory issues - for chunk in iter(partial(f.read, 1000000), ''): - # Tokenize the chunk - encoded = self.tokenizer( - chunk, - truncation=False, - padding=False, - ) - iids = encoded['input_ids'] - - # If this is not the first chunk, remove the BOS token - if not first_chunk: - if iids[0] == self.tokenizer.bos_token_id: - iids = iids[1:] - - # Add the tokens to the buffer - buffer += iids - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self. - max_length:] if self.should_wrap else [] - yield { - 'tokens': np.asarray(concat_sample, dtype=np.int32), - } - - first_chunk = False - - # Add the EOS token to the buffer to separate files. - buffer += self.eos_tokens - - # Yield any remaining samples of size max_length. - while len(buffer) >= self.max_length: - concat_sample = buffer[:self.max_length] - buffer = buffer[self.max_length:] if self.should_wrap else [] - yield {'tokens': np.asarray(concat_sample, dtype=np.int32)} - - def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( @@ -203,418 +113,23 @@ def parse_args() -> Namespace: help='Logging level for the script. Default is INFO.', ) parsed = parser.parse_args() - - # Set eos token. - if parsed.use_tokenizer_eos: - # Ensure that eos text is not specified twice. - if parsed.eos_text is not None: - parser.error( - 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.', - ) - tokenizer = AutoTokenizer.from_pretrained( - parsed.tokenizer, - trust_remote_code=parsed.trust_remote_code, - ) - parsed.eos_text = tokenizer.eos_token - - # now that we have validated them, change BOS/EOS to strings - if parsed.bos_text is None: - parsed.bos_text = '' - if parsed.eos_text is None: - parsed.eos_text = '' return parsed -def get_object_names(input_folder: str) -> List[str]: - """Get object names from a local or remote folder. - - Args: - input_folder (str): local or remote folder path. - """ - object_store = maybe_create_object_store_from_uri(input_folder) - if object_store is not None: - _, _, folder_prefix = parse_uri(input_folder) - names = [ - name for name in object_store.list_objects(folder_prefix) - if name.endswith('.txt') - ] - else: - # input_folder is a local folder - names = [ - text_file for dirpath, _, _ in os.walk(input_folder) - for text_file in glob(os.path.join(dirpath, '*.txt')) - ] - # return names, sizes - log.info(f'Found {len(names)} text files at {input_folder}') - - return names - - -def get_task_args( - object_names: List[str], - output_root: str, - input_folder: str, - n_groups: int, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -) -> Iterable: - """Get download_and_convert arguments split across n_groups. - - Each group handles a portion of object_names. - - Args: - object_names (List[str]): Names of objects to process - output_root (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - n_groups (int): Number of groups to split the object names into - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - num_objects = len(object_names) - objs_per_group = math.ceil(num_objects / n_groups) - for group, i in enumerate(range(0, num_objects, objs_per_group)): - output_subdir = os.path.join(output_root, str(group)) - yield ( - object_names[i:min(i + objs_per_group, num_objects)], - output_subdir, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - -def download_and_convert_starargs(args: Tuple): - """Helper function to call download_and_convert with star args. - - This helps us use download_and_convert with multiprocessing. - """ - return download_and_convert(*args) - - -def download_and_convert( - file_names: List[str], - output_folder: str, - input_folder: str, - tokenizer_name: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - trust_remote_code: bool, -): - """Downloads and converts text files to MDS format. - - Args: - file_names (List[str]): Files to process - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - tokenizer_name (str): Name of tokenizer to use - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - object_store = maybe_create_object_store_from_uri(input_folder) - - # Download file_names - with tempfile.TemporaryDirectory() as tmp_dir: - downloading_iter = DownloadingIterable( - object_names=file_names, - output_folder=tmp_dir, - object_store=object_store, - ) - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - trust_remote_code=trust_remote_code, - ) - tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace - - # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up - # to the maximum sequence length - dataset = ConcatTokensFromFilesDataset( - files=downloading_iter, - max_length=concat_tokens, - tokenizer=tokenizer, - eos_text=eos_text, - bos_text=bos_text, - no_wrap=no_wrap, - ) - - columns = {'tokens': 'ndarray:int32'} - - log.info('Converting to MDS format...') - with MDSWriter( - out=output_folder, - columns=columns, - compression=compression, - ) as out: - for sample in tqdm(dataset): - out.write(sample) - - -def is_remote_path(path: str) -> bool: - """Checks whether a path is a remote path. - - Args: - path (str): path to check - """ - backend, _, _ = parse_uri(path) - return backend != '' - - -def is_already_processed( - output_root: str, - args_str: str, - object_names: List[str], -) -> bool: - """Determines whether a group of text files has already been processed. - - Checks the done fie at output root to determine this. - - Args: - output_root (str): Output folder where a done file may exist - args_str (str): String representation of the arguments - object_names (List[str]): Names of objects to convert to MDS format - """ - # Retrieve the done file contents - output_object_store = maybe_create_object_store_from_uri(output_root) - if output_object_store is not None: - # Download and read the done file from the remote object store - _, _, output_folder_prefix = parse_uri(output_root) - try: - with tempfile.TemporaryDirectory() as tmp_dir: - done_file = os.path.join(tmp_dir, DONE_FILENAME) - download_file( - object_store=output_object_store, - object_name=os.path.join( - output_folder_prefix, - DONE_FILENAME, - ), - output_filename=done_file, - ) - with open(done_file) as df: - done_file_contents = df.read().splitlines() - except FileNotFoundError: - return False - else: - # Read the local done file - done_file = os.path.join(output_root, DONE_FILENAME) - if not os.path.isfile(done_file): - return False - with open(done_file) as df: - done_file_contents = df.read().splitlines() - # Compare the arguments - prev_args_str = done_file_contents[0] - if prev_args_str != args_str: - return False - - # Compare file names - prev_names = done_file_contents[1:] - if len(prev_names) != len(object_names): - return False - for idx, prev_name in enumerate(prev_names): - if object_names[idx] != prev_name: - return False - return True - - -def write_done_file(folder: str, args_str: str, object_names: List[str]): - """Write a file to signify completion. - - This the done file includes the arguments to processing and - a list of objects that were processed. - - Args: - folder (str): Folder to write the done file to - args_str (str): String representation of arguments - object_names (List[str]): List of objects to convert to MDS format - """ - with open(os.path.join(folder, DONE_FILENAME), 'w') as done_file: - done_file.write('\n'.join([args_str] + object_names) + '\n') - - -def convert_text_to_mds( - tokenizer_name: str, - output_folder: str, - input_folder: str, - concat_tokens: int, - eos_text: str, - bos_text: str, - no_wrap: bool, - compression: str, - processes: int, - args_str: str, - reprocess: bool, - trust_remote_code: bool, -): - """Convert a folder of text files to MDS format. - - Args: - tokenizer_name (str): Name of tokenizer to use - output_folder (str): Folder to write MDS shards to - input_folder (str): Folder of text files to process - concat_tokens (int): Concatenate up to this many tokens - eos_text (str): Text to append to each example to separate concatenated samples - bos_text (str): Text to prepend to each example to separate concatenated samples - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - compression (str): The compression algorithm to use for MDS writing - processes (int): The number of processes to use. - args_str (str): String representation of the arguments - reprocess (bool): Whether to always reprocess the given folder of text files - trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer - """ - is_remote_output = is_remote_path(output_folder) - - object_names = get_object_names(input_folder) - if len(object_names) == 0: - raise InputFolderMissingDataError(input_folder) - - # Check if the text files in the bucket have already been processed. - if not reprocess and is_already_processed( - output_folder, - args_str, - object_names, - ): - log.info( - f'Input folder {input_folder} is already processed at {output_folder} and ' - + - 'reprocess is set to False. Set reprocess to True if you would like to force reprocessing.', - ) - return - - # Use a temporary local directory if the output is remote and there are more than 1 processes - local_output_folder = tempfile.TemporaryDirectory( - ).name if is_remote_output else output_folder - - if os.path.isdir(output_folder) and len(os.listdir(output_folder)) > 0: - raise OutputFolderNotEmptyError(output_folder) - - if processes > 1: - # Download and convert the text files in parallel - args = get_task_args( - object_names, - local_output_folder, - input_folder, - processes, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_and_convert_starargs, args)) - - # Merge the mds shards from each of the processes into a single folder - merge_shard_groups(local_output_folder) - else: - download_and_convert( - object_names, - local_output_folder, - input_folder, - tokenizer_name, - concat_tokens, - eos_text, - bos_text, - no_wrap, - compression, - trust_remote_code, - ) - - # Write a done file with the args and object names - write_done_file(local_output_folder, args_str, object_names) - - if is_remote_output: - # Upload the local output to the remote location - output_object_store = cast( - ObjectStore, - maybe_create_object_store_from_uri(output_folder), - ) - _, _, output_folder_prefix = parse_uri(output_folder) - files_to_upload = os.listdir(local_output_folder) - - for file in files_to_upload: - assert not os.path.isdir(file) - remote_path = os.path.join(output_folder_prefix, file) - output_object_store.upload_object( - remote_path, - os.path.join(local_output_folder, file), - ) - - -def _args_str(original_args: Namespace) -> str: - """Create a string from the args to determine whether to reprocess. - - Args: - original_args (Namespace): Arguments to main function. - """ - # Take the arguments that influence the final result. - # reprocess and max_mds_writer_workers are not taken. - args = Namespace( - tokenizer_name=original_args.tokenizer, - output_folder=original_args.output_folder, - input_folder=original_args.input_folder, - concat_tokens=original_args.concat_tokens, - eos_text=original_args.eos_text, - bos_text=original_args.bos_text, - no_wrap=original_args.no_wrap, - compression=original_args.compression, - processes=original_args.processes, - ) - - return str(args) - - -def _configure_logging(logging_level: str): - """Configure logging. - - Args: - logging_level (str): Logging level. - """ - logging.basicConfig( - format= - f'%(asctime)s: [%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging_level = logging_level.upper() - logging.getLogger('llmfoundry').setLevel(logging_level) - logging.getLogger(__name__).setLevel(logging_level) - log.info(f'Logging level set to {logging_level}') - - if __name__ == '__main__': args = parse_args() - _configure_logging(args.logging_level) - convert_text_to_mds( - tokenizer_name=args.tokenizer, + convert_text_to_mds_from_args( output_folder=args.output_folder, input_folder=args.input_folder, + compression=args.compression, concat_tokens=args.concat_tokens, - eos_text=args.eos_text, + tokenizer_name=args.tokenizer, bos_text=args.bos_text, + eos_text=args.eos_text, + use_tokenizer_eos=args.use_tokenizer_eos, no_wrap=args.no_wrap, - compression=args.compression, processes=args.processes, reprocess=args.reprocess, trust_remote_code=args.trust_remote_code, - args_str=_args_str(args), + logging_level=args.logging_level, ) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 29a03b72cc..caafda4b87 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,494 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import logging -import os import sys -import time -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd -import torch -from composer.core import Callback -from composer.loggers.logger_destination import LoggerDestination -from composer.trainer import Trainer -from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig -from omegaconf import OmegaConf as om -from rich.traceback import install - -from llmfoundry.utils import ( - find_mosaicml_logger, - log_eval_analytics, - maybe_create_mosaicml_logger, -) - -install() -from llmfoundry.utils.builders import ( - add_metrics_to_eval_loaders, - build_callback, - build_composer_model, - build_evaluators, - build_logger, - build_tokenizer, -) -from llmfoundry.utils.config_utils import ( - EVAL_CONFIG_KEYS, - EvalConfig, - log_config, - make_dataclass_and_log_config, - process_init_device, -) -from llmfoundry.utils.registry_utils import import_file - -log = logging.getLogger(__name__) - - -def evaluate_model( - tokenizer: Dict[str, Any], - model_name: str, - model: Dict[str, Any], - dist_timeout: Union[float, int], - run_name: str, - seed: int, - icl_tasks: Union[str, List[Dict[str, Any]]], - max_seq_len: int, - device_eval_batch_size: Union[int, float], - eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], - eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], - fsdp_config: Optional[Dict[str, Any]], - loggers: List[LoggerDestination], - python_log_level: Optional[str], - precision: str, - eval_gauntlet_df: Optional[pd.DataFrame], - eval_subset_num_batches: int, - icl_subset_num_batches: Optional[int], - callback_configs: Optional[Dict[str, Any]], - metadata: Optional[Dict[str, str]], - logged_config: Dict[str, Any], - should_log_config: bool = True, - load_path: Optional[str] = None, -): - log.info(f'Evaluating model: {model_name}') - # Build tokenizer and model - tokenizer_cfg = tokenizer - tokenizer_name = tokenizer_cfg['name'] - tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) - tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - - evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, - icl_seq_len=max_seq_len, - icl_subset_num_batches=icl_subset_num_batches, - ) - - # Callbacks - callbacks: List[Callback] = [ - build_callback(name=str(name), kwargs=callback_cfg) - for name, callback_cfg in callback_configs.items() - ] if callback_configs else [] - - if eval_gauntlet_callback is not None: - callbacks.append(eval_gauntlet_callback) - - if metadata is not None: - # Find the MosaicMLLogger - mosaicml_logger = find_mosaicml_logger(loggers) - - if mosaicml_logger is not None: - mosaicml_logger.log_metrics(metadata) - mosaicml_logger._flush_metadata(force_flush=True) - - if fsdp_config and model.get('load_in_8bit', False): - raise ValueError( - 'The FSDP config block is not supported when loading ' + - 'Hugging Face models in 8bit.', - ) - - init_context = process_init_device(model, fsdp_config) - - name = model.pop('name') - composer_model = build_composer_model( - name=name, - tokenizer=tokenizer, - init_context=init_context, - cfg=model, - ) - - # Now add the eval metrics - if eval_loader_config is not None: - train_metrics = composer_model.get_metrics(is_train=True) - evaluators = add_metrics_to_eval_loaders( - evaluators, - list(train_metrics.keys()), - ) - - if eval_gauntlet_df is None and eval_gauntlet_callback is not None: - eval_gauntlet_df = pd.DataFrame( - columns=['model_name'] + list(eval_gauntlet_callback.averages) + - [t['name'] for t in eval_gauntlet_callback.categories], - ) - - if name == 'mpt_causal_lm' and load_path is None: - raise ValueError( - 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' - + - ' Please check your yaml and the model_cfg to ensure that load_path is set.', - ) - - assert composer_model is not None - - log.info(f'Building trainer for {model_name}...') - trainer = Trainer( - run_name=run_name, - seed=seed, - model=composer_model, - callbacks=callbacks, - loggers=loggers, - precision=precision, - fsdp_config=fsdp_config, - load_path=load_path, - load_weights_only=True, - progress_bar=False, - log_to_console=True, - dist_timeout=dist_timeout, - python_log_level=python_log_level, - ) - - if should_log_config: - log.info('Evaluation config:') - log_config(logged_config) - - log.info(f'Starting eval for {model_name}...') - if torch.cuda.is_available(): - torch.cuda.synchronize() - a = time.time() - trainer.eval( - eval_dataloader=evaluators, - subset_num_batches=eval_subset_num_batches, - ) - if torch.cuda.is_available(): - torch.cuda.synchronize() - b = time.time() - - log.info(f'Ran {model_name} eval in: {b-a} seconds') - return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) - - -def allow_toplevel_keys(cfg: Dict[str, Any]) -> Dict[str, Any]: - """Transform the config to allow top-level keys for model configuration. - - This function allows users to use the 'train.py' syntax in 'eval.py'. - It converts a config with top-level 'model', 'tokenizer', and (optionally) 'load_path' keys - into the nested 'models' list format required by 'eval.py'. - - Input config format (train.py style): - ```yaml - model: - - load_path: /path/to/checkpoint - tokenizer: - - ``` - - Output config format (eval.py style): - ```yaml - models: - - model: - - tokenizer: - - load_path: /path/to/checkpoint - ``` - """ - if 'model' in cfg: - if 'models' in cfg: - raise ValueError( - 'Please specify either model or models in the config, not both', - ) - default_name = cfg.get('model').get('name') - model_cfg = { - 'model': cfg.pop('model'), - 'tokenizer': cfg.pop('tokenizer', None), - 'model_name': cfg.pop('model_name', default_name), - } - if 'tokenizer' not in model_cfg or model_cfg['tokenizer'] is None: - raise ValueError( - 'When specifying model, "tokenizer" must be provided in the config', - ) - if 'load_path' in cfg: - model_cfg['load_path'] = cfg.pop('load_path') - cfg['models'] = [model_cfg] - - return cfg - - -def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: - # Run user provided code if specified - for code_path in cfg.get('code_paths', []): - import_file(code_path) - - logged_cfg, eval_config = make_dataclass_and_log_config( - cfg, - EvalConfig, - EVAL_CONFIG_KEYS, - transforms=[allow_toplevel_keys], - icl_tasks_required=True, - ) - - model_configs = eval_config.models - eval_gauntlet_config = eval_config.eval_gauntlet or eval_config.eval_gauntlet_str - - fsdp_config = eval_config.fsdp_config - - # Mandatory Evaluation Parameters - icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str - if icl_tasks is None: - raise ValueError('icl_tasks must be specified in the config') - - # Optional Evaluation Parameters with default values - eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name = eval_config.run_name if eval_config.run_name else default_run_name - - reproducibility.seed_all(eval_config.seed) - dist.initialize_dist(get_device(None), timeout=eval_config.dist_timeout) - - if eval_config.python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging.getLogger('llmfoundry').setLevel( - eval_config.python_log_level.upper(), - ) - - # Default argument values for evaluate_model - eval_gauntlet_df = None - models_df = None - composite_scores = None - trainers = [] - - # Build loggers - loggers: List[LoggerDestination] = [ - build_logger(name, logger_cfg) - for name, logger_cfg in (eval_config.loggers or {}).items() - ] - - mosaicml_logger = find_mosaicml_logger(loggers) - if mosaicml_logger is None: - mosaicml_logger = maybe_create_mosaicml_logger() - # mosaicml_logger will be None if run isn't on MosaicML platform - if mosaicml_logger is not None: - loggers.append(mosaicml_logger) - - # mosaicml_logger will be None if the run isn't from the MosaicML platform - if mosaicml_logger is not None: - log_eval_analytics( - mosaicml_logger, - model_configs, - icl_tasks, - eval_gauntlet_config, - ) - - for model_cfg in model_configs: - - attn_config = model_cfg['model'].get('attn_config', None) - if attn_config is not None: - seq_parallel_world_size = attn_config.get( - 'seq_parallel_world_size', - None, - ) - if seq_parallel_world_size is not None and seq_parallel_world_size != 1: - raise ValueError( - 'Offline eval does not support sequence parallelism.', - ) - - (trainer, logger_keys, eval_gauntlet_callback, - eval_gauntlet_df) = evaluate_model( - dist_timeout=eval_config.dist_timeout, - run_name=run_name, - seed=eval_config.seed, - icl_tasks=icl_tasks, - max_seq_len=eval_config.max_seq_len, - device_eval_batch_size=eval_config.device_eval_batch_size, - eval_gauntlet_config=eval_gauntlet_config, - eval_loader_config=eval_loader_config, - fsdp_config=fsdp_config, - loggers=loggers, - python_log_level=eval_config.python_log_level, - precision=eval_config.precision, - eval_gauntlet_df=eval_gauntlet_df, - callback_configs=eval_config.callbacks, - eval_subset_num_batches=eval_config.eval_subset_num_batches, - icl_subset_num_batches=eval_config.icl_subset_num_batches, - metadata=eval_config.metadata, - logged_config=logged_cfg, - should_log_config=eval_config.log_config, - **model_cfg, - ) - trainers.append(trainer) - - if eval_gauntlet_callback is not None: - composite_scores = eval_gauntlet_callback.eval_after_all( - trainer.state, - trainer.logger, - ) - - benchmark_to_taxonomy = {} - if eval_gauntlet_callback is not None: - for t in eval_gauntlet_callback.categories: - for b in t['benchmarks']: - benchmark_to_taxonomy[b['name']] = t['name'] - - assert 'model_name' in model_cfg, 'model_name must be specified in model config' - model_results = calculate_markdown_results( - logger_keys, - trainer, - benchmark_to_taxonomy, - model_cfg['model_name'], - ) - - if models_df is None: - models_df = model_results - else: - models_df = pd.concat([models_df, model_results], ignore_index=True) - - if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: - assert composite_scores is not None - row = {'model_name': model_cfg['model_name']} - row.update({ - k.split('/')[-1]: v for k, v in composite_scores.items() - }) - eval_gauntlet_df = pd.concat([ - eval_gauntlet_df, - pd.DataFrame([row]), - ], - ignore_index=True) - - print(f'Printing gauntlet results for all models') - - print( - eval_gauntlet_df.sort_values( - list(eval_gauntlet_callback.averages.keys())[0], - ascending=False, - ).to_markdown(index=False), - ) - print(f'Printing complete results for all models') - assert models_df is not None - print(models_df.to_markdown(index=False)) - - trainer.close() - - return trainers, eval_gauntlet_df - - -def calculate_markdown_results( - logger_keys: List[str], - trainer: Trainer, - benchmark_to_taxonomy: Dict[str, str], - model_name: str, -): - results = {} - - for key in logger_keys: - # dl_name is either 2-tuple (benchmark_name, num_fewshot) - # or 3-tuple (benchmark_name, num_fewshot, subcategory) - dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1] - if 'Accuracy' not in metric_name: - continue - - metric = trainer.state.eval_metrics.get('/'.join(dl_name), - {}).get(metric_name, None) - - if metric is None: - continue - if dl_name[1] not in results: - results[dl_name[1]] = {} - - if dl_name[0] not in results[dl_name[1]]: - results[dl_name[1]][dl_name[0]] = {} - - if metric_name not in results[dl_name[1]][dl_name[0]]: - results[dl_name[1]][dl_name[0]][metric_name] = [] - - results[dl_name[1]][dl_name[0]][metric_name].append({ - 'val': metric.compute(), - 'subcat': dl_name[-1] if len(dl_name) == 3 else 'no_subcat', - }) - - df = pd.DataFrame( - columns=[ - 'Category', - 'Benchmark', - 'Subtask', - 'Accuracy', - 'Number few shot', - 'Model', - ], - ) - - for num_shot in results: - for benchmark in results[num_shot]: - for metric in results[num_shot][benchmark]: - subscores = results[num_shot][benchmark][metric] - if len(subscores) == 1: - row = { - 'Category': benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': benchmark, - 'Subtask': None, - 'Accuracy': subscores[0]['val'], - 'Number few shot': num_shot, - 'Model': model_name, - } - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) - else: - row = { - 'Category': - benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': - benchmark, - 'Subtask': - 'Average', - 'Accuracy': - sum(s['val'] for s in subscores) / len(subscores), - 'Number few shot': - num_shot, - 'Model': - model_name, - } - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True) - for sub in subscores: - row = { - 'Category': - benchmark_to_taxonomy.get(benchmark, ''), - 'Benchmark': - None, - 'Subtask': - sub['subcat'], - 'Accuracy': - sub['val'], - 'Number few shot': - num_shot, - 'Model': - model_name, - } - df = pd.concat([df, pd.DataFrame([row])], - ignore_index=True) - return df +from llmfoundry.command_utils import eval_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - assert isinstance(cfg, DictConfig) - main(cfg) + eval_from_yaml(yaml_path, args_list) diff --git a/scripts/train/train.py b/scripts/train/train.py index 3c8973048b..728010d13a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from llmfoundry.train import train_from_yaml +from llmfoundry.command_utils import train_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] diff --git a/tests/a_scripts/data_prep/test_convert_dataset_hf.py b/tests/a_scripts/data_prep/test_convert_dataset_hf.py index 4c5d1a6bba..e09c54ca70 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_hf.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_hf.py @@ -2,29 +2,26 @@ # SPDX-License-Identifier: Apache-2.0 import os -from argparse import Namespace from pathlib import Path -from scripts.data_prep.convert_dataset_hf import main as main_hf +from llmfoundry.command_utils import convert_dataset_hf def test_download_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-c4-1') - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': ['val_xsmall'], - 'out_root': path, - 'compression': None, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=['val_xsmall'], + out_root=path, + compression=None, + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, + tokenizer=None, + tokenizer_kwargs={}, ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_dataset_json.py b/tests/a_scripts/data_prep/test_convert_dataset_json.py index 912e44cd0c..4f70a35637 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_json.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_json.py @@ -2,28 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 import os -from argparse import Namespace from pathlib import Path -from scripts.data_prep.convert_dataset_json import main as main_json +from llmfoundry.command_utils import convert_dataset_json def test_json_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-arxiv-1') - main_json( - Namespace( - **{ - 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': path, - 'compression': None, - 'split': 'train', - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path='scripts/data_prep/example_data/arxiv.jsonl', + out_root=path, + compression=None, + split='train', + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py index 8dac151f55..f4c160790a 100644 --- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -13,11 +13,7 @@ from streaming import StreamingDataset from transformers import AutoTokenizer -from llmfoundry.utils.exceptions import ( - InputFolderMissingDataError, - OutputFolderNotEmptyError, -) -from scripts.data_prep.convert_text_to_mds import ( +from llmfoundry.command_utils.data_prep.convert_text_to_mds import ( DONE_FILENAME, convert_text_to_mds, download_and_convert, @@ -25,6 +21,10 @@ merge_shard_groups, write_done_file, ) +from llmfoundry.utils.exceptions import ( + InputFolderMissingDataError, + OutputFolderNotEmptyError, +) class MockObjectStore(): @@ -83,15 +83,15 @@ def _assert_files_exist(prefix: str, files: List[str]): @pytest.mark.parametrize('processes', [1, 2, 3]) @patch.object(ProcessPoolExecutor, 'map', new=Mock(wraps=_mock_map)) @patch( - 'scripts.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.maybe_create_object_store_from_uri', ) -@patch('scripts.data_prep.convert_text_to_mds.parse_uri') +@patch('llmfoundry.command_utils.data_prep.convert_text_to_mds.parse_uri') @patch( - 'scripts.data_prep.convert_text_to_mds.download_and_convert', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.download_and_convert', wraps=download_and_convert, ) @patch( - 'scripts.data_prep.convert_text_to_mds.merge_shard_groups', + 'llmfoundry.command_utils.data_prep.convert_text_to_mds.merge_shard_groups', wraps=merge_shard_groups, ) def test_single_and_multi_process( diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index 01f3760d26..fc0dc8a882 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -11,10 +11,10 @@ from composer import Trainer from composer.loggers import InMemoryLogger +from llmfoundry.command_utils import evaluate from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container -from scripts.eval.eval import main # noqa: E402 from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg @@ -75,7 +75,7 @@ def test_icl_eval( eval_cfg = copy.deepcopy(eval_cfg) eval_cfg.models[0].load_path = mock_saved_model_path assert isinstance(eval_cfg, om.DictConfig) - main(eval_cfg) + evaluate(eval_cfg) out, _ = capfd.readouterr() expected_results = '| Category | Benchmark | Subtask | Accuracy | Number few shot | Model |\n|:----------------------------|:---------------|:----------|-----------:|:------------------|:---------|\n| language_understanding_lite | lambada_openai | | 0 | 0-shot | tiny_mpt |' assert expected_results in out @@ -135,14 +135,14 @@ def test_loader_eval( test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})}) # This test uses a training yaml with training-only keys present. - # We exclude these keys before calling `main` from the eval script. + # We exclude these keys before calling `evaluate` from the eval script. allowed_keys = EVAL_CONFIG_KEYS present_keys = set(test_cfg.keys()) keys_to_pop = present_keys.difference(allowed_keys) [test_cfg.pop(key) for key in keys_to_pop] - trainers, eval_gauntlet_df = main(test_cfg) + trainers, eval_gauntlet_df = evaluate(test_cfg) assert eval_gauntlet_df is None assert len(trainers) == 1 # one per model diff --git a/tests/a_scripts/eval/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py index 0ca5765a26..86243ba154 100644 --- a/tests/a_scripts/eval/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -8,7 +8,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.eval.eval import main # noqa: E402 +from llmfoundry.command_utils import evaluate class TestHuggingFaceEvalYAMLInputs: @@ -44,7 +44,7 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: ValueError, )): cfg[p + '-mispelled'] = cfg.pop(p) - main(cfg) + evaluate(cfg) cfg[p] = cfg.pop(p + '-mispelled') def test_optional_mispelled_params_raise_error( @@ -68,7 +68,7 @@ def test_optional_mispelled_params_raise_error( updated_param = param + '-mispelling' cfg[updated_param] = orig_value with pytest.raises(ValueError): - main(cfg) + evaluate(cfg) # restore configs. cfg = copy.deepcopy(old_cfg) @@ -105,4 +105,4 @@ def test_empty_load_path_raises_error(self, cfg: DictConfig) -> None: + ' Please check your yaml and the model_cfg to ensure that load_path is set.' cfg.models[0].load_path = None with pytest.raises(ValueError, match=error_string): - main(cfg) + evaluate(cfg) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index a49f1ac07a..1f724a6070 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,8 +11,8 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from llmfoundry.train import TrainConfig # noqa: E402 -from llmfoundry.train import TRAIN_CONFIG_KEYS, train, validate_config +from llmfoundry.command_utils import TrainConfig # noqa: E402 +from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 328a06a69e..73540afe2f 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry.train import train # noqa: E402 +from llmfoundry.command_utils import train def make_fake_index_file(path: str) -> None: diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index a489002399..21d73c0d34 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -5,7 +5,6 @@ import pathlib import random import shutil -from argparse import Namespace from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import ContextManager, Literal, Optional, Union @@ -22,6 +21,7 @@ from streaming import MDSWriter from streaming.base.util import clean_stale_shared_memory +from llmfoundry.command_utils import convert_dataset_hf from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import ( _HF_IGNORE_INDEX, @@ -56,7 +56,6 @@ UnknownExampleTypeError, ) # yapf: enable -from scripts.data_prep.convert_dataset_hf import main as main_hf from scripts.data_prep.convert_finetuning_dataset import get_columns_and_format from tests.data_utils import ( make_tiny_conversation_ft_dataset, @@ -204,42 +203,34 @@ def test_correct_padding( path = get_abs_data_path(data_local) shutil.rmtree(path, ignore_errors=True) if pretokenize: - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [split], - 'out_root': path, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': tokenizer_name, - 'tokenizer_kwargs': {}, - 'bos_text': bos_text, - 'eos_text': eos_text, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[split], + out_root=path, + compression=None, + concat_tokens=2048, + tokenizer=tokenizer_name, + tokenizer_kwargs={}, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=False, + num_workers=None, ) else: - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [split], - 'out_root': path, - 'compression': None, - 'concat_tokens': None, - 'tokenizer': tokenizer_name, - 'tokenizer_kwargs': {}, - 'bos_text': bos_text, - 'eos_text': eos_text, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[split], + out_root=path, + compression=None, + concat_tokens=None, + tokenizer=tokenizer_name, + tokenizer_kwargs={}, + bos_text=bos_text, + eos_text=eos_text, + no_wrap=False, + num_workers=None, ) if not os.path.isdir(path): raise RuntimeError(f'c4 dataset at {path} not set up as expected') diff --git a/tests/data_utils.py b/tests/data_utils.py index 9653d8579a..ea64943735 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -4,16 +4,16 @@ import json import os import shutil -from argparse import Namespace from pathlib import Path from typing import Dict, List, Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om -from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 -from scripts.data_prep.convert_dataset_json import \ - main as main_json # noqa: E402 +from llmfoundry.command_utils import ( + convert_dataset_hf, + convert_dataset_json, +) def make_tiny_ft_dataset( @@ -230,23 +230,19 @@ def create_c4_dataset_xxsmall(path: Path) -> str: downloaded_split = 'val_xxsmall' # very fast to convert # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [downloaded_split], - 'out_root': c4_dir, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': 'EleutherAI/gpt-neox-20b', - 'tokenizer_kwargs': {}, - 'bos_text': '', - 'eos_text': '<|endoftext|>', - 'no_wrap': False, - 'num_workers': 8, - }, - ), + convert_dataset_hf( + dataset='c4', + data_subset='en', + splits=[downloaded_split], + out_root=c4_dir, + compression=None, + concat_tokens=2048, + tokenizer='EleutherAI/gpt-neox-20b', + tokenizer_kwargs={}, + bos_text='', + eos_text='<|endoftext|>', + no_wrap=False, + num_workers=8, ) # copy the small downloaded_split to other c4 splits for mocking purposes @@ -269,20 +265,16 @@ def create_arxiv_dataset(path: Path) -> str: if not os.getcwd().endswith('scripts'): arxiv_path = os.path.join('scripts', arxiv_path) - main_json( - Namespace( - **{ - 'path': arxiv_path, - 'out_root': arxiv_dir, - 'compression': None, - 'split': downloaded_split, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None, - }, - ), + convert_dataset_json( + path=arxiv_path, + out_root=arxiv_dir, + compression=None, + split=downloaded_split, + concat_tokens=None, + bos_text='', + eos_text='', + no_wrap=False, + num_workers=None, ) return arxiv_dir diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py new file mode 100644 index 0000000000..bdffe2b49f --- /dev/null +++ b/tests/models/layers/test_attention.py @@ -0,0 +1,160 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from llmfoundry.models.layers.layer_builders import build_attention_layer + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_unfused_wqkv(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config_fused = generic_attn_kwargs.copy() + attn_config_fused['fused_qkv'] = True + + attn_config_unfused = generic_attn_kwargs.copy() + attn_config_unfused['fused_qkv'] = False + + attn_fused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_fused, + ) + attn_unfused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_unfused, + ) + + # Make sure unfused attention has the same params as the fused one. + fused_wqkv = attn_fused.Wqkv.weight.detach().clone() + kv_heads_len = (fused_wqkv.shape[0] - dim) // 2 + Wq_shape_before = (attn_unfused.Wq.weight.shape, attn_unfused.Wq.bias.shape) + Wk_shape_before = (attn_unfused.Wk.weight.shape, attn_unfused.Wk.bias.shape) + Wv_shape_before = (attn_unfused.Wv.weight.shape, attn_unfused.Wv.bias.shape) + + attn_unfused.Wq.weight.data = fused_wqkv[:dim, :] + attn_unfused.Wk.weight.data = fused_wqkv[dim:dim + kv_heads_len, :] + attn_unfused.Wv.weight.data = fused_wqkv[dim + kv_heads_len:, :] + attn_unfused.out_proj.weight.data = attn_fused.out_proj.weight + attn_unfused.Wq.bias.data = attn_fused.Wqkv.bias[:dim] + attn_unfused.Wk.bias.data = attn_fused.Wqkv.bias[dim:dim + kv_heads_len] + attn_unfused.Wv.bias.data = attn_fused.Wqkv.bias[dim + kv_heads_len:] + attn_unfused.out_proj.bias.data = attn_fused.out_proj.bias + + # Make sure initialization fuse splits are as expected. + all_fuse_splits = ( + 0, + [i * d_head for i in range(1, n_heads + 2 * kv_n_heads)], + ) + q_fuse_splits = (0, [i * d_head for i in range(1, n_heads)]) + kv_fuse_splits = (0, [i * d_head for i in range(1, kv_n_heads)]) + + assert attn_fused.Wqkv._fused == all_fuse_splits + assert attn_unfused.Wq._fused == q_fuse_splits + assert attn_unfused.Wk._fused == kv_fuse_splits + assert attn_unfused.Wv._fused == kv_fuse_splits + + assert torch.allclose( + attn_fused.Wqkv.weight, + torch.cat( + [ + attn_unfused.Wq.weight, + attn_unfused.Wk.weight, + attn_unfused.Wv.weight, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.Wqkv.bias, + torch.cat( + [ + attn_unfused.Wq.bias, + attn_unfused.Wk.bias, + attn_unfused.Wv.bias, + ], + dim=0, + ), + ) + assert torch.allclose( + attn_fused.out_proj.weight, + attn_unfused.out_proj.weight, + ) + assert torch.allclose(attn_fused.out_proj.bias, attn_unfused.out_proj.bias) + + assert Wq_shape_before == ( + attn_unfused.Wq.weight.shape, + attn_unfused.Wq.bias.shape, + ) + assert Wk_shape_before == ( + attn_unfused.Wk.weight.shape, + attn_unfused.Wk.bias.shape, + ) + assert Wv_shape_before == ( + attn_unfused.Wv.weight.shape, + attn_unfused.Wv.bias.shape, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + x1.requires_grad = True + x2.requires_grad = True + + out_fused, _, _ = attn_fused(x1) + out_unfused, _, _ = attn_unfused(x2) + + assert torch.allclose(out_fused, out_unfused) + + # Dummy loss function is simply the sum. + loss_fused = out_fused.sum() + loss_fused.backward() + + loss_unfused = out_unfused.sum() + loss_unfused.backward() + + assert isinstance(x1.grad, torch.Tensor) + assert isinstance(x2.grad, torch.Tensor) + assert torch.allclose(x1.grad, x2.grad) + combined_grad = torch.concat( + [ + attn_unfused.Wq.weight.grad, + attn_unfused.Wk.weight.grad, + attn_unfused.Wv.weight.grad, + ], + dim=0, + ) + assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor) + assert isinstance(combined_grad, torch.Tensor) + assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad) diff --git a/tests/test_registry.py b/tests/test_registry.py index 7ee95442c8..aa0c93ee13 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -44,6 +44,8 @@ def test_expected_registries_exist(): 'fcs', 'icl_datasets', 'config_transforms', + 'load_planners', + 'save_planners', } assert existing_registries == expected_registry_names diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index dfcb5b327c..fb6cb0c5df 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -13,17 +13,24 @@ from composer.callbacks import Generate from composer.core import Evaluator from composer.loggers import WandBLogger +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer +from llmfoundry.registry import load_planners, save_planners from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_callback, build_eval_loaders, build_evaluators, + build_load_planner, build_logger, build_optimizer, + build_save_planner, build_tokenizer, ) @@ -345,6 +352,34 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): assert eval_loaders2[1].metric_names == [] +def test_build_load_planner(): + # Dummy LoadPlanner for testing + class DummyLoadPlanner(DefaultLoadPlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + load_planners.register('dummy', func=DummyLoadPlanner) + load_planner = build_load_planner('dummy', is_test=True) + + assert isinstance(load_planner, DummyLoadPlanner) + assert load_planner.is_test is True + + +def test_build_save_planner(): + # Dummy SavePlanner for testing + class DummySavePlanner(DefaultSavePlanner): + + def __init__(self, is_test: bool): + self.is_test = is_test + + save_planners.register('dummy', func=DummySavePlanner) + save_planner = build_save_planner('dummy', is_test=True) + + assert isinstance(save_planner, DummySavePlanner) + assert save_planner.is_test is True + + def test_add_metrics_to_eval_loaders(): evaluators = [ Evaluator(