Skip to content

Commit

Permalink
metadata in eval.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 18, 2023
1 parent 1a485d4 commit cd2a31d
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd
import torch
from composer.loggers import MosaicMLLogger
from composer.loggers.logger_destination import LoggerDestination
from composer.models.base import ComposerModel
from composer.trainer import Trainer
Expand All @@ -24,7 +25,8 @@
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device
from llmfoundry.utils.config_utils import (log_config, pop_config,
process_init_device)


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -114,6 +116,7 @@ def evaluate_model(
precision: str,
eval_gauntlet_df: Optional[pd.DataFrame],
icl_subset_num_batches: Optional[int],
metadata: Optional[Dict[str, str]],
):

print(f'Evaluating model: {model_cfg.model_name}', flush=True)
Expand Down Expand Up @@ -144,6 +147,20 @@ def evaluate_model(
for name, logger_cfg in loggers_cfg.items()
]

if metadata is not None:
# Flatten the metadata for logging
loggers_cfg.pop('metadata', None)
loggers_cfg.update(metadata, merge=True)

# Find the MosaicMLLogger
mosaicml_logger = next((
logger for logger in loggers if isinstance(logger, MosaicMLLogger)),
None)

if mosaicml_logger is not None:
mosaicml_logger.log_metrics(metadata)
mosaicml_logger._flush_metadata(force_flush=True)

if fsdp_config and model_cfg.model.get('load_in_8bit', False):
raise ValueError(
'The FSDP config block is not supported when loading ' +
Expand Down Expand Up @@ -177,6 +194,7 @@ def evaluate_model(

assert composer_model is not None

print(f'Building trainer for {model_cfg.model_name}...')
trainer = Trainer(
run_name=run_name,
seed=seed,
Expand All @@ -193,13 +211,18 @@ def evaluate_model(
python_log_level=python_log_level,
)

print('Logging config')
log_config(loggers_cfg)

print(f'Starting eval for {model_cfg.model_name}...')
if torch.cuda.is_available():
torch.cuda.synchronize()
a = time.time()
trainer.eval(eval_dataloader=evaluators)
if torch.cuda.is_available():
torch.cuda.synchronize()
b = time.time()

print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds')
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)

Expand Down Expand Up @@ -270,6 +293,12 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
'icl_subset_num_batches',
must_exist=False,
default_value=None)
metadata: Optional[Dict[str, str]] = pop_config(cfg,
'metadata',
must_exist=False,
default_value=None,
convert=True)

# Pop out interpolation variables.
pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None)

Expand Down Expand Up @@ -313,7 +342,8 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
icl_subset_num_batches=icl_subset_num_batches)
icl_subset_num_batches=icl_subset_num_batches,
metadata=metadata)
trainers.append(trainer)

if eval_gauntlet_callback is not None:
Expand Down

0 comments on commit cd2a31d

Please sign in to comment.