Skip to content

Commit

Permalink
Merge branch 'main' into shashank/seq_id_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Nov 30, 2023
2 parents 371e3a2 + 3100859 commit fa2a2ee
Show file tree
Hide file tree
Showing 17 changed files with 815 additions and 191 deletions.
32 changes: 12 additions & 20 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
Expand All @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
81 changes: 81 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
Expand All @@ -42,6 +44,85 @@
log = logging.getLogger(__name__)


def build_evaluators(
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
icl_tasks_config: Optional[Union[str, ListConfig]],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
*,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:

evaluators = []
if eval_loader_config is not None:
evaluators = build_eval_loaders(
eval_loader_config,
tokenizer,
device_eval_batch_size,
)

logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config,
eval_gauntlet_config,
tokenizer,
device_eval_batch_size,
icl_seq_len,
icl_subset_num_batches,
)
evaluators.extend(icl_evaluators)

return evaluators, logger_keys, eval_gauntlet_callback


def build_eval_loaders(
eval_loader_config: Union[DictConfig, ListConfig],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_loader_config
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False

for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader: Evaluator = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
)
evaluators.append(eval_loader)
return evaluators


def add_metrics_to_eval_loaders(
evaluators: List[Evaluator],
metrics: Dict[str, Metric],
) -> List[Evaluator]:
metric_names = list(metrics.keys())
eval_loaders, other_evaluators = [], []
for evaluator in evaluators:
if evaluator.metric_names == []:
evaluator.metric_names = metric_names
eval_loaders.append(evaluator)
else:
other_evaluators.append(evaluator)

# Put the base eval_loaders first
return eval_loaders + other_evaluators


def build_icl_data_and_gauntlet(
icl_tasks_config: Union[str, ListConfig],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
Expand Down
11 changes: 6 additions & 5 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ def __init__(

def __iter__(self):
for object_name in self.object_names:
object_name = object_name.strip('/')
output_filename = os.path.join(self.output_folder, object_name)
# Default output_filename, used for local paths.
output_filename = object_name

# Download objects if remote path.
if self.object_store is not None:
output_filename = os.path.join(self.output_folder,
object_name.strip('/'))
self.object_store.download_object(object_name=object_name,
filename=output_filename,
overwrite=True)
else:
# Inputs are local so we do not need to download them.
output_filename = object_name

with open(output_filename) as _txt_file:
txt = _txt_file.read()
Expand Down
58 changes: 58 additions & 0 deletions llmfoundry/utils/prompt_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
from typing import List, Optional

PROMPTFILE_PREFIX = 'file::'


def load_prompts(prompts: List[str],
prompt_delimiter: Optional[str] = None) -> List[str]:
"""Loads a set of prompts, both free text and from file.
Args:
prompts (List[str]): List of free text prompts and prompt files
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
prompt_strings = []
for prompt in prompts:
if prompt.startswith(PROMPTFILE_PREFIX):
prompts = load_prompts_from_file(prompt, prompt_delimiter)
prompt_strings.extend(prompts)
else:
prompt_strings.append(prompt)
return prompt_strings


def load_prompts_from_file(prompt_path: str,
prompt_delimiter: Optional[str] = None) -> List[str]:
"""Load a set of prompts from a text fie.
Args:
prompt_path (str): Path for text file
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
if not prompt_path.startswith(PROMPTFILE_PREFIX):
raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}')

_, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1)
prompt_file_path = os.path.expanduser(prompt_file_path)
if not os.path.isfile(prompt_file_path):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')

with open(prompt_file_path, 'r') as f:
prompt_string = f.read()

if prompt_delimiter is None:
return [prompt_string]
return [i for i in prompt_string.split(prompt_delimiter) if i]
53 changes: 39 additions & 14 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
Expand All @@ -21,13 +21,14 @@

from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils.builders import (build_icl_data_and_gauntlet,
build_logger, build_tokenizer)
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
num_retries: int) -> Optional[ComposerModel]:
num_retries: int) -> ComposerModel:
try:
from peft import PeftModel
except ImportError as e:
Expand All @@ -43,7 +44,8 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
}

retries = 0
while retries < num_retries:
composer_model_wrapper = None
while retries < num_retries and composer_model_wrapper is None:
try:
trust_remote_code = model_cfg.get('trust_remote_code', True)
use_auth_token = model_cfg.get('use_auth_token', False)
Expand All @@ -58,7 +60,6 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,

composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name](
peft_model, tokenizer)
return composer_model_wrapper
except Exception as e:
retries += 1
if retries >= num_retries:
Expand All @@ -68,19 +69,21 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

assert composer_model_wrapper is not None
return composer_model_wrapper


def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
fsdp_config: Optional[Dict],
num_retries: int) -> Optional[ComposerModel]:
fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel:
init_context = process_init_device(model_cfg, fsdp_config)

retries = 0
composer_model = None
with init_context:
while retries < num_retries:
while retries < num_retries and composer_model is None:
try:
composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name](
model_cfg, tokenizer)
return composer_model
except Exception as e:
retries += 1
if retries >= num_retries:
Expand All @@ -90,6 +93,9 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

assert composer_model is not None
return composer_model


def evaluate_model(
model_cfg: DictConfig,
Expand All @@ -100,6 +106,7 @@ def evaluate_model(
max_seq_len: int,
device_eval_batch_size: int,
eval_gauntlet_config: Optional[Union[str, DictConfig]],
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
fsdp_config: Optional[Dict],
num_retries: int,
loggers_cfg: Dict[str, Any],
Expand All @@ -118,9 +125,15 @@ def evaluate_model(
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
max_seq_len, icl_subset_num_batches)
evaluators, logger_keys, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=max_seq_len,
icl_subset_num_batches=icl_subset_num_batches,
)

callbacks = []
if eval_gauntlet_callback is not None:
Expand All @@ -143,6 +156,11 @@ def evaluate_model(
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
num_retries)

# Now add the eval metrics
if eval_loader_config is not None:
train_metrics = composer_model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)

if eval_gauntlet_df is None and eval_gauntlet_callback is not None:
eval_gauntlet_df = pd.DataFrame(
columns=['model_name'] +
Expand Down Expand Up @@ -186,7 +204,7 @@ def evaluate_model(
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


def main(cfg: DictConfig):
def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
om.resolve(cfg)
model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True)
eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config(
Expand Down Expand Up @@ -228,6 +246,8 @@ def main(cfg: DictConfig):
default_value='debug')

# Optional Evaluation Parameters with default values
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17)
dist_timeout: Union[float, int] = pop_config(cfg,
'dist_timeout',
Expand Down Expand Up @@ -274,6 +294,7 @@ def main(cfg: DictConfig):
eval_gauntlet_df = None
models_df = None
composite_scores = None
trainers = []
for model_cfg in model_configs:
(trainer, logger_keys, eval_gauntlet_callback,
eval_gauntlet_df) = evaluate_model(
Expand All @@ -285,13 +306,15 @@ def main(cfg: DictConfig):
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
eval_gauntlet_config=eval_gauntlet_config,
eval_loader_config=eval_loader_config,
fsdp_config=fsdp_config,
num_retries=num_retries,
loggers_cfg=loggers_cfg,
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
icl_subset_num_batches=icl_subset_num_batches)
trainers.append(trainer)

if eval_gauntlet_callback is not None:
composite_scores = eval_gauntlet_callback.eval_after_all(
Expand Down Expand Up @@ -330,6 +353,8 @@ def main(cfg: DictConfig):
assert models_df is not None
print(models_df.to_markdown(index=False))

return trainers, eval_gauntlet_df


def calculate_markdown_results(logger_keys: List[str], trainer: Trainer,
benchmark_to_taxonomy: Dict[str, str],
Expand Down
Loading

0 comments on commit fa2a2ee

Please sign in to comment.