Skip to content

Commit

Permalink
more dpo fixes for dataset loading and docs (#1185) [skip ci]
Browse files Browse the repository at this point in the history
* more dpo fixes for dataset loading and docs

* preprocess dpo datasets
  • Loading branch information
winglian authored Jan 24, 2024
1 parent b97254a commit dbd2523
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
10 changes: 10 additions & 0 deletions docs/rlhf.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ datasets:
rl: ipo
```
#### Using local dataset files
```yaml
datasets:
- ds_type: json
data_files:
- orca_rlhf.jsonl
split: train
type: chatml.intel
```
#### Trl autounwrap for peft
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
Expand Down Expand Up @@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cfg.rl:
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,12 @@ def build_training_arguments(self, total_num_steps):
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
else:
training_args_kwargs["remove_unused_columns"] = False

if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
Expand All @@ -1013,7 +1019,6 @@ def build_training_arguments(self, total_num_steps):
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
save_strategy="steps",
Expand Down
53 changes: 51 additions & 2 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import yaml
from datasets import (
Dataset,
DatasetDict,
Expand Down Expand Up @@ -853,6 +854,41 @@ def encode_packed_pretraining(
return chunked_data


def _get_path(ds_hash, cfg):
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)

return prepared_ds_path


def _load_preprocessed_ds(cfg, sub_cfg):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)
dataset = None

if (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))

return dataset


def _save_preprocessed_ds(cfg, sub_cfg, dataset):
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
prepared_ds_path = _get_path(ds_hash, cfg)

if cfg.is_preprocess and is_main_process():
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset.save_to_disk(str(prepared_ds_path))


def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
Expand Down Expand Up @@ -889,12 +925,25 @@ def load_split(dataset_cfgs, _cfg):
return concatenate_datasets(split_datasets)

with zero_first(is_main_process()):
train_dataset = load_split(cfg.datasets, cfg)
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)

eval_dataset = None
if cfg.test_datasets:
eval_dataset = load_split(cfg.test_datasets, cfg)
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
eval_dataset = None

if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)

return train_dataset, eval_dataset

0 comments on commit dbd2523

Please sign in to comment.