diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 15da78b094..9aab7b39fb 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -17,7 +17,6 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art -from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -30,7 +29,7 @@ normalize_config, validate_config, ) -from axolotl.utils.data import prepare_dataset +from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars @@ -343,81 +342,7 @@ def load_rl_datasets( cfg: DictDefault, cli_args: TrainerCliArgs, # pylint: disable=unused-argument ) -> TrainDatasetMeta: - train_datasets: List[Any] = [] - for i, ds_cfg in enumerate(cfg.datasets): - train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) - # eval_dataset = load_dataset( - # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] - # ) - eval_dataset = None - - def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" - sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" - return sample - - def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen']}<|im_end|>" - sample["rejected"] = f"{sample['rejected']}<|im_end|>" - return sample - - def apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen']}<|im_end|>" - sample["rejected"] = f"{sample['rejected']}<|im_end|>" - return sample - - def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" - sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" - return sample - - for i, data_set in enumerate(train_datasets): - _type = cfg.datasets[i]["type"] - ds_type_fn = locals()[_type] - train_datasets[i] = data_set.map( - ds_type_fn, - desc="Mapping RL Dataset", - ) - train_dataset = concatenate_datasets(train_datasets) - - # eval_dataset = eval_dataset.map(intel_apply_chatml) - + train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c3b01e6c6e..e109db7f84 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -12,14 +12,19 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import torch import transformers from datasets import Dataset from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers import ( + EarlyStoppingCallback, + Trainer, + TrainerCallback, + TrainingArguments, +) from transformers.trainer_utils import seed_worker from trl import DPOTrainer @@ -460,6 +465,7 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None _model_ref = None + _peft_config = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg @@ -490,13 +496,26 @@ def eval_dataset(self): def eval_dataset(self, dataset): self._eval_dataset = dataset + @property + def peft_config(self): + return self._peft_config + + @peft_config.setter + def peft_config(self, peft_config): + self._peft_config = peft_config + @abstractmethod def build(self, total_num_steps): pass - @abstractmethod - def get_callbacks(self): - pass + def get_callbacks(self) -> List[TrainerCallback]: + callbacks = [] + if self.cfg.use_wandb: + callbacks.append( + SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) + ) + + return callbacks @abstractmethod def get_post_trainer_create_callbacks(self, trainer): @@ -504,12 +523,6 @@ def get_post_trainer_create_callbacks(self, trainer): Callbacks added after the trainer is created, usually b/c these need access to the trainer """ - -class HFCausalTrainerBuilder(TrainerBuilderBase): - """ - Build the HuggingFace training args/trainer for Causal models - """ - def hook_pre_create_training_args(self, training_arguments_kwargs): # TODO return training_arguments_kwargs @@ -526,10 +539,16 @@ def hook_post_create_trainer(self, trainer): # TODO return trainer + +class HFCausalTrainerBuilder(TrainerBuilderBase): + """ + Build the HuggingFace training args/trainer for Causal models + """ + def get_callbacks(self): - callbacks = [] + callbacks = super().get_callbacks() callbacks.append(GPUStatsCallback(self.cfg)) - callbacks.append(EvalFirstStepCallback) + callbacks.append(EvalFirstStepCallback()) if self.cfg.relora_steps: callbacks.append(ReLoRACallback(self.cfg)) @@ -538,7 +557,7 @@ def get_callbacks(self): hasattr(self.model, "use_bettertransformer") and self.model.use_bettertransformer is True ): - callbacks.append(SaveBetterTransformerModelCallback) + callbacks.append(SaveBetterTransformerModelCallback()) if self.cfg.use_wandb: callbacks.append( @@ -931,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): """ def get_callbacks(self): - callbacks = [] + callbacks = super().get_callbacks() return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -949,21 +968,60 @@ def build_training_arguments(self, total_num_steps): ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + + if self.cfg.hub_model_id: + training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_args_kwargs["push_to_hub"] = True + training_args_kwargs["hub_private_repo"] = True + training_args_kwargs["hub_always_push"] = True + + if self.cfg.hub_strategy: + training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + + if self.cfg.save_safetensors is not None: + training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors + + if self.eval_dataset: + training_args_kwargs["evaluation_strategy"] = "steps" + training_args_kwargs["eval_steps"] = self.cfg.eval_steps + else: + training_args_kwargs["evaluation_strategy"] = "no" + if self.cfg.bf16 or self.cfg.bfloat16: + training_args_kwargs["bf16"] = True + + training_args_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" + ) + training_args_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) + + if self.cfg.dataloader_pin_memory is not None: + training_args_kwargs[ + "dataloader_pin_memory" + ] = self.cfg.dataloader_pin_memory + if self.cfg.dataloader_num_workers is not None: + training_args_kwargs[ + "dataloader_num_workers" + ] = self.cfg.dataloader_num_workers + if self.cfg.dataloader_prefetch_factor is not None: + training_args_kwargs[ + "dataloader_prefetch_factor" + ] = self.cfg.dataloader_prefetch_factor + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, - max_steps=total_num_steps, + max_steps=self.cfg.max_steps or total_num_steps, remove_unused_columns=False, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, - evaluation_strategy="no", - # eval_steps=self.cfg.eval_steps, save_strategy="steps", save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, - bf16=True, gradient_checkpointing=self.cfg.gradient_checkpointing, - gradient_checkpointing_kwargs={"use_reentrant": False}, + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + or {"use_reentrant": False}, logging_first_step=True, logging_steps=1, optim=self.cfg.optimizer, @@ -982,22 +1040,27 @@ def build(self, total_num_steps): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" - + if self.eval_dataset: + dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset + if self.cfg.adapter and self.peft_config: + dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer = DPOTrainer( self.model, self.model_ref, args=training_args, beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, - # eval_dataset=self.eval_dataset, - eval_dataset=None, tokenizer=self.tokenizer, max_length=self.cfg.sequence_len, max_target_length=None, max_prompt_length=self.cfg.sequence_len, generate_during_eval=True, + callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + dpo_trainer = self.hook_post_create_trainer(dpo_trainer) + for callback in self.get_post_trainer_create_callbacks(dpo_trainer): + dpo_trainer.add_callback(callback) return dpo_trainer diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py new file mode 100644 index 0000000000..3c1c808005 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -0,0 +1,21 @@ +""" +module for DPO style dataset transform strategies +""" + +import importlib +import logging + +LOG = logging.getLogger("axolotl") + + +def load(strategy, cfg): + try: + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") + func = getattr(mod, load_fn) + load_kwargs = {} + return func(cfg, **load_kwargs) + except Exception: # pylint: disable=broad-exception-caught + LOG.warning(f"unable to load strategy {strategy}") + return None diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py new file mode 100644 index 0000000000..e0840f7622 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -0,0 +1,85 @@ +""" +DPO strategies for chatml +""" + + +def argilla( + cfg, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + return transform_fn + + +def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca DPO Pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + return transform_fn + + +def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + return transform_fn + + +def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/zephyr.py b/src/axolotl/prompt_strategies/dpo/zephyr.py new file mode 100644 index 0000000000..02bce8a338 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/zephyr.py @@ -0,0 +1,21 @@ +""" +DPO strategies for zephyr +""" + + +def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + data = {} + data["prompt"] = ( + "<|system|>\n\n" + "<|user|>\n" + f"{sample['prompt']}\n" + "<|assistant|>\n" + ) + answers = sorted(sample["answers"], key=lambda x: x["rank"]) + data["chosen"] = answers[-1]["answer"] + data["rejected"] = answers[-2]["answer"] + + return data + + return transform_fn diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d68ae46b14..79b8802345 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -96,7 +96,12 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps + cfg, + train_dataset, + eval_dataset, + (model, model_ref, peft_config), + tokenizer, + total_num_steps, ) if hasattr(model, "config"): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c4cd148be..fb2eb9bc42 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -21,6 +21,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, @@ -850,3 +851,50 @@ def encode_packed_pretraining( chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data + + +def load_prepare_dpo_datasets(cfg): + def load_split(dataset_cfgs, _cfg): + split_datasets: List[Any] = [] + for i, ds_cfg in enumerate(dataset_cfgs): + if ds_cfg["ds_type"] == "json": + for data_file in ds_cfg["data_files"]: + data_files = {ds_cfg["split"]: data_file} + ds = load_dataset( # pylint: disable=invalid-name + "json", + data_files=data_files, + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + else: + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + + for i, data_set in enumerate(split_datasets): + _type = dataset_cfgs[i]["type"] + if _type: + ds_transform_fn = load_dpo(_type, _cfg) + split_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen" and "rejected" already preprocessed + split_datasets[i] = data_set + + return concatenate_datasets(split_datasets) + + with zero_first(is_main_process()): + train_dataset = load_split(cfg.datasets, cfg) + + eval_dataset = None + if cfg.test_datasets: + eval_dataset = load_split(cfg.test_datasets, cfg) + if not eval_dataset: + eval_dataset = None + + return train_dataset, eval_dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d75926952f..6ba1e3704b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -682,7 +682,12 @@ def load_model( lora_config = None if not reference_model or cfg.lora_model_dir: - model, lora_config = load_adapter(model, cfg, cfg.adapter) + # if we're not loading the reference model, then we're loading the model for training + # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config + if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: + _, lora_config = load_lora(model, cfg, inference=False, config_only=True) + else: + model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") @@ -770,8 +775,8 @@ def find_all_linear_names(model): return list(lora_module_names) -def load_lora(model, cfg, inference=False): - # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] +def load_lora(model, cfg, inference=False, config_only=False): + # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] from peft import LoraConfig, PeftModel, get_peft_model @@ -794,6 +799,9 @@ def load_lora(model, cfg, inference=False): task_type="CAUSAL_LM", ) + if config_only: + return None, lora_config + if cfg.lora_model_dir: LOG.debug("Loading pretained PEFT - LoRA") model_kwargs: Any = {} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index dfd3385b73..2e9d782c74 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -316,9 +316,10 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl: + if cfg.rl in ["dpo", "ipo", "kto_pair"]: trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] + trainer_builder.peft_config = model[2] else: trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py new file mode 100644 index 0000000000..ac3c6d0693 --- /dev/null +++ b/tests/e2e/test_dpo.py @@ -0,0 +1,157 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_rl_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestDPOLlamaLora(unittest.TestCase): + """ + Test case for DPO Llama models using LoRA + """ + + @with_temp_dir + def test_dpo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_kto_pair_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "kto_pair", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_ipo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "ipo", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()