Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long context evals using hugging face hosted datasets #709

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions llmfoundry/utils/builders.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @maxisawesome!

It might be worth passing in the hugging face variables into the get_icl_task_dataloader function. Maybe add

hf_loading_vars=icl_cfg.get('hf_loading_vars', {}),
hf_parsing_map=icl_cfg.get('hf_parsing_map', {})

in line 304 originally and in 358 in your new commit. These allows you to pass parameters into hugging face's load_dataset function. In particular, this was helpful in specifying which split of the hugging face dataset, I'd like to evaluate such as hf_loading_vars = {'split': 'train'}.

Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import datasets as hf_datasets
import json
from composer import algorithms
from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
MemoryMonitor, OptimizerMonitor,
RuntimeEstimator, SpeedMonitor)
from composer.callbacks import (EarlyStopper, Generate, LRMonitor, MemoryMonitor,
OptimizerMonitor, RuntimeEstimator, EvalOutputLogging,
SpeedMonitor)
from composer.core import Algorithm, Callback, Evaluator
from composer.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
Expand Down Expand Up @@ -118,6 +120,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
return HuggingFaceCheckpointer(**kwargs)
elif name == 'eval_output_logging':
return EvalOutputLogging(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down Expand Up @@ -219,6 +223,51 @@ def build_tokenizer(

return tokenizer

def prep_hf_dataset(icl_cfg: ListConfig):
"""
Temporary hack to read HF datasets while the composer PR is still WIP
"""
hf_dataset_uri = icl_cfg.dataset_uri.replace("hf://", "")
dataset_args = icl_cfg.hf_vars
if "split" not in dataset_args:
dataset_args["split"] = "test"

# TODO: should I use tmp here?
output_filepath = icl_cfg.dataset_uri.replace("hf://", "/tmp/").replace("/", "_") + '_'.join([str(dataset_arg) for dataset_arg in dataset_args.values()]) + '.jsonl'
if os.path.isfile(output_filepath):
print(f"Output file already exists for {icl_cfg.label}, skipping dataset processing and saving")
else:
print(f"Processing {icl_cfg.label}")
dataset = hf_datasets.load_dataset(hf_dataset_uri, **dataset_args)
if "pivot_col" in icl_cfg.hf_cols:
def _augment_data(examples):
outputs = []
contexts = []
for i, doc in enumerate(examples[icl_cfg.hf_cols["pivot_col"]]):
for j in range(len(examples[icl_cfg.hf_cols["inputs"][0]][i])):
instruction = ''.join([examples[input_col][i][j] for input_col in icl_cfg.hf_cols["inputs"]])
contexts.append(doc + "\n" + instruction)
outputs.append(''.join([examples[output_col][i][j] for output_col in icl_cfg.hf_cols['outputs']]))
return {"context": contexts, "answer": outputs}
dataset = dataset.map(
_augment_data,
batched=True,
remove_columns=dataset.column_names,
batch_size=1000
)
else:
dataset = dataset.map(
lambda example: {
"context": ''.join([str(example[col]) for col in icl_cfg.hf_cols['inputs']]),
"answer": ''.join([str(example[col]) for col in icl_cfg.hf_cols['outputs']])
}
)
with open(output_filepath, 'w') as outfile:
for entry in dataset:
json.dump(entry, outfile)
outfile.write('\n')
return output_filepath


def build_icl_evaluators(
icl_tasks: Union[str, ListConfig],
Expand Down Expand Up @@ -284,6 +333,7 @@ def _validate_cfg(icl_cfg: DictConfig):
if 'num_beams' not in icl_cfg:
icl_cfg.num_beams = 20


for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
_validate_cfg(icl_cfg)
Expand All @@ -301,6 +351,10 @@ def _validate_cfg(icl_cfg: DictConfig):
os.remove(destination_path)
dist.barrier()

if "hf://" in icl_cfg.dataset_uri:
new_uri = prep_hf_dataset(icl_cfg)
icl_cfg.dataset_uri = new_uri

dataloaders = get_icl_task_dataloader(
icl_cfg.icl_task_type,
icl_cfg.dataset_uri,
Expand Down
30 changes: 22 additions & 8 deletions mcli/mcli-hf-eval.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: v0.3.0
git_branch: output_eval_logging
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu]"
ssh_clone: false # Should be true if using a private repo

command: |
pip uninstall mosaicml -y
pip install git+https://github.com/bmosaicml/composer.git@error_logging_callback
cd llm-foundry/scripts
composer eval/eval.py /mnt/config/parameters.yaml

# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
run_name: mpt-eval
run_name: output-logger-test
gpu_num: 8
# gpu_type:
# cluster: # replace with your cluster here!
gpu_type: a100_80gb
cluster: r1z1 # replace with your cluster here!

image: mosaicml/llm-foundry:2.0.1_cu118-latest

Expand All @@ -31,13 +33,13 @@ parameters:
model_name: mosaicml/mpt-7b-instruct
# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
name: mosaicml/mpt-7b-instruct
kwargs:
model_max_length: ${max_seq_len}

model:
name: hf_causal_lm
pretrained_model_name_or_path: mosaicml/mpt-7b-instruct
pretrained_model_name_or_path: mosaicml/mpt-7b-instruct
init_device: mixed
pretrained: true
use_auth_token: false
Expand All @@ -50,5 +52,17 @@ parameters:
limit_all_gathers: True


icl_tasks: 'eval/yamls/tasks.yaml'
eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml'
icl_tasks:
-
label: jeopardy
dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [10]
icl_task_type: language_modeling
continuation_delimiter: "\nAnswer: " # this separates questions from answers
has_categories: true

callbacks:
eval_output_logging:
subset_sample: -1
output_directory: s3://mosaicml-internal-checkpoints-test/test_icl_output_logger_7b

68 changes: 68 additions & 0 deletions mcli/mcli-rlhf-eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
git_branch: output_eval_logging
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu]"
ssh_clone: false # Should be true if using a private repo

command: |
pip uninstall mosaicml -y
pip install git+https://github.com/bmosaicml/composer.git@error_logging_callback
cd llm-foundry/scripts
composer eval/eval.py /mnt/config/parameters.yaml

# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
run_name: output-logger-rlhf-prompts
gpu_num: 8
gpu_type: a100_80gb
cluster: r1z1 # replace with your cluster here!

image: mosaicml/llm-foundry:2.0.1_cu118-latest

# The below is injected as a YAML file: /mnt/config/parameters.yaml
parameters:
dist_timeout: 6000
seed: 1
max_seq_len: 1024
device_eval_batch_size: 1
precision: amp_fp16

models:
-
model_name: mosaicml/mpt-30b-instruct
# Tokenizer
tokenizer:
name: mosaicml/mpt-30b-instruct
kwargs:
model_max_length: ${max_seq_len}

model:
name: hf_causal_lm
pretrained_model_name_or_path: mosaicml/mpt-30b-instruct
init_device: mixed
pretrained: true
use_auth_token: false

# FSDP config for model sharding
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: FULL
forward_prefetch: True
limit_all_gathers: True


icl_tasks:
-
label: rlhf_prompts
dataset_uri: eval/local_data/rlhf_prompts/rlhf_prompts.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [0]
icl_task_type: question_answering
has_categories: true

callbacks:
eval_output_logging:
print_only_incorrect: false
subset_sample: -1
output_directory: s3://mosaicml-internal-checkpoints-test/30b_instruct_rlhf_prompts

26 changes: 21 additions & 5 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import warnings
from typing import Any, Dict, List, Optional, Union

from composer.core.callback import Callback
import pandas as pd
import torch
from composer.loggers.logger_destination import LoggerDestination
Expand All @@ -21,7 +21,7 @@

from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils.builders import (build_icl_data_and_gauntlet,
from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, build_callback,
build_logger, build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device

Expand Down Expand Up @@ -107,6 +107,7 @@ def evaluate_model(
precision: str,
eval_gauntlet_df: Optional[pd.DataFrame],
icl_subset_num_batches: Optional[int],
callback_configs: Optional[Dict]
):

print(f'Evaluating model: {model_cfg.model_name}', flush=True)
Expand All @@ -122,7 +123,12 @@ def evaluate_model(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
max_seq_len, icl_subset_num_batches)

callbacks = []
# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)

Expand Down Expand Up @@ -174,6 +180,7 @@ def evaluate_model(
dist_timeout=dist_timeout,
python_log_level=python_log_level,
)


if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down Expand Up @@ -252,7 +259,11 @@ def main(cfg: DictConfig):
default_value=None)
# Pop out interpolation variables.
pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None)

callback_configs: Optional[DictConfig] = pop_config(cfg,
'callbacks',
must_exist=False,
default_value=None)

# Warn for unused parameters
for key in cfg:
warnings.warn(
Expand Down Expand Up @@ -291,7 +302,9 @@ def main(cfg: DictConfig):
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
icl_subset_num_batches=icl_subset_num_batches)
icl_subset_num_batches=icl_subset_num_batches,
callback_configs=callback_configs
)

if eval_gauntlet_callback is not None:
composite_scores = eval_gauntlet_callback.eval_after_all(
Expand Down Expand Up @@ -331,6 +344,9 @@ def main(cfg: DictConfig):
print(models_df.to_markdown(index=False))





def calculate_markdown_results(logger_keys: List[str], trainer: Trainer,
benchmark_to_taxonomy: Dict[str, str],
model_name: str):
Expand Down
Loading
Loading