From a787d57f0725e0837c7ac23110619ee245d8a09e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Jan 2024 23:29:32 -0500 Subject: [PATCH] normalize pretraining_dataset configuration --- src/axolotl/utils/config.py | 3 +++ src/axolotl/utils/data.py | 5 +---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 0eaab56a0a..be4ddf9782 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -155,6 +155,9 @@ def normalize_config(cfg): if isinstance(cfg.learning_rate, str): cfg.learning_rate = float(cfg.learning_rate) + if isinstance(cfg.pretraining_dataset, dict): + cfg.pretraining_dataset = [cfg.pretraining_dataset] + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fc4409fc54..a0fd3ea1a8 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -71,10 +71,7 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(cfg.pretraining_dataset, dict): - path = cfg.pretraining_dataset["path"] - name = cfg.pretraining_dataset["name"] - elif isinstance(cfg.pretraining_dataset, list) and isinstance( + if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"]