Skip to content

Commit

Permalink
Merge branch 'main' into milo/conversion-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Mar 29, 2024
2 parents 133d6ba + 7a8a156 commit e6dfe50
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
MemoryMonitor, MemorySnapshot, OOMObserver,
OptimizerMonitor, RuntimeEstimator,
from composer.callbacks import (EarlyStopper, EvalOutputLogging, Generate,
LRMonitor, MemoryMonitor, MemorySnapshot,
OOMObserver, OptimizerMonitor, RuntimeEstimator,
SpeedMonitor)

from llmfoundry.callbacks.async_eval_callback import AsyncEval
Expand Down Expand Up @@ -33,6 +33,7 @@
callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver)
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('eval_output_logging', func=EvalOutputLogging)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand Down
19 changes: 16 additions & 3 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pandas as pd
import torch
from composer.core import Callback
from composer.loggers.logger_destination import LoggerDestination
from composer.trainer import Trainer
from composer.utils import dist, get_device, reproducibility
Expand All @@ -23,8 +24,9 @@

install()
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_composer_model, build_evaluators,
build_logger, build_tokenizer)
build_callback, build_composer_model,
build_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import (log_config, pop_config,
process_init_device)
from llmfoundry.utils.registry_utils import import_file
Expand All @@ -49,6 +51,7 @@ def evaluate_model(
eval_gauntlet_df: Optional[pd.DataFrame],
eval_subset_num_batches: int,
icl_subset_num_batches: Optional[int],
callback_configs: Optional[DictConfig],
metadata: Optional[Dict[str, str]],
logged_config: DictConfig,
should_log_config: bool = True,
Expand All @@ -73,7 +76,12 @@ def evaluate_model(
icl_subset_num_batches=icl_subset_num_batches,
)

callbacks = []
# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)

Expand Down Expand Up @@ -238,6 +246,10 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:

# Pop out interpolation variables.
pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None)
callback_configs: Optional[DictConfig] = pop_config(cfg,
'callbacks',
must_exist=False,
default_value=None)

# Warn for unused parameters
for key in cfg:
Expand Down Expand Up @@ -296,6 +308,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
callback_configs=callback_configs,
eval_subset_num_batches=eval_subset_num_batches,
icl_subset_num_batches=icl_subset_num_batches,
metadata=metadata,
Expand Down

0 comments on commit e6dfe50

Please sign in to comment.