From 204d2f7c5a729ba5955f7723751e4db5a179e056 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Mon, 27 Nov 2023 23:25:07 +0000 Subject: [PATCH] fix typing and formatting --- llmfoundry/data/dataloader.py | 9 +++------ llmfoundry/utils/builders.py | 24 ++++++++++++++---------- scripts/eval/eval.py | 5 ++--- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 71d31b6808..6974e3ba0b 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -18,11 +18,8 @@ } -def build_dataloader( - cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, -) -> DataSpec: +def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int) -> DataSpec: """Builds a dataloader from a config. Args: @@ -33,7 +30,7 @@ def build_dataloader( """ if cfg.name not in LOADER_NAME_TO_FUNCTION: - allowed = ", ".join(LOADER_NAME_TO_FUNCTION.keys()) + allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys()) raise ValueError(f'Expected dataloader name to be one of {allowed}' + f' but found name "{cfg.name}" in config: {cfg}') diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index d2ea7f46c6..c2fbf3f71f 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -42,22 +42,26 @@ def build_eval_loader( - eval_loader_config: DictConfig, + eval_loader_config: Union[DictConfig, ListConfig], model: Union[Any, ComposerHFCausalLM], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, -) -> Evaluator: +) -> List[Evaluator]: assert model.train_metrics is not None eval_metric_names = list(model.train_metrics.keys()) - evaluators = [] + evaluators: List[Evaluator] = [] + if isinstance(eval_loader_config, ListConfig): + eval_configs: ListConfig = eval_configs + is_multi_eval = True + else: + eval_configs = ListConfig([eval_loader_config]) + is_multi_eval = False - is_multi_eval = isinstance(eval_loader_config, ListConfig) - eval_configs = eval_loader_config if is_multi_eval else [eval_loader_config] for eval_config in eval_configs: eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size) - eval_loader = Evaluator( + eval_loader: Evaluator = Evaluator( label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=eval_metric_names, @@ -220,8 +224,8 @@ def build_tokenizer( signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' - if dist.is_available() and dist.is_initialized() and dist.get_world_size( - ) > 1: + if dist.is_available() and dist.is_initialized( + ) and dist.get_world_size() > 1: # Make sure the tokenizer files are downloaded and cached first by local rank 0 with dist.local_rank_zero_download_and_wait(signal_file_path): pass @@ -240,8 +244,8 @@ def build_tokenizer( int(1e30), ) - if dist.is_available() and dist.is_initialized() and dist.get_world_size( - ) > 1: + if dist.is_available() and dist.is_initialized( + ) and dist.get_world_size() > 1: if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_tokenizer_setup') diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index ef54e3234a..1f306f4de4 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -334,9 +334,8 @@ def main(cfg: DictConfig): if eval_gauntlet_df is not None and eval_gauntlet_callback is not None: assert composite_scores is not None row = {'model_name': model_cfg['model_name']} - row.update({ - k.split('/')[-1]: v for k, v in composite_scores.items() - }) + row.update( + {k.split('/')[-1]: v for k, v in composite_scores.items()}) eval_gauntlet_df = pd.concat( [eval_gauntlet_df, pd.DataFrame([row])], ignore_index=True)