Skip to content

Commit

Permalink
normalize pretraining_dataset configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 20, 2024
1 parent 585c11d commit a787d57
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def normalize_config(cfg):
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)

if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down
5 changes: 1 addition & 4 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def prepare_dataset(cfg, tokenizer):
else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]
elif isinstance(cfg.pretraining_dataset, list) and isinstance(
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
Expand Down

0 comments on commit a787d57

Please sign in to comment.