diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 0eaab56a0a..be4ddf9782 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -155,6 +155,9 @@ def normalize_config(cfg): if isinstance(cfg.learning_rate, str): cfg.learning_rate = float(cfg.learning_rate) + if isinstance(cfg.pretraining_dataset, dict): + cfg.pretraining_dataset = [cfg.pretraining_dataset] + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fc4409fc54..a0fd3ea1a8 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -71,10 +71,7 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(cfg.pretraining_dataset, dict): - path = cfg.pretraining_dataset["path"] - name = cfg.pretraining_dataset["name"] - elif isinstance(cfg.pretraining_dataset, list) and isinstance( + if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"]