Skip to content

Commit

Permalink
Add eval loader to eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 16, 2023
1 parent e796218 commit 774c62e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
25 changes: 25 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
Expand All @@ -39,6 +40,30 @@
log = logging.getLogger(__name__)


def build_eval_loader(
eval_loader_config: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
) -> Evaluator:
evaluators = []

is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [eval_loader_config]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)

# For training, metrics are added after the model is created
# For eval, we'll use Evaluator's default, which is to use what's
# returned by model.get_metrics()
eval_loader = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
)
evaluators.append(eval_loader)
return evaluators


def build_icl_data_and_gauntlet(
icl_tasks_config: Union[str, ListConfig],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
Expand Down
12 changes: 11 additions & 1 deletion scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils.builders import (build_icl_data_and_gauntlet,
from llmfoundry.utils.builders import (build_eval_loader,
build_icl_data_and_gauntlet,
build_logger, build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device

Expand Down Expand Up @@ -100,6 +101,7 @@ def evaluate_model(
max_seq_len: int,
device_eval_batch_size: int,
eval_gauntlet_config: Optional[Union[str, DictConfig]],
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
fsdp_config: Optional[Dict],
num_retries: int,
loggers_cfg: Dict[str, Any],
Expand All @@ -122,6 +124,11 @@ def evaluate_model(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
max_seq_len, icl_subset_num_batches)

if eval_loader_config is not None:
loader_evaluators = build_eval_loader(eval_loader_config, tokenizer,
device_eval_batch_size)
evaluators.extend(loader_evaluators)

callbacks = []
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
Expand Down Expand Up @@ -228,6 +235,8 @@ def main(cfg: DictConfig):
default_value='debug')

# Optional Evaluation Parameters with default values
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17)
dist_timeout: Union[float, int] = pop_config(cfg,
'dist_timeout',
Expand Down Expand Up @@ -285,6 +294,7 @@ def main(cfg: DictConfig):
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
eval_gauntlet_config=eval_gauntlet_config,
eval_loader_config=eval_loader_config,
fsdp_config=fsdp_config,
num_retries=num_retries,
loggers_cfg=loggers_cfg,
Expand Down
17 changes: 3 additions & 14 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MPTForCausalLM)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.utils.builders import (build_algorithm, build_callback,
build_eval_loader,
build_icl_data_and_gauntlet,
build_logger, build_optimizer,
build_scheduler, build_tokenizer)
Expand Down Expand Up @@ -529,22 +530,10 @@ def main(cfg: DictConfig) -> Trainer:
evaluators = []
eval_loaders = []
if eval_loader_config is not None:
is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [
eval_loader_config
]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
eval_loaders.append(eval_loader)
eval_loaders = build_eval_loader(eval_loader_config, tokenizer,
device_eval_batch_size)

eval_gauntlet_callback = None

if icl_tasks_config is not None:
icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config, eval_gauntlet_config, tokenizer,
Expand Down

0 comments on commit 774c62e

Please sign in to comment.