From 4965cd2e405c14e563676434aef9b3793061de17 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 10 Apr 2024 02:42:38 +0000 Subject: [PATCH] fix dictconfig stuff again --- scripts/train/train.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 92f7ae2714..a1d6741d7c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -121,7 +121,8 @@ def validate_config(cfg: TrainConfig): `label` attribute.') loaders.append(loader) else: - loaders.append(eval_loader) + if eval_loader is not None: + loaders.append(eval_loader) for loader in loaders: if loader['name'] == 'text': if cfg.model['name'] == 'hf_t5': @@ -257,11 +258,11 @@ def main(cfg: DictConfig) -> Trainer: dist.initialize_dist(get_device(None), timeout=dist_timeout) # Mandatory model training configs - model_config: Dict[str, Any] = scfg.model + model_config: DictConfig = DictConfig(scfg.model) tokenizer_config: Dict[str, Any] = scfg.tokenizer optimizer_config: Dict[str, Any] = scfg.optimizer scheduler_config: Dict[str, Any] = scfg.scheduler - train_loader_config: Dict[str, Any] = scfg.train_loader + train_loader_config: DictConfig = DictConfig(scfg.train_loader) # Optional fsdp data, fine-tuning, and eval configs fsdp_config: Optional[Dict[str, Any]] = scfg.fsdp_config @@ -269,20 +270,20 @@ def main(cfg: DictConfig) -> Trainer: if scfg.eval_loader is not None and scfg.eval_loaders is not None: raise ValueError( 'Only one of `eval_loader` or `eval_loaders` should be provided.') - eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[ - str, - Any]]]] = scfg.eval_loader if scfg.eval_loader is not None else scfg.eval_loaders - icl_tasks_config: Optional[Union[ - List[Dict[str, Any]], - str]] = scfg.icl_tasks if scfg.icl_tasks is not None else scfg.icl_tasks_str - eval_gauntlet_config: Optional[Union[ - DictConfig, - str]] = scfg.eval_gauntlet if scfg.eval_gauntlet is not None else scfg.eval_gauntlet_str + eval_loader_config: Optional[Union[DictConfig, ListConfig]] = DictConfig( + scfg.eval_loader) if scfg.eval_loader is not None else ListConfig( + scfg.eval_loaders) if scfg.eval_loaders is not None else None + icl_tasks_config: Optional[Union[ListConfig, str]] = ListConfig( + scfg.icl_tasks) if scfg.icl_tasks is not None else scfg.icl_tasks_str + eval_gauntlet_config: Optional[Union[DictConfig, str]] = DictConfig( + scfg.eval_gauntlet + ) if scfg.eval_gauntlet is not None else scfg.eval_gauntlet_str icl_subset_num_batches: Optional[int] = scfg.icl_subset_num_batches icl_seq_len: Optional[int] = scfg.icl_seq_len # Optional logging, evaluation and callback configs logger_configs: Optional[Dict[str, Any]] = scfg.loggers - callback_configs: Optional[Dict[str, Any]] = scfg.callbacks + callback_configs: Optional[DictConfig] = DictConfig( + scfg.callbacks) if scfg.callbacks else None algorithm_configs: Optional[Dict[str, Any]] = scfg.algorithms # Mandatory hyperparameters for training @@ -392,7 +393,8 @@ def main(cfg: DictConfig) -> Trainer: # Profiling profiler: Optional[Profiler] = None - profiler_cfg: Optional[DictConfig] = scfg.profiler + profiler_cfg: Optional[DictConfig] = DictConfig( + scfg.profiler) if scfg.profiler is not None else None if profiler_cfg: profiler_schedule_cfg: Dict = pop_config(profiler_cfg, 'schedule',