diff --git a/docs/rlhf.md b/docs/rlhf.md index 774c3992f9..3283880bdd 100644 --- a/docs/rlhf.md +++ b/docs/rlhf.md @@ -34,6 +34,16 @@ datasets: rl: ipo ``` +#### Using local dataset files +```yaml +datasets: + - ds_type: json + data_files: + - orca_rlhf.jsonl + split: train + type: chatml.intel +``` + #### Trl autounwrap for peft Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config. diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 76b655afb6..745a530c0a 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -13,6 +13,7 @@ check_user_token, load_cfg, load_datasets, + load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import PreprocessCliArgs @@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cfg.rl: + load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + LOG.info( Fore.GREEN + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d055fd4764..8c13eb78da 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -996,6 +996,12 @@ def build_training_arguments(self, total_num_steps): training_args_kwargs["lr_scheduler_kwargs"] = ( self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) + if self.cfg.remove_unused_columns is not None: + training_args_kwargs[ + "remove_unused_columns" + ] = self.cfg.remove_unused_columns + else: + training_args_kwargs["remove_unused_columns"] = False if self.cfg.dataloader_pin_memory is not None: training_args_kwargs[ @@ -1013,7 +1019,6 @@ def build_training_arguments(self, total_num_steps): training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, - remove_unused_columns=False, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, save_strategy="steps", diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5839f74f69..6dd1ec5602 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import yaml from datasets import ( Dataset, DatasetDict, @@ -853,6 +854,41 @@ def encode_packed_pretraining( return chunked_data +def _get_path(ds_hash, cfg): + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + ) + + return prepared_ds_path + + +def _load_preprocessed_ds(cfg, sub_cfg): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + dataset = None + + if ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess + ): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset = load_from_disk(str(prepared_ds_path)) + + return dataset + + +def _save_preprocessed_ds(cfg, sub_cfg, dataset): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + + if cfg.is_preprocess and is_main_process(): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset.save_to_disk(str(prepared_ds_path)) + + def load_prepare_dpo_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] @@ -889,12 +925,25 @@ def load_split(dataset_cfgs, _cfg): return concatenate_datasets(split_datasets) with zero_first(is_main_process()): - train_dataset = load_split(cfg.datasets, cfg) + train_is_preprocessed = False + eval_is_preprocessed = False + if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): + train_is_preprocessed = True + else: + train_dataset = load_split(cfg.datasets, cfg) eval_dataset = None if cfg.test_datasets: - eval_dataset = load_split(cfg.test_datasets, cfg) + if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): + eval_is_preprocessed = True + else: + eval_dataset = load_split(cfg.test_datasets, cfg) if not eval_dataset: eval_dataset = None + if not train_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) + if eval_dataset and not eval_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) + return train_dataset, eval_dataset