Skip to content

Commit

Permalink
Merge branch 'migrate_subclasses_to_foundry' into openai_compatible_g…
Browse files Browse the repository at this point in the history
…auntlet
  • Loading branch information
bmosaicml committed Apr 12, 2024
2 parents bb2728b + 03f7e91 commit 65f1a3e
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 33 deletions.
33 changes: 23 additions & 10 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import random
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union

import torch
import transformers
Expand Down Expand Up @@ -93,14 +93,17 @@ class InContextLearningDataset(Dataset):
strip_dataset (bool): Boolean for whether to strip whitespace from data. Trailing whitespace can cause degenerative outputs,
so unless whitespace should be preserved (for example in code), this should be set to True.
padding_side (str): Side of the content and answer on which to apply padding. Can be either 'right' or 'left'.
tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses.
padding_size (int): The final size of the tensor after padding. Defaults to max_sequence_length.
base_batch (Dict): The base dictionary upon which a batch is created. See above for more details.
base_mapping (Dict): A mapping of batch keys to dataset columns, used to create batches. See above for more details.
hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF.
hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}.
Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset.
tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses.
generation_kwargs (Dict): A dictionary containing keyword arguments to be passed along to the model's generate function.
static_keys (List): A list of the key values which will be broadcast across a batch (e.g. it is the same for each batch element).
list_keys (List): A list of the batch keys whose values are lists which will be split using list methods during calls to split_batch.
tensor_keys (List): A list of the batch keys whose values are tensors which will be split using tensor methods during calls to split_batch.
"""

def __init__(
Expand All @@ -121,15 +124,15 @@ def __init__(
strip_dataset: bool = True,
padding_side: str = 'right',
tokenize_labels: bool = True,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
padding_size: Optional[int] = None,
base_batch: Optional[Dict] = None,
batch_mapping: Optional[Dict] = None,
hf_loading_vars: Optional[Dict] = None,
hf_parsing_map: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = None,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
):
self.tokenizer = tokenizer
self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer)
Expand Down Expand Up @@ -473,21 +476,24 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
return batch

def split_batch(self, batch: Any,
microbatch_size: int) -> List[Dict[str, Any]]:
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Handling for certain specialty columns that must be split into.
batches in different formats.
Args:
batch (Dict): Batch of data
microbatch_size (int): Size of microbatches
microbatch_size (int | float): Size of microbatches
Returns:
List: List of chunked batches
"""
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
if isinstance(microbatch_size, float):
raise ValueError(
'split_batch does not support floating point microbatch_size.')
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
Expand Down Expand Up @@ -901,7 +907,7 @@ def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int:
return batch['input_ids'].shape[0] // self.num_choices

def split_batch(self, batch: Any,
microbatch_size: int) -> List[Dict[str, Any]]:
microbatch_size: Union[int, float]) -> Sequence[Any]:
"""Split batch while ensuring all continuations are in the same.
microbatch.
Expand All @@ -913,11 +919,14 @@ def split_batch(self, batch: Any,
microbatch_size and real attributes by microbatch_size * num_choices.
Args:
batch (Dict): Batch of data
microbatch_size (int): Size of microbatches
microbatch_size (int | float): Size of microbatches
Returns:
list: List of chunked batches
"""
if isinstance(microbatch_size, float):
raise ValueError(
'split_batch does not support floating point microbatch_size.')
chunked = {}
for k, v in batch.items():
if k in self.static_keys:
Expand Down Expand Up @@ -1175,7 +1184,7 @@ class InContextLearningCodeEvalDataset(InContextLearningDataset):
for more details):
- pad_token_id: ID for padding token, derived automatically
- num_beams: How many beams to search for generations, set to 1
- num_beams: How many beams to search for generations, default set to 1
- do_sample: Determines whether model is sampling or greedily decoding. Always set to True
- use_cache: Whether or not to use past key values to speed up sampling. Always set to True
Expand Down Expand Up @@ -1485,6 +1494,10 @@ def build_icl_dataloader(
)
effective_batchsize = batch_size
elif icl_task_type == 'code_evaluation':
warnings.warn(
VersionedDeprecationWarning(
"ICL task type 'code_evaluation' is deprecated and will no longer be supported. ",
'v0.7.0'))
dataset = InContextLearningCodeEvalDataset(
dataset_uri=dataset_uri,
tokenizer=tokenizer,
Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/eval/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility and helper functions for datasets."""
from __future__ import annotations

Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics."""

from llmfoundry.eval.metrics.nlp import (
Expand Down
1 change: 0 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, Mapping

# required for loading a python model into composer
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
Expand Down
25 changes: 21 additions & 4 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def parse_args() -> Namespace:
help='If true, reprocess the input_folder to mds format. Otherwise, ' +
'only reprocess upon changes to the input folder or dataset creation parameters.',
)
parser.add_argument(
'--trust-remote-code',
type=bool,
required=False,
default=False,
help='If true, allows custom code to be executed to load the tokenizer',
)

parsed = parser.parse_args()

Expand All @@ -124,7 +131,8 @@ def parse_args() -> Namespace:
parser.error(
'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.'
)
tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer)
tokenizer = AutoTokenizer.from_pretrained(
parsed.tokenizer, trust_remote_code=parsed.trust_remote_code)
parsed.eos_text = tokenizer.eos_token

# now that we have validated them, change BOS/EOS to strings
Expand Down Expand Up @@ -171,6 +179,7 @@ def get_task_args(
bos_text: str,
no_wrap: bool,
compression: str,
trust_remote_code: bool,
) -> Iterable:
"""Get download_and_convert arguments split across n_groups.
Expand All @@ -187,6 +196,7 @@ def get_task_args(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
num_objects = len(object_names)
objs_per_group = math.ceil(num_objects / n_groups)
Expand All @@ -202,6 +212,7 @@ def get_task_args(
bos_text,
no_wrap,
compression,
trust_remote_code,
)


Expand All @@ -223,6 +234,7 @@ def download_and_convert(
bos_text: str,
no_wrap: bool,
compression: str,
trust_remote_code: bool,
):
"""Downloads and converts text fies to MDS format.
Expand All @@ -236,6 +248,7 @@ def download_and_convert(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
object_store = maybe_create_object_store_from_uri(input_folder)

Expand All @@ -244,7 +257,8 @@ def download_and_convert(
downloading_iter = DownloadingIterable(object_names=file_names,
output_folder=tmp_dir,
object_store=object_store)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, trust_remote_code=trust_remote_code)
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
Expand Down Expand Up @@ -353,6 +367,7 @@ def convert_text_to_mds(
processes: int,
args_str: str,
reprocess: bool,
trust_remote_code: bool,
):
"""Convert a folder of text files to MDS format.
Expand All @@ -368,6 +383,7 @@ def convert_text_to_mds(
processes (int): The number of processes to use.
args_str (str): String representation of the arguments
reprocess (bool): Whether to always reprocess the given folder of text files
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
is_remote_output = is_remote_path(output_folder)

Expand Down Expand Up @@ -396,7 +412,7 @@ def convert_text_to_mds(
# Download and convert the text files in parallel
args = get_task_args(object_names, local_output_folder, input_folder,
processes, tokenizer_name, concat_tokens, eos_text,
bos_text, no_wrap, compression)
bos_text, no_wrap, compression, trust_remote_code)
with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_and_convert_starargs, args))

Expand All @@ -405,7 +421,7 @@ def convert_text_to_mds(
else:
download_and_convert(object_names, local_output_folder, input_folder,
tokenizer_name, concat_tokens, eos_text, bos_text,
no_wrap, compression)
no_wrap, compression, trust_remote_code)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down Expand Up @@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str:
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args))
except Exception as e:
if mosaicml_logger is not None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ You can use the default `icl_tasks` and `eval_gauntlet` configs or specify your

ICL evaluation measures a model’s ability to solve novel problems by being provided examples in-context without ever being specifically trained to answer such questions.

We supports a number of standard ICL formats and allow users to upload their own datasets that correspond to these formats. All of our ICL task types are implemented in `llm-foundry/llmfoundry/eval/datasets/in_context_learning_evaluation.py` while all of our ICL
We support a number of standard ICL formats and allow users to upload their own datasets that correspond to these formats. All of our ICL task types are implemented in `llm-foundry/llmfoundry/eval/datasets/in_context_learning_evaluation.py` while all of our ICL
metrics are implemented in `llm-foundry/llmfoundry/eval/metrics/nlp.py`. You can see which metrics work with which task types in the `llmfoundry.utils.builders.build_icl_evaluators` helper function.

This document explains the ICL formats compatible with [Composer](https://github.com/mosaicml/composer), summarizes how to add new datasets in those formats, and catalogs the datasets currently used by the research team to evaluate models.
Expand Down
10 changes: 5 additions & 5 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def validate_config(cfg: DictConfig):
fsdp_config = cfg.get('fsdp_config', None)
act_ckpt = fsdp_config.get('activation_checkpointing', False)
act_ckpt_reentrant = fsdp_config.get(
'activation_checkpointing_reentrant', True)
if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False:
'activation_checkpointing_reentrant', False)
if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True:
warnings.warn(
'`te.Linear` layers do not support activation_checkpointing with '
+ '`activation_checkpointing_reentrant = False`. ' +
'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.'
+ '`activation_checkpointing_reentrant = True`. ' +
'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.'
)
cfg.fsdp_config.activation_checkpointing_reentrant = True
cfg.fsdp_config.activation_checkpointing_reentrant = False

if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp':
warnings.warn(
Expand Down
3 changes: 3 additions & 0 deletions tests/a_scripts/data_prep/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None:
processes=processes,
args_str='Namespace()',
reprocess=False,
trust_remote_code=False,
)

call_convert_text_to_mds()
Expand Down Expand Up @@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool):
processes=1,
args_str='Namespace()',
reprocess=reprocess,
trust_remote_code=False,
)

# Create input text data
Expand Down Expand Up @@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path):
processes=1,
args_str='Namespace()',
reprocess=False,
trust_remote_code=False,
)


Expand Down
3 changes: 0 additions & 3 deletions tests/eval/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import os
import random
Expand Down
3 changes: 0 additions & 3 deletions tests/eval/test_nlp_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, List

import pytest
Expand Down

0 comments on commit 65f1a3e

Please sign in to comment.