Skip to content

Commit

Permalink
Fix and document test_datasets (#1228)
Browse files Browse the repository at this point in the history
* Make sure test_dataset are used and treat val_set_size.

* Add test_datasets docs.

* Apply suggestions from code review

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
DreamGenX and winglian authored Jan 31, 2024
1 parent 8608d80 commit 5787e1a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,17 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:

# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
- path: /workspace/data/eval.jsonl
ds_type: json
# You need to specify a split. For "json" datasets the default split is called "train".
split: train
type: completion
data_files:
- /workspace/data/eval.jsonl

# use RL training: dpo, ipo, kto_pair
rl:

Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def build(self, total_num_steps):
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True

if self.cfg.val_set_size == 0:
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
Expand Down Expand Up @@ -822,6 +822,7 @@ def build(self, total_num_steps):
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and not self.cfg.test_datasets
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def load_prepare_datasets(
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
tokenizer, cfg, default_dataset_prepared_path, split=split
)

if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
Expand Down

0 comments on commit 5787e1a

Please sign in to comment.