Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execution Prediction #2659

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'InContextLearningMultipleChoiceTaskDataset',
'InContextLearningCodeEvalDataset',
'InContextLearningQATaskDataset',
'InContextLearningExecutionPredictionTaskDataset',
'get_icl_task_dataloader',
]

Expand Down Expand Up @@ -341,6 +342,261 @@ def split_batch(self, batch: Any, microbatch_size: int):
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]


class InContextLearningExecutionPredictionTaskDataset(Dataset):
"""A dataset that construct batches for in-context learning code tracing evaluation

Args:
dataset_uri (str): Either a local path, or a remote path beginning with ``s3://``, or another backend
supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. Dataset must consist of rows of JSON data points with "context",
and "continuation". See tests/datasets/local_data/lambada_small.jsonl.
tokenizer (Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast]): The tokenizer used to transform data into batches
batch_size (int): Size of a batch used for eval
max_seq_len (int): The sequence length expected by the model
pad_tok_id (int): The special token reserved for padding the ends of batches
num_fewshot (int): The number of complete fewshot examples to prepend before each test example
prompt_string (str): Prompt string to put once before all fewshot examples/test examples (e.g. 'translate english to french')
example_delimiter (str): Separator that goes between individual (context, continuation) pairs (e.g. '\n') continuation_delimiter: (str): Separator that goes between context and continuation in each example (e.g. '->')
destination_path (str): Temporary path to store downloaded datasets
fewshot_random_seed (int): Random seed used to select fewshot examples
"""

def __init__(
self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str,
example_delimiter: str,
destination_path: str,
fewshot_random_seed: int,
generations_per_sample: int,
pass_at_k: int = 1,
top_p: Optional[float] = 0.95,
top_k: Optional[int] = 40,
):
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='datasets',
conda_channel='conda-forge') from e
with dist.local_rank_zero_download_and_wait(destination_path):
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.samples = list(
dataset.map(
lambda examples: {
'task_id': examples['task_id'],
'prompt': examples['prompt'],
'canonical_solution': examples['canonical_solution'],
'test_inputs': examples['test_inputs'],
'test_outputs': examples['test_outputs'],
'test': examples['test'],
'language': examples['language'],
}))

if generations_per_sample < pass_at_k:
raise ValueError(
f'generations_per_sample ({generations_per_sample}) must be greater than or equal to pass_at_k ({pass_at_k}) for code evaluation.'
)

self.pass_at_k = pass_at_k
self.generations_per_sample = generations_per_sample

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.pad_tok_id = pad_tok_id
self.padding_side = 'left'
self.top_p = top_p
self.top_k = top_k
self.tokenizer = tokenizer
fewshot_rng = random.Random(fewshot_random_seed)
self.max_answer_length = 0
self.max_prompt_length = 0
self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, fewshot_rng)

@staticmethod
def stringify_input(input_tuple):
tup = eval(input_tuple)
res = ', '.join([f'{json.dumps(x)}' for x in tup])
return res

@staticmethod
def _write_assert_statement(language, fn_name, input_val, output_val, fewshot_idx=''):
if language == 'python':
if output_val is not None:
return f'\n\ndef test{fewshot_idx}():\n\tassert {fn_name}({input_val}) == {output_val}'
else:
return f'\n\ndef test{fewshot_idx}():\n\tassert {fn_name}({input_val}) =='
else:
raise ValueError(f'Unsupported language: {language}')

def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: str, fewshot_rng: random.Random):
"""Prepares a set of language modeling tasks into tokenized format with prompt and fewshot examples.

Each task consists of a context and a continuation as well as an optional prompt and optional list of
example context/continuation pairs which precede the test context/continuation pair.

Args:
num_fewshot (int): Number of examples context/continuation pairs to prepend to the test pair
prompt_string (str): The prompt to prepend to all inputs
example_delimiter (str): The delimiter used to separate each individual context/continuation pair
continuation_delimiter (str): The delimiter used to separate each context from its continuation
fewshot_rng (random.Random): Random number generator used to select fewshot examples

Returns:
dict: Contains the context, the continuation, and the preamble (prompt + fewshot examples)
"""
examples = []
max_answer_length = 0 # this is used to determine the expected generation length
max_prompt_length = 0 # this is used to determine padding
for sample_idx in tqdm(range(len(self.samples))):

preamble = f'"""\n{prompt_string}\n"""'

if num_fewshot > 0:
fewshot_idxs = _get_fewshot_sample_idxs(len(self.samples), num_fewshot, sample_idx, fewshot_rng)
for idx, fewshot_idx in enumerate(fewshot_idxs):
prompt, soln, entry_point, test_in, test_out = (
self.samples[fewshot_idx]['prompt'],
self.samples[fewshot_idx]['canonical_solution'],
self.samples[fewshot_idx]['entry_point'],
self.samples[fewshot_idx]['test_inputs'],
self.samples[fewshot_idx]['test_outputs'],
)

test_idx = fewshot_rng.choice(range(0, len(test_in)))
assert_stmt = self._write_assert_statement(self.samples[sample_idx]['language'], entry_point,
self.stringify_input(test_in[test_idx]),
test_out[test_idx], str(idx))
example = f"""{example_delimiter}{prompt}{soln}{assert_stmt}"""

preamble += example

prompt, soln, entry_point, test_in, test_out, language = (
self.samples[sample_idx]['prompt'],
self.samples[sample_idx]['canonical_solution'],
self.samples[sample_idx]['entry_point'],
self.samples[sample_idx]['test_inputs'],
self.samples[sample_idx]['test_outputs'],
self.samples[sample_idx]['language'],
)
for inp, out in zip(test_in, test_out):
encoded_example = {}
assert_stmt = self._write_assert_statement(
language,
entry_point,
self.stringify_input(inp),
None # final assert statement is incomplete
)
context = f"""{example_delimiter}{prompt}{soln}{assert_stmt}"""

# If the preamble is empty then this will be a 0-length list, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer)
encoded_example['preamble'] = self.tokenizer(preamble)
# If there is an EOS token added, we need to remove it so it is not in the middle of the prompt
if self.tokenizer.eos_token_id is not None and len(
encoded_example['preamble']
['input_ids']) > 1 and encoded_example['preamble']['input_ids'][-1] == self.tokenizer.eos_token_id:
encoded_example['preamble']['input_ids'] = encoded_example['preamble']['input_ids'][:-1]

encoded_example['prompt'] = self.tokenizer(context, add_special_tokens=False)
encoded_example['prompt_text'] = self.samples[sample_idx]['prompt']
encoded_example['language'] = self.samples[sample_idx]['language']
encoded_example['expected_output'] = out
examples.append(encoded_example)

max_answer_length = max(max_answer_length,
len(self.tokenizer(out, add_special_tokens=False)['input_ids']))
max_prompt_length = max(
max_prompt_length,
len(encoded_example['preamble']['input_ids'] + encoded_example['prompt']['input_ids']))

self.max_answer_length = max_answer_length
self.max_prompt_length = max_prompt_length
return examples

def __getitem__(self, index):
return self.encoded_dataset[index]

def __len__(self):
return len(self.encoded_dataset)

def collate_fn(self, data):
inputs, prompts, outputs, languages = [], [], [], []
for sample in data:
preamble, prompt, language = (
sample['preamble'],
sample['prompt'],
sample['language'],
)
context_enc = preamble['input_ids'] + prompt['input_ids']
inp, _ = _make_padded_input(context_enc, [],
self.max_prompt_length,
self.pad_tok_id,
padding_side=self.padding_side)

inputs.append(inp)
outputs.append(sample['expected_output'])
prompts.append(self.tokenizer.decode(context_enc))
languages.append(language)

batch = {
'input_ids': torch.stack(inputs),
'mode': 'generate',
'prompts': prompts, # list of prompts
'languages': languages, # list of languages
'pass_at_k': self.pass_at_k,
'generation_length': self.max_answer_length,
'labels': outputs,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
'num_return_sequences': self.generations_per_sample, # how many gens per prompt
'do_sample': True,
'top_p': self.top_p,
'top_k': self.top_k,
'use_cache': True,
}
}
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
return batch

def get_num_samples_in_batch(self, batch) -> int:
# Count number of inputs in the batch
return batch['input_ids'].shape[0]

def split_batch(self, batch: Any, microbatch_size: int):
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
no_split = ['mode', 'generation_length', 'pass_at_k', 'generation_kwargs']
normal_split = ['input_ids', 'attention_mask']
list_split = [
'labels', 'tests', 'canonical_solutions', 'entry_points', 'test_inputs', 'test_outputs', 'prompts',
'languages'
]
chunked = {}
for k, v in batch.items():
if k in no_split:
# Defer broadcasting until we know num_chunks
pass
elif k in list_split:
chunked[k] = _split_list(v, microbatch_size)
elif k in normal_split:
chunked[k] = _default_split_batch(v, microbatch_size)
else:
raise ValueError(f'Unexpected key {k}')
num_chunks = len(chunked['input_ids'])
for k, v in batch.items():
if isinstance(v, (int, float, str, bool, dict)):
chunked[k] = [v] * num_chunks

return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]


class InContextLearningLMTaskDataset(Dataset):
"""A dataset that construct batches for in-context learning language modeling evaluation

Expand Down Expand Up @@ -1238,6 +1494,19 @@ def build_icl_dataloader(
pass_at_k=pass_at_k,
generations_per_sample=generations_per_sample)
effective_batchsize = batch_size
elif icl_task_type == 'code_execution_prediction':
dataset = InContextLearningExecutionPredictionTaskDataset(dataset_uri,
tokenizer,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
destination_path=destination_path,
fewshot_random_seed=fewshot_random_seed,
pass_at_k=pass_at_k,
generations_per_sample=generations_per_sample)
effective_batchsize = batch_size
else:
raise Exception(f'Unrecognized ICL task type: {icl_task_type}')

Expand Down
25 changes: 8 additions & 17 deletions composer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,25 @@

from composer.metrics.map import MAP
from composer.metrics.metrics import CrossEntropy, Dice, LossMetric, MIoU
from composer.metrics.nlp import (BinaryF1Score, InContextLearningCodeEvalAccuracy, InContextLearningLMAccuracy,
from composer.metrics.nlp import (BinaryF1Score, InContextLearningCodeEvalAccuracy,
InContextLearningCodeExecutionPredictionAccuracy, InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError, InContextLearningMetric,
InContextLearningMultipleChoiceAccuracy, InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity, MaskedAccuracy)

__all__ = [
'MAP',
'MIoU',
'Dice',
'CrossEntropy',
'LossMetric',
'BinaryF1Score',
'LanguageCrossEntropy',
'MaskedAccuracy',
'LanguagePerplexity',
'InContextLearningLMAccuracy',
'InContextLearningMultipleChoiceAccuracy',
'InContextLearningQAAccuracy',
'InContextLearningMCExpectedCalibrationError',
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMetric',
'InContextLearningCodeEvalAccuracy',
'MAP', 'MIoU', 'Dice', 'CrossEntropy', 'LossMetric', 'BinaryF1Score', 'LanguageCrossEntropy', 'MaskedAccuracy',
'LanguagePerplexity', 'InContextLearningLMAccuracy', 'InContextLearningMultipleChoiceAccuracy',
'InContextLearningQAAccuracy', 'InContextLearningMCExpectedCalibrationError',
'InContextLearningLMExpectedCalibrationError', 'InContextLearningMetric', 'InContextLearningCodeEvalAccuracy',
'InContextLearningCodeExecutionPredictionAccuracy'
]

METRIC_DEFAULT_CTORS = {
'InContextLearningLMAccuracy': InContextLearningLMAccuracy,
'InContextLearningMultipleChoiceAccuracy': InContextLearningMultipleChoiceAccuracy,
'InContextLearningQAAccuracy': InContextLearningQAAccuracy,
'InContextLearningCodeEvalAccuracy': InContextLearningCodeEvalAccuracy,
'InContextLearningCodeExecutionPredictionAccuracy': InContextLearningCodeExecutionPredictionAccuracy,
}
Loading
Loading