Skip to content

Commit

Permalink
Merge branch 'main' into pickle-tiktoken
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 30, 2023
2 parents dd809a6 + e40689f commit 6831f52
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 661 deletions.
88 changes: 63 additions & 25 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,32 @@ class Weighting(Enum):
LOG_SAMPLE_SZ = 3


def calculate_named_averages(average_names: Dict[str, list],
category_scores: Dict[str, float]):
"""Calculates the named averages based off the raw category scores.
For each named average, take a simple average of all the category scores associated with that named average.
Args:
average_names (dict[str, list]): Contains a mapping of named averages to which category scores that average should consist of.
category_scores (dict[str, float]): Contains the raw scores corresponding to each category.
"""
average_scores = {}
for avg_name, category_list in average_names.items():
composite_subset = {
category: score
for category, score in category_scores.items()
if category in category_list
}
if len(composite_subset.values()) > 0:
average_scores[avg_name] = sum(composite_subset.values()) / len(
composite_subset.values())
else:
average_scores[avg_name] = 0

return average_scores


class EvalGauntlet(Callback):
"""The EvalGauntlet aggregates ICL eval results.
Expand All @@ -31,7 +57,7 @@ class EvalGauntlet(Callback):
Args:
logger_keys (list): These are the exact keys that the individual benchmark metrics will be
logged under in the logger after eval
tasks (dict): This contains the list of categories, as well as the subtasks within them, the
categories (dict): This contains the list of categories, as well as the subtasks within them, the
random baseline accuracy of each subtask, and the number of fewshot examples
used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet.yaml` to see the structure.
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
Expand All @@ -43,6 +69,7 @@ class EvalGauntlet(Callback):
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
averages (Optional[dict]): Optional dictionary specifying a mapping from a average names to lists of categories used produce each named average.
"""

def __init__(self,
Expand All @@ -51,7 +78,8 @@ def __init__(self,
weighting: str = 'EQUAL',
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
benchmark_sizes: Optional[dict] = None):
benchmark_sizes: Optional[dict] = None,
averages: Optional[dict] = None):
if isinstance(logger_keys, dict):
raise ValueError(
'logger_keys now requires a list type as input, not a dict')
Expand All @@ -66,13 +94,12 @@ def __init__(self,
)

self.categories = categories
self.category_names = [conf.get('name') for conf in self.categories]
self.weighting = Weighting[weighting]
self.subtract_random_baseline = subtract_random_baseline
self.rescale_accuracy = rescale_accuracy
self.logger_keys = logger_keys

for category in self.categories:

for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"

Expand All @@ -95,7 +122,20 @@ def __init__(self,
assert weight is not None
benchmark['weighting'] = weight

def compute_averages(self, state: State) -> Dict[str, float]:
self.averages = {}
if averages is not None:
self.averages = averages
else:
# if no averages spec provided, simply average everything
self.averages['default_average'] = self.category_names

for avg_name in self.averages:
if avg_name in self.category_names:
raise ValueError(
f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.'
)

def extract_metrics_from_state(self, state: State) -> Dict[str, float]:
results = {}

for key in self.logger_keys:
Expand All @@ -121,31 +161,30 @@ def compute_averages(self, state: State) -> Dict[str, float]:
return {k: sum(v) / len(v) for k, v in results.items()}

def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
new_metrics = self.compute_averages(state)
if len(new_metrics) == 0:
computed_metrics = self.extract_metrics_from_state(state)
if len(computed_metrics) == 0:
return {}
composite_scores = {}

category_scores = {}
for category in self.categories:
missing_metrics = []
composite_scores[category['name']] = []
category_scores[category['name']] = []
for benchmark in category['benchmarks']:
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"

if key not in new_metrics:
if key not in computed_metrics:
log.warning(
f'Could not find results for benchmark: {benchmark}.')
missing_metrics.append(key)
else:
score = new_metrics[key]
score = computed_metrics[key]

if self.subtract_random_baseline:
score -= benchmark['random_baseline']

if self.rescale_accuracy and self.subtract_random_baseline:
score /= 1.0 - benchmark['random_baseline']

composite_scores[category['name']].append({
category_scores[category['name']].append({
'name': benchmark['name'],
'score': score,
'weighting': benchmark['weighting']
Expand All @@ -155,23 +194,22 @@ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
log.warning(
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
)
del composite_scores[category['name']]
del category_scores[category['name']]
continue
total_weight = sum(
k['weighting'] for k in composite_scores[category['name']])
composite_scores[category['name']] = sum(
k['weighting'] for k in category_scores[category['name']])
category_scores[category['name']] = sum(
k['score'] * (k['weighting'] / total_weight)
for k in composite_scores[category['name']])
for k in category_scores[category['name']])

composite_scores = {
named_averages = calculate_named_averages(self.averages,
category_scores)
category_scores.update(named_averages)
category_scores = {
f'icl/metrics/eval_gauntlet/{k}': v
for k, v in composite_scores.items()
for k, v in category_scores.items()
}

composite_scores['icl/metrics/eval_gauntlet/average'] = sum(
composite_scores.values()) / len(composite_scores.values()) if len(
composite_scores.values()) > 0 else 0
if logger is not None:
logger.log_metrics(composite_scores)
logger.log_metrics(category_scores)

return composite_scores
return category_scores
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def llama_attention_patch_torch(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down Expand Up @@ -186,6 +188,8 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down
Loading

0 comments on commit 6831f52

Please sign in to comment.