From 7a2d427019fcbd6ae6b916af3156c909ff56849e Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Fri, 6 Dec 2024 14:46:20 -0800 Subject: [PATCH] feat: add sequence packing support for DPO (#423) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ashors1 Signed-off-by: Terry Kong Signed-off-by: NeMo-Aligner CI Signed-off-by: abukharin Signed-off-by: Oliver Koenig Signed-off-by: arendu Co-authored-by: Terry Kong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alexander Bukharin <59148829+abukharin3@users.noreply.github.com> Co-authored-by: abukharin Co-authored-by: oliver könig Co-authored-by: Adi Renduchintala --- .github/workflows/cicd-main.yml | 1 + CHANGELOG.md | 1 + docs/user-guide/dpo.rst | 61 ++++ .../data/dpo/prepare_packed_dpo_dataset.py | 270 ++++++++++++++++++ examples/nlp/gpt/train_gpt_dpo.py | 16 +- nemo_aligner/algorithms/dpo.py | 14 +- nemo_aligner/data/nlp/builders.py | 12 +- nemo_aligner/data/nlp/datasets.py | 198 ++++++++++++- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 257 +++++++++++++---- nemo_aligner/utils/distributed.py | 17 +- tests/functional/dpo.sh | 19 +- tests/functional/test_cases/dpo-llama3 | 4 +- tests/functional/test_cases/dpo-llama3-pack | 26 ++ tests/functional/test_data/dummy-dpo.jsonl | 150 +++++----- .../test_data/dummy_dpo_packed_90.npy | Bin 0 -> 66341 bytes tests/test_datasets.py | 177 +++++++++++- tests/test_distributed.py | 56 ++-- 17 files changed, 1086 insertions(+), 193 deletions(-) create mode 100644 examples/nlp/data/dpo/prepare_packed_dpo_dataset.py create mode 100755 tests/functional/test_cases/dpo-llama3-pack create mode 100644 tests/functional/test_data/dummy_dpo_packed_90.npy diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d2d27e95a..3f11fa876 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -92,6 +92,7 @@ jobs: - ppo-llama3-pp2-reshard - reinforce-llama3-pp2-reshard - dpo-llama3 + - dpo-llama3-pack - kd-llama3 - sft-llama3 - rm-llama3 diff --git a/CHANGELOG.md b/CHANGELOG.md index ad2da0d5f..ca8bc3b37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Next Version] ### New Features and Optimizations +- Sequence packing is now supported when running DPO. - Added support for Knowledge Distillation with SFT. See the [tutorial](docs/user-guide/knowledge-distillation.rst) for details. - Added support for Megatron Core’s distributed optimizer, which can be configured using `++model.optim.name=mcore_distributed_optim`. - Introduced `ScopedTimer` as a successor to `SyncedTimer`. `SyncedTimer` is marked for deprecation and will be removed in the next version. diff --git a/docs/user-guide/dpo.rst b/docs/user-guide/dpo.rst index 901ceee37..fa75941d0 100644 --- a/docs/user-guide/dpo.rst +++ b/docs/user-guide/dpo.rst @@ -27,6 +27,7 @@ The algorithm is identified with the ``dpo.preference_loss`` config variable. We To use the RPO algorithm, each dataset example should have ``chosen_reward`` and ``rejected_reward``, which might come from human labelers or reward models. If ``chosen_reward`` and ``rejected_reward`` are not existent in the data, ``dpo.default_chosen_reward`` and ``dpo.default_rejected_reward`` are used. + Obtain a Pretrained Model ######################### To start, we must first get a pretrained model to align. There are two models we recommend to get started. The rest of the tutorial will work with either model, but for demonstration purposes, we will use the smaller 2B model. @@ -80,6 +81,9 @@ For best DPO training performance, it is recommended that you start with a SFT m DPO Model Training ################## +Prepare your Dataset +==================== + Before running the core DPO training, you must prepare your training and validation data to the format required for DPO training. DPO expects ``.jsonl`` files where each line is a JSON dict corresponding to a single, complete sample, as shown below:: {"prompt": "Which year was the Magna Carta signed?", "chosen_response": "1215", "rejected_response": "I refuse to answer this question."} @@ -94,6 +98,63 @@ Always follow the prompt-response template format used during your SFT training Your JSONL file must contain at least as many samples as the Global Batch Size (GBS) you plan to use during training. For example, if GBS = 64, ensure that both your training and validation files include at least 64 samples. Using a file with fewer samples than the GBS will result in a crash. +Sequence Packing with DPO +========================= + +We also support packed sequence training with DPO. Sequence packing is a training technique in which multiple training examples are concatenated to create one longer sequence. This approach eliminates the need for padding and improves GPU utilization. +Refer to the `sequence packing documentation `_ for a detailed overview of sequence packing and its advantages. This document +discusses sequence packing for SFT in particular, but the same benefits apply to DPO. + +Packing your DPO dataset is done as a preprocessing step in NeMo and NeMo-Aligner. We provide a `script https://github.com/NVIDIA/NeMo-Aligner/blob/ashors/dpo-packing/examples/nlp/data/dpo/prepare_packed_dpo_dataset.py`_ to pack your DPO dataset. This script assumes you already have a prepared DPO-format dataset. Three main steps are run in this script: + + #. The online processing code in ``DPOModelDataset`` is run. This includes tasks such as prompt template manipulation and tokenization. The result is an array of tokenized sequences, represented by indices. + #. Chosen and rejected sequences are concatenated. + #. The tokenized sequences are grouped by length and a packing algorithm is run. + + +You can read more about packing algorithms `here `_. Currently, two variants of ``first_fit`` are supported: + + #. ``first_fit_decreasing``: sorts the sequences in decreasing order before applying the first-fit algorithm. It generates a more optimal packing, but it tends to keep all short sequences together, which may have an impact for convergence. + #. ``first_fit_shuffle``: runs first-fit in a random order. Packing is less optimal but it keeps the dataset order random. The recommendation is to run first_fit_shuffle and check the packed sequence lengths. If they are similar to the target length (i.e. efficient packing), then use shuffle. Otherwise try first_fit_decreasing. + + +The following is an example of running the packing script to prepare your DPO dataset: + +.. code-block:: bash + + python examples/nlp/data/dpo/prepare_packed_dpo_dataset.py \ + model.data.data_prefix=/path/to/training.jsonl \ + +model.encoder_seq_length=2048 \ + +tokenizer_path=/path/to/tokenizer/model \ + +output_dir=/path/to/output_folder \ + +pack_sizes=[4096] \ + +tokenizer_type= + [ +packing_algorithm=first_fit_shuffle \ ] + [ ++model.seed=0 ] + + +Because this script packs chosen and rejected sequences together, ``pack_sizes`` should always be at least double ``model.encoder_seq_length``. +When running training using the packed dataset, ``model.encoder_seq_length`` should be set to the ``packed_size`` used for the packed dataset. + +To use the packed dataset during training, add the following line to your train command: + +.. code-block:: bash + + ++model.data.data_impl=packed_jsonl + + +A few notes to keep in mind when running training with sequence packing: + + #. Make sure to pack your train, validation, and test datasets. + #. Sequence packing can only be run with a micro batch size of 1. + #. Sequence packing is supported via Transformer Engine, so be sure to enable transformer engine in your config by setting `++model.transformer_engine=True`. + #. Sequence packing increases the number of examples processed per global batch. Try to scale your global batch size accordingly by setting the new + global batch size to approximately ``unpacked_global_batch_size / avg_num_sequences_per_pack``. The average number of sequences per pack is printed to stdout after ``prepare_packed_dpo_dataset.py`` completes. + + +Begin Training +============== + Once your data is processed into the correct format, you are ready to begin DPO training. You must start with a pretrained or SFT trained model. For this section, we will use the SFT model trained in the previous step to train the DPO model. For the purposes of the following sections, we assume that your training ``.jsonl`` file is located in ``/path/to/train_dpo_format.jsonl`` and your validation ``.jsonl`` file is located in ``/path/to/valid_dpo_format.jsonl``. diff --git a/examples/nlp/data/dpo/prepare_packed_dpo_dataset.py b/examples/nlp/data/dpo/prepare_packed_dpo_dataset.py new file mode 100644 index 000000000..df8db5d86 --- /dev/null +++ b/examples/nlp/data/dpo/prepare_packed_dpo_dataset.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Tuple + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.sequence_packing_utils import create_hist, create_packing_strategy +from nemo_aligner.data.nlp.builders import build_train_valid_test_dpo_datasets +from nemo_aligner.data.nlp.datasets import DPOModelDataset + +if TYPE_CHECKING: + from omegaconf import DictConfig + +""" +Script to prepare packed dataset from a DPO dataset in the jsonl format. +Three main steps are run in this script: +1. The online processing code in DPOModelDataset is run (including prompt template manipulation, +sequence length truncation, tokenization, etc) and the result is an array of tokenized sequences, +represented by indices). +2. chosen and rejected sequences are concatenated for each example +3. The sequences are grouped by length, and a packing algorithm is run. (https://en.wikipedia.org/wiki/Bin_packing_problem#Offline_algorithms) +Currently, two variants of "first fit" are supported. +"first_fit_decreasing" sorts the sequences in decreasing order before applying first-fit. +It generates a more optimal packing, but it tends to keep all short sequences together, which may affect convergence. +"first_fit_shuffle" runs first-fit in a random order. Packing is less optimal but it keeps the dataset order random. +The recommendation is to run "first_fit_shuffle" and check the packed sequence lengths in the printout. +If they are similar to the target length (i.e. packing is efficient), then use shuffle. Otherwise try first_fit_decreasing. + +Example usage: + +python scripts/nlp_language_modeling/prepare_packed_dpo_dataset.py \ + model.data.train_ds.file_names=[/path/to/training.jsonl] \ + model.encoder_seq_length=1024 \ + +tokenizer_path= \ + +tokenizer_type=sentencepiece \ + +output_dir=/path/to/output_folder \ + +pack_sizes=[2048,4096,8192] + +Note: + - Tokenizer path supports SentencePiece tokenizer and HF tokenizer. + For SentencePiece tokenizer, specify the file /path/to/tokenizer.model + For HF tokenizer, specify a folder /path/to/hf_folder which contains tokenizer.json, tokenizer_config.json + and special_tokens_map.json or the HF name of the tokenizer to use (e.g. "meta-llama/Meta-Llama-3-8B") + + - If your model or dataset requires non-default configs for DPO training in NeMo, you will + need to pass in the same configs to ``model.data.train_ds`` as you would for training with unpacked dataset. + + - ``model.encoder_seq_length`` is the length to truncate each sequence before packing multiple sequences + to the size of packed sequence (``pack_size``). + + - ``pack_sizes`` is a list of packed sequence lengths. In this example, there will be three output files, one for + each pack size. The output files are named ``/packed_{pack_size}_seed{seed}.npy``. + This argument is a list because you will likely want to experiment with a few ``pack_sizes`` to find out which length + can fill the GPU memory without exceeding it. Adjusting ``pack_size`` is analogous to adjusting the micro batch size in + the unpacked case. + - **important**: ``pack_sizes`` should be at least double the value of model.encoder_seq_length in order to guarantee + that chosen and rejected sequences for a given example can be packed together. +""" + + +def tokenize_dataset(cfg: "DictConfig", tokenizer_type): + """ + Tokenizes a dataset using the same configuration file as DPOModelDataset. + + This function reads a dataset and tokenizes based on the provided configuration. + + Args: + cfg: A Hydra configuration object containing parameters for tokenization. + + Returns: + A NumPy array containing the tokenized sequences from the dataset. + """ + + logging.info("Tokenizing dataset...") + + if tokenizer_type == "huggingface": + # pass in either a local Hugging Face folder which contains tokenizer.json or a path to the tokenizer on huggingface + tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True) + elif tokenizer_type == "sentencepiece": + tokenizer = get_nmt_tokenizer(library="sentencepiece", tokenizer_model=cfg.tokenizer_path) + else: + raise ValueError(f"unsupported tokenizer type {tokenizer_type}") + + with open(cfg.model.data.data_prefix, "r", encoding="utf_8") as fr: + data_payload = [json.loads(line.strip()) for line in fr] + documents = np.arange(len(data_payload), step=1, dtype=np.int32) + dataset = DPOModelDataset( + cfg=cfg.model, + name="packing_dataset", + tokenizer=tokenizer, + data_prefix=cfg.model.data.data_prefix, + documents=documents, + data=data_payload, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.get("seed", 1234), + drop_last=True, ## False not currently supported + pad_chosen_rejected_to_max=False, + ) + + combined_dataset = [] + for item in dataset: + if item["ignore_example"]: + continue + input_ids = torch.cat((item["chosen"], item["rejected"])).numpy() + labels = torch.cat((item["chosen_labels"], item["rejected_labels"])).numpy() + reward = torch.tensor([item["chosen_reward"], item["rejected_reward"]]).numpy() + boundary = len(item["chosen"]) + lengths = np.array([item["chosen_length"], item["rejected_length"]]) + new_item = { + "input_ids": input_ids, + "labels": labels, + "reward": reward, + "lengths": lengths, + "boundary": boundary, + } + combined_dataset.append(new_item) + + return np.array(combined_dataset) + + +## modified version of https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/sequence_packing_utils.py#L178 for DPO +## pack size should be at least 2*encoder_seq_length since the packed sequences include both the chosen and rejected sequences +## for a given example +def fill_packing_strategy( + assignments: List[List[int]], sequences: Dict[int, List[Dict]], pack_size: int +) -> List[Dict]: + """ + Fills the packing strategy with actual sequence data based on assignments and sequence information. + + This function takes the assignments generated by the packing algorithm (containing sequence length indices), + the original sequences data, and the pack size. It iterates through the assignments, retrieves the corresponding + sequences from the sequences dictionary, and constructs the final output data structure with input IDs, loss masks + (if available), and starting indices for each sequence in a packed sequence. + + Args: + assignments: A list of lists, where each inner list represents a bin and contains the indices of the + sequence lengths assigned to that bin (output of 'create_packing_strategy'). + sequences: A dictionary where keys are sequence lengths and values are lists of corresponding sequences + from the dataset (output of 'create_hist'). + pack_size: The maximum capacity of each bin. + + Returns: + output_data: A list of dictionaries, where each dictionary represents a packed sequence with its input IDs, + loss mask (if available), and starting indices. + """ + ifile_handles = dict() + for seq_len in tqdm(range(pack_size + 1)): + per_seq_data = sequences[seq_len] + if len(per_seq_data) > 0: + perm = np.random.permutation(len(per_seq_data)) + + perm = np.random.permutation(len(per_seq_data)) + input_ids = np.array([x["input_ids"] for x in per_seq_data])[perm].tolist() + labels = np.array([x["labels"] for x in per_seq_data])[perm].tolist() + reward = np.array([x["reward"] for x in per_seq_data])[perm].tolist() + lengths = np.array([x["lengths"] for x in per_seq_data])[perm].tolist() + boundary = np.array([x["boundary"] for x in per_seq_data])[perm].tolist() + + ifile_handles[seq_len] = (input_ids, labels, reward, lengths, boundary) + + input_ids, labels, reward, lengths, seq_boundaries = {}, {}, {}, {}, {} + + for oindex, assignment in tqdm(enumerate(assignments), total=len(assignments)): + _input_ids, _labels, _reward, _lengths, _seq_boundaries = [], [], [], [], [0] + + for seq_length in assignment: + + previous_seq_len = len(_input_ids) + + _input_ids.extend(ifile_handles[seq_length][0].pop()) + _labels.extend(ifile_handles[seq_length][1].pop()) + _reward.extend(ifile_handles[seq_length][2].pop()) + _lengths.extend(ifile_handles[seq_length][3].pop()) + + ## store the boundaries for the chosen, rejected sequences + _seq_boundaries.append(previous_seq_len + ifile_handles[seq_length][4].pop()) + _seq_boundaries.append(len(_input_ids)) + + input_ids[oindex] = _input_ids + labels[oindex] = _labels + reward[oindex] = _reward + lengths[oindex] = _lengths + seq_boundaries[oindex] = _seq_boundaries + + output_data = [] + for i in range(len(input_ids)): + item_dict = { + "input_ids": input_ids[i], + "labels": labels[i], + "reward": reward[i], + "lengths": lengths[i], + "seq_boundaries": seq_boundaries[i], + } + output_data.append(item_dict) + + # (input_ids, labels, reward, lengths, boundary) = length 5 + for i in range(5): + assert all( + not seq[i] for seq in ifile_handles.values() + ), "Error: There are items left over from the assignment" + return output_data + + +@dataclass +class PackingArgs: + output_dir: str = "output" + pack_sizes: Tuple[int] = (2048,) + packing_algorithm: str = "first_fit_shuffle" + tokenizer_type: str = "sentencepiece" ## one of "huggingface" or "sentencepiece" + + def from_config(self, cfg: "DictConfig"): + for required_arg in ("output_dir", "pack_sizes"): + assert cfg.get(required_arg, None), f"Please specify +{required_arg}=..." + self.output_dir = cfg.output_dir + self.pack_sizes = cfg.pack_sizes + self.packing_algorithm = cfg.get("packing_algorithm", "first_fit_shuffle") + self.tokenizer_type = cfg.tokenizer_type + return self + + +@hydra_runner(config_path="../../gpt/conf", config_name="gpt_dpo") +def main(cfg: "DictConfig") -> None: + args = PackingArgs().from_config(cfg) + dataset = tokenize_dataset(cfg, args.tokenizer_type) + sequences, histogram = create_hist( + dataset, 2 * cfg.model.data.seq_length + ) ## multiply by 2 because packed sequences include chosen and rejected + for pack_size in args.pack_sizes: + assignments = create_packing_strategy(histogram, pack_size, args.packing_algorithm) + output_data = fill_packing_strategy(assignments, sequences, pack_size) + + # save output data + os.makedirs(args.output_dir, exist_ok=True) + output_path = os.path.join(args.output_dir, f"packed_{pack_size}_seed{cfg.model.get('seed', 1234)}.npy") + np.save(output_path, output_data) + logging.info(f"Done, output written to {output_path}") + + logging.info( + f""" +✅ Packed datasets with pack sizes {args.pack_sizes} are prepared successfully. +To train with packed sequences, you need to make changes to the DPO config file. +See the NeMo-Aligner sequence packing documentation for more details: +https://github.com/NVIDIA/NeMo-Aligner/blob/main/docs/user-guide/dpo.rst#sequence-packing-with-dpo +""" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index f16a9dacf..f50b9a786 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -20,7 +20,12 @@ from nemo.utils import logging from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate -from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_dpo_datasets, + build_train_valid_test_dpo_packed_datasets, + identity_collate, +) from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( @@ -85,7 +90,11 @@ def main(cfg) -> None: # use the entire dataset train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3 - train_ds, validation_ds, _ = build_train_valid_test_dpo_datasets( + if cfg.model.data.data_impl == "packed_jsonl": + build_fn = build_train_valid_test_dpo_packed_datasets + else: + build_fn = build_train_valid_test_dpo_datasets + train_ds, validation_ds, _ = build_fn( cfg=cfg.model, data_prefix=cfg.model.data.data_prefix, data_impl=cfg.model.data.data_impl, @@ -96,6 +105,7 @@ def main(cfg) -> None: tokenizer=ptl_model.tokenizer, ) + collate = train_ds.global_collate_fn if cfg.model.data.data_impl == "packed_jsonl" else dpo_custom_collate train_dataloader = build_dataloader( cfg=cfg, dataset=train_ds, @@ -136,7 +146,7 @@ def main(cfg) -> None: val_dataloader=val_dataloader, test_dataloader=None, collate_fn=partial( - dpo_custom_collate, + collate, eos_id=ptl_model.tokenizer.eos_id, reset_position_ids=cfg.model.data.get("reset_position_ids", False), reset_attention_mask=cfg.model.data.get("reset_attention_mask", False), diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 626b7b58e..f499c44b7 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -367,12 +367,18 @@ def augment_dataloader(self, dataloader): batch = next(iter_dataloader) batch = self.collate_fn(batch) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() - chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0) - batch["ref_policy_log_probs_chosen"] = chosen_logps - batch["ref_policy_log_probs_rejected"] = reject_logps + packed = "input_ids" in batch + if not packed: + chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0) + batch["ref_policy_log_probs_chosen"] = chosen_logps + batch["ref_policy_log_probs_rejected"] = reject_logps + else: + batch["ref_policy_log_probs"] = logprobs yield batch - del logprobs, chosen_logps, reject_logps + del logprobs + if not packed: + del chosen_logps, reject_logps except StopIteration: break diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index 97b68ffe4..85ecac1e1 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -44,6 +44,7 @@ from nemo.utils import logging from nemo_aligner.data.nlp.datasets import ( DPOModelDataset, + DPOPackedDataset, KnowledgeDistillationDataset, KTOModelDataset, RegressionRewardModelDataset, @@ -112,12 +113,14 @@ def _build_dataset(current_data_prefix, current_num_samples): elif data_impl.startswith("json"): with open(current_data_prefix, "r", encoding="utf_8") as fr: data_payload = [json.loads(line.strip()) for line in fr] + elif data_impl == "packed_jsonl": + data_payload = np.load(current_data_prefix, allow_pickle=True) elif data_impl == "chunked_jsonl": assert isinstance(n_chunks, int) and n_chunks >= 1, f"Not valid n_chunks {n_chunks}" data_payload = ChunkedJsonl(current_data_prefix, n_chunks, n_examples_per_chunk) else: raise RuntimeError( - f"data.data_impl must be one of mmap, json, jsonl or chunked_jsonl, but got {data_impl}" + f"data.data_impl must be one of mmap, json, jsonl, packed_jsonl, or chunked_jsonl, but got {data_impl}" ) total_num_of_documents = len(data_payload) @@ -319,11 +322,15 @@ def _build_train_valid_test_datasets( elif data_impl.startswith("json"): with open(data_prefix, "r", encoding="utf_8") as fr: data_payload = [json.loads(line.strip()) for line in fr] + elif data_impl == "packed_jsonl": + data_payload = np.load(data_prefix, allow_pickle=True) elif data_impl == "chunked_jsonl": assert isinstance(n_chunks, int) and n_chunks >= 1, f"Not valid n_chunks {n_chunks}" data_payload = ChunkedJsonl(data_prefix, n_chunks, n_examples_per_chunk) else: - raise RuntimeError(f"data.data_impl must be one of mmap, json, jsonl or chunked_jsonl, but got {data_impl}") + raise RuntimeError( + f"data.data_impl must be one of mmap, json, jsonl, packed_jsonl, or chunked_jsonl, but got {data_impl}" + ) total_num_of_documents = len(data_payload) splits = get_train_valid_test_split_(splits_string, total_num_of_documents) @@ -372,6 +379,7 @@ def build_dataset(index, name): build_train_valid_test_rlhf_datasets = partial(build_train_valid_test_datasets, RLHFDataset) build_train_valid_test_rm_datasets = partial(build_train_valid_test_datasets, RewardModelDataset) build_train_valid_test_dpo_datasets = partial(build_train_valid_test_datasets, DPOModelDataset) +build_train_valid_test_dpo_packed_datasets = partial(build_train_valid_test_datasets, DPOPackedDataset) build_train_valid_test_kto_datasets = partial(build_train_valid_test_datasets, KTOModelDataset) build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset) build_train_valid_test_knowledge_distillation_datasets = partial( diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index a07bf61a1..b1b6e2d6c 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -14,6 +14,7 @@ """Custom datasets for RLHF training""" +import math import os from typing import Dict, List @@ -30,6 +31,7 @@ ) from nemo.core import Dataset from nemo.utils import logging +from nemo_aligner.utils import parallel_state class KnowledgeDistillationDataset(Dataset): @@ -311,7 +313,17 @@ class DPOModelDataset(Dataset): """ def __init__( - self, cfg, tokenizer, name, data_prefix, documents, data, seq_length, seed, drop_last=True, + self, + cfg, + tokenizer, + name, + data_prefix, + documents, + data, + seq_length, + seed, + drop_last=True, + pad_chosen_rejected_to_max=True, ): super().__init__() self.cfg = cfg @@ -321,6 +333,10 @@ def __init__( self.seq_length = seq_length self.tokenizer = tokenizer + ## pad_chosen_rejected_to_max should be true unless iterating through the + ## dataset as a data preparation step for packing + self.pad_chosen_rejected_to_max = pad_chosen_rejected_to_max + self.reset_position_ids = cfg.data.get("reset_position_ids", False) self.reset_attention_mask = cfg.data.get("reset_attention_mask", False) self.eod_mask_loss = cfg.data.get("eod_mask_loss", False) @@ -455,19 +471,32 @@ def __getitem__(self, idx): max_curr_seq_len = max(chosen_len, reject_len) - chosen_tokens = torch.nn.functional.pad( - torch.LongTensor(chosen), (0, max_curr_seq_len - chosen_len), mode="constant", value=self.eos_id - ) - rejected_tokens = torch.nn.functional.pad( - torch.LongTensor(reject), (0, max_curr_seq_len - reject_len), mode="constant", value=self.eos_id - ) - labels_chosen_tokens = torch.nn.functional.pad( - torch.LongTensor(chosen_labels), (0, max_curr_seq_len - len(chosen_labels)), mode="constant", value=-100 - ) - labels_reject_tokens = torch.nn.functional.pad( - torch.LongTensor(reject_labels), (0, max_curr_seq_len - len(reject_labels)), mode="constant", value=-100 - ) + if self.pad_chosen_rejected_to_max: + chosen_tokens = torch.nn.functional.pad( + torch.LongTensor(chosen), (0, max_curr_seq_len - chosen_len), mode="constant", value=self.eos_id + ) + rejected_tokens = torch.nn.functional.pad( + torch.LongTensor(reject), (0, max_curr_seq_len - reject_len), mode="constant", value=self.eos_id + ) + labels_chosen_tokens = torch.nn.functional.pad( + torch.LongTensor(chosen_labels), + (0, max_curr_seq_len - len(chosen_labels)), + mode="constant", + value=-100, + ) + labels_reject_tokens = torch.nn.functional.pad( + torch.LongTensor(reject_labels), + (0, max_curr_seq_len - len(reject_labels)), + mode="constant", + value=-100, + ) + else: + chosen_tokens = torch.LongTensor(chosen) + rejected_tokens = torch.LongTensor(reject) + labels_chosen_tokens = torch.LongTensor(chosen_labels) + labels_reject_tokens = torch.LongTensor(reject_labels) + ignore_example = False # ignore the example whose tokenized text exceeds max seq length. if max_curr_seq_len > self.seq_length: logging.warning( @@ -480,6 +509,7 @@ def __getitem__(self, idx): labels_reject_tokens = torch.ones_like(rejected_tokens) * (-100) chosen_len = self.nograd_length reject_len = self.nograd_length + ignore_example = True output = { "chosen": chosen_tokens, @@ -490,7 +520,149 @@ def __getitem__(self, idx): "rejected_labels": labels_reject_tokens, "chosen_reward": payload.get("chosen_reward", self.default_chosen_reward), "rejected_reward": payload.get("rejected_reward", self.default_rejected_reward), + "ignore_example": ignore_example, + } + return output + + +class DPOPackedDataset(DPOModelDataset): + """A dataset class for DPO with sequence packing. Data is expected to be + pre-tokenized and pre-packed using examples/nlp/data/dpo/prepare_packed_dpo_dataset.py. + """ + + REWARDS_PAD_ID = -1000 + LABELS_PAD_ID = -100 + + def __init__( + self, + cfg, + tokenizer, + name, + data_prefix, + documents, + data, + seq_length, + seed, + drop_last=True, # return_cu_seqlen: bool = True ## should always be true + ): + + super().__init__(cfg, tokenizer, name, data_prefix, documents, data, seq_length, seed, drop_last) + self.data_prefix = data_prefix + + def __getitem__(self, idx): + return self.data[idx] + + def _ceil_to_nearest(self, n, m): + return (n + m - 1) // m * m + + def _maybe_cast_to_list(self, x): + return [item.tolist() if isinstance(item, np.ndarray) else item for item in x] + + def _collate_item(self, item, max_length, pad_id): + item = self._maybe_cast_to_list(item) + item = [x + [pad_id] * (max_length - len(x)) for x in item] + return item + + ## reset_position_ids, reset_attention_mask and eod_mask_loss are unused but are needed to match the API of dpo_custom_collate + def global_collate_fn( + self, + batch, + eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + pad_length_to_multiple_of: int | None = None, + ): + def combine_keys(key): + return [item[key] for item in batch] + + lengths = combine_keys("lengths") + rewards = combine_keys("reward") + seq_boundaries = combine_keys("seq_boundaries") + + input_ids = [ + np.concatenate( + [ + item["input_ids"][item["seq_boundaries"][i] : item["seq_boundaries"][i + 1] - 1] + for i in range(len(item["seq_boundaries"]) - 1) + ] + ) + for item in batch + ] + labels = [ + np.concatenate( + [ + item["labels"][item["seq_boundaries"][i] + 1 : item["seq_boundaries"][i + 1]] + for i in range(len(item["seq_boundaries"]) - 1) + ] + ) + for item in batch + ] + + if pad_length_to_multiple_of: + max_seq_len = torch.tensor(max(ex.shape[0] for ex in input_ids), device=torch.cuda.current_device()) + torch.distributed.all_reduce( + max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group() + ) + max_length = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of + else: + # pad to the nearest multiple of 16 for FP8 training + # for many datasets in practice, all packed sequence lengths are very close to the + # target length (2048, 4096, 8192), so there is very minimal padding + max_length = max(len(l) for l in input_ids) + max_length = min(self.seq_length, self._ceil_to_nearest(max_length, 16)) + + position_ids: List[List[int]] = [] + cu_seqlens: List[List[int]] = [] + for item in batch: + position_ids.append([]) + cu_seqlens.append([0]) + seqlens = np.array(item["seq_boundaries"][1:]) - np.array(item["seq_boundaries"][:-1]) + for l in seqlens: + position_ids[-1].extend(list(range(l - 1))) ## l - 1 to exclude labels + cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1) + # set last seq to the max seq len because rope and attn kernels expect no padding + cu_seqlens[-1][-1] = max_length + + assert len(input_ids[0]) == len( + position_ids[0] + ), "Dataset problem: input_ids and position_ids lengths don't match" + + input_ids = self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + labels = self._collate_item(labels, max_length=max_length, pad_id=self.LABELS_PAD_ID) + position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0) + + max_num_sequences = max(len(l) for l in lengths) + lengths = self._collate_item(lengths, max_length=max_num_sequences, pad_id=0) + rewards = self._collate_item(rewards, max_length=max_num_sequences, pad_id=self.REWARDS_PAD_ID) + + output = { + "input_ids": torch.LongTensor(input_ids), + "labels": torch.LongTensor(labels), + "lengths": torch.LongTensor(lengths), + "rewards": torch.FloatTensor(rewards), + "position_ids": torch.LongTensor(position_ids), } + + cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) + + # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. + cu_seqlens = torch.IntTensor(cu_seqlens) + cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) + seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] + max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + + output.update( + { + "attention_mask": torch.LongTensor( + [1] * len(input_ids) + ), # no attention mask is needed for packed seq, this serves as a placeholder + "cu_seqlens": torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 + "cu_seqlens_argmin": cu_seqlens_argmin, # only required for perf + "max_seqlen": max_seqlen, # only required for perf + } + ) + return output diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 952b4e897..c5404ac7b 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -30,6 +30,7 @@ ) from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo_aligner.data.nlp.datasets import DPOPackedDataset from nemo_aligner.models.alignable_interface import SupervisedInterface from nemo_aligner.utils import parallel_state from nemo_aligner.utils.distributed import broadcast_2d_tensor, from_parallel_logits_to_logprobs @@ -74,16 +75,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.gt_reward_scale = self.cfg.dpo.get("gt_reward_scale", 1.0) @torch.no_grad() - def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): + def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, cu_seqlens=None, average_log_probs=False): pi_logprobs = pi_logprobs.detach() dp_group = parallel_state.get_data_parallel_group() batch_logs = self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels[:, 1:], average_log_probs=average_log_probs + pi_logprobs - ref_logprobs, labels, cu_seqlens, average_log_probs=average_log_probs ) - output_list = [torch.zeros_like(batch_logs) for _ in range(dp_group.size())] + num_examples_on_this_rank = torch.tensor(batch_logs.size(), device=torch.cuda.current_device()) + num_examples = [torch.zeros_like(num_examples_on_this_rank) for _ in range(dp_group.size())] + torch.distributed.all_gather(num_examples, num_examples_on_this_rank, group=dp_group) + output_list = [torch.zeros(size, device=torch.cuda.current_device()) for size in num_examples] torch.distributed.all_gather(output_list, batch_logs, group=dp_group) @@ -96,6 +100,7 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_lo def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) + packed = "input_ids" in batch required_keys = set() if parallel_state.get_pipeline_model_parallel_world_size() == 1: @@ -104,46 +109,63 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # there is a problem with apex ignoring the mask on the older models # so we will always give the attention mask required_keys.add("attention_mask") + if "cu_seqlens" in batch: + required_keys.add("cu_seqlens") if parallel_state.is_pipeline_first_stage(): - required_keys.update(("chosen", "rejected", "position_ids")) + if packed: + required_keys.update(("input_ids", "position_ids")) + ## batch not packed --> chosen and rejected are separate keys + else: + required_keys.update(("chosen", "rejected", "position_ids")) if parallel_state.is_pipeline_last_stage(): - required_keys.update( - ( - "ref_policy_log_probs_chosen", - "ref_policy_log_probs_rejected", - "chosen_labels", - "rejected_labels", - "chosen_rewards", - "rejected_rewards", + if not packed: + required_keys.update( + ( + "ref_policy_log_probs_chosen", + "ref_policy_log_probs_rejected", + "chosen_labels", + "rejected_labels", + "chosen_rewards", + "rejected_rewards", + ) + ) + else: + required_keys.update( + ("ref_policy_log_probs", "labels", "rewards",) ## chosen and rejected interleaved ) - ) batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} - tokens, labels, ref_logprobs, gt_rewards = None, None, None, None - if batch["chosen"] is not None and batch["rejected"] is not None: - tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) - - if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: - labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - - if ( - batch.get("ref_policy_log_probs_chosen") is not None - and batch.get("ref_policy_log_probs_rejected") is not None - ): - ref_logprobs = torch.cat( - (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 - ) + tokens, labels, ref_logprobs, gt_rewards, cu_seqlens = None, None, None, None, None + if packed: ## packed sequence + tokens = batch["input_ids"] + labels = batch["labels"] + gt_rewards = batch["rewards"] + ref_logprobs = batch.get("ref_policy_log_probs", None) + else: + if batch["chosen"] is not None and batch["rejected"] is not None: + tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) + if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: + labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) + if ( + batch.get("ref_policy_log_probs_chosen") is not None + and batch.get("ref_policy_log_probs_rejected") is not None + ): + ref_logprobs = torch.cat( + (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 + ) - if batch["chosen_rewards"] is not None and batch["rejected_rewards"] is not None: - gt_rewards = torch.cat((batch["chosen_rewards"], batch["rejected_rewards"]), dim=0) + if batch["chosen_rewards"] is not None and batch["rejected_rewards"] is not None: + gt_rewards = torch.cat((batch["chosen_rewards"], batch["rejected_rewards"]), dim=0) # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs # these two lines ensure your position_ids and attn_mask are always B=1 # position_ids = batch["position_ids"][0:1] - attention_mask = batch["attention_mask"][0:1] + + ## if using packing via TE, attention mask is generated in TE + attention_mask = batch["attention_mask"][0:1] if not packed else None # Model forward pass forward_args = { @@ -162,6 +184,29 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ else: forward_args.pop("loss_mask") + if "cu_seqlens" in batch: # packed sequence from DPOPackedDataset + # these args are passed eventually into TEDotProductAttention.forward() + cu_seqlens = batch["cu_seqlens"].squeeze() # remove batch size dimension (mbs=1) + + max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None + cu_seqlens_argmin = batch["cu_seqlens_argmin"] if "cu_seqlens_argmin" in batch else None + + # remove -1 "paddings" added in collate_fn + if cu_seqlens_argmin is not None: + cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] + else: + cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] + + from megatron.core.packed_seq_params import PackedSeqParams + + forward_args["packed_seq_params"] = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format="thd", + ) + output_tensor = model(**forward_args) # in this nemo version the model and autocast dtypes are not synced @@ -173,8 +218,13 @@ def logprobs_func(output_tensor, non_loss_data=True): # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 assert non_loss_data + logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + vocab_parallel_logits=output_tensor, + target=labels, + inference_only=True, + higher_stability=True, + ignore_last=not packed, ) return {"logprobs": logprobs} @@ -187,20 +237,26 @@ def loss_func(output_tensor): target=labels, inference_only=validation_step, higher_stability=True, + ignore_last=not packed, ) + if not packed: + labels_for_loss = labels[:, 1:] + else: + labels_for_loss = labels preference_loss, acc_chosen = self.loss_func( per_token_logps, ref_logprobs, - labels[:, 1:], + labels_for_loss, gt_rewards, + cu_seqlens, average_log_probs=self.preference_avg_log_probs, ) sft_loss = torch.zeros_like(preference_loss) if self.sft_loss_weight != 0: sft_loss = self.sft_loss_func( - per_token_logps, labels[:, 1:], average_log_probs=self.sft_avg_log_probs + per_token_logps, labels_for_loss, cu_seqlens, average_log_probs=self.sft_avg_log_probs ) loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss @@ -212,7 +268,11 @@ def loss_func(output_tensor): ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss, acc_chosen]) out_chosen, out_rejected = self.gather_and_split_rewards( - per_token_logps, ref_logprobs, labels, average_log_probs=self.preference_avg_log_probs + per_token_logps, + ref_logprobs, + labels_for_loss, + cu_seqlens, + average_log_probs=self.preference_avg_log_probs, ) return ( @@ -238,9 +298,44 @@ def split_output_tensor(self, output_tensor): chosen_logps, reject_logps = torch.split(output_tensor.float(), len(output_tensor) // 2, dim=0) return chosen_logps, reject_logps - def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): + def get_reduced_masked_logps(self, logps, labels, cu_seqlens=None, average_log_probs=False): assert logps.shape == labels.shape, "logps and labels shape mismatch" + ## mbs = 1 + logps = logps.squeeze() + labels = labels.squeeze() + + ## break up the packed batch into an unpacked batch + if cu_seqlens is not None: + + ## cu_seqlens has an extra entry if the final example is padded. + ## we have to handle the case where the final example is padded and + ## the case where it is not separately. + split = cu_seqlens[1:-1] if len(cu_seqlens) % 2 == 1 else cu_seqlens[1:-2] + split = split.long().cpu() + logp_unpacked = list(torch.tensor_split(logps, split, -1)) + labels_unpacked = list(torch.tensor_split(labels, split, -1)) + lengths = [ex.shape[-1] for ex in logp_unpacked] + max_length = max(lengths) + + for i in range(len(logp_unpacked)): + logp_unpacked[i] = torch.nn.functional.pad( + logp_unpacked[i], (0, max_length - logp_unpacked[i].shape[-1]), "constant", + ) + labels_unpacked[i] = torch.nn.functional.pad( + labels_unpacked[i], (0, max_length - labels_unpacked[i].shape[-1]), "constant", -100 + ) + + unpacked_logps = logp_unpacked[::2] ## chosen + unpacked_logps_rejected = logp_unpacked[1::2] ## rejected + unpacked_labels = labels_unpacked[::2] + unpacked_labels_rejected = labels_unpacked[1::2] + + unpacked_logps.extend(unpacked_logps_rejected) + unpacked_labels.extend(unpacked_labels_rejected) + logps = torch.stack(unpacked_logps, 0) + labels = torch.stack(unpacked_labels, 0) + loss_mask = (labels > -1).float() if average_log_probs: @@ -249,9 +344,9 @@ def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): else: return (logps * loss_mask).sum(-1) - def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): + def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, cu_seqlens=None, average_log_probs=False): rewards = self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs + pi_logprobs - ref_logprobs, labels, cu_seqlens=cu_seqlens, average_log_probs=average_log_probs, ) chosen_rewards, reject_rewards = self.split_output_tensor(rewards) rewards_delta = chosen_rewards - reject_rewards @@ -262,7 +357,11 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + if cu_seqlens is not None: ## packed sequence + gt_rewards = gt_rewards[gt_rewards != DPOPackedDataset.REWARDS_PAD_ID] + chosen_gt_rewards, reject_gt_rewards = gt_rewards[::2], gt_rewards[1::2] + else: + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) @@ -275,7 +374,11 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + if cu_seqlens is not None: ## packed sequence + gt_rewards = gt_rewards[gt_rewards != DPOPackedDataset.REWARDS_PAD_ID] + chosen_gt_rewards, reject_gt_rewards = gt_rewards[::2], gt_rewards[1::2] + else: + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) @@ -287,7 +390,11 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p elif self.preference_loss == "ipo": loss = torch.mean((chosen_rewards - reject_rewards - 1.0 / (2.0 * self.ref_policy_kl_penalty)) ** 2, 0) elif self.preference_loss == "rpo_sq": - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + if cu_seqlens is not None: ## packed sequence + gt_rewards = gt_rewards[gt_rewards != DPOPackedDataset.REWARDS_PAD_ID] + chosen_gt_rewards, reject_gt_rewards = gt_rewards[::2], gt_rewards[1::2] + else: + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) loss = torch.mean((self.ref_policy_kl_penalty * rewards_delta - gt_rewards_delta) ** 2, 0) @@ -300,19 +407,38 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_p return loss, acc_chosen - def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): - logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) + def sft_loss_func(self, pi_logprobs, labels, cu_seqlens=None, average_log_probs=False): + logprobs = self.get_reduced_masked_logps( + pi_logprobs, labels, cu_seqlens=cu_seqlens, average_log_probs=average_log_probs + ) chosen_logprobs, _ = self.split_output_tensor(logprobs) return -chosen_logprobs.mean(0) def get_loss_and_metrics(self, batch, forward_only): - seq_length = batch["chosen"].shape[1] + packed = "input_ids" in batch + if packed: + seq_length = batch["input_ids"].shape[1] + else: + seq_length = batch["chosen"].shape[1] data_iter = get_iterator_k_split(batch, get_num_microbatches()) set_sync_funcs(self, forward_only) fwd_bwd_function = get_forward_backward_func() + micro_batch_size = self.cfg.micro_batch_size + if not packed: + # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + micro_batch_size *= 2 + else: + assert micro_batch_size == 1, ( + f"Packed sequence is only supported with micro batch size 1," + f" but your micro batch size is {micro_batch_size}." + ) + assert self.cfg.get( + "transformer_engine", False + ), "Transformer Engine should be enabled when using sequence packing." + losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), data_iterator=data_iter, @@ -320,8 +446,7 @@ def get_loss_and_metrics(self, batch, forward_only): num_microbatches=get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size - * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + micro_batch_size=micro_batch_size, ) # only the last stages of the pipeline return losses @@ -405,10 +530,25 @@ def finish_validation_step(self): @torch.no_grad() def get_logprob_batch(self, batch): - seq_length = batch["chosen"].shape[1] - batch_size = batch["chosen"].shape[0] + packed = "input_ids" in batch + if packed: + k = "input_ids" + else: + k = "chosen" + seq_length = batch[k].shape[1] + batch_size = batch[k].shape[0] num_microbatches = divide(batch_size, self.cfg.dpo.log_prob_forward_micro_batch_size) + micro_batch_size = self.cfg.dpo.log_prob_forward_micro_batch_size + if not packed: + # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + micro_batch_size *= 2 + else: + assert micro_batch_size == 1, ( + f"Packed sequence is only supported with forward micro batch size 1," + f" but your forward micro batch size is {micro_batch_size}." + ) + data_iter = get_iterator_k_split(batch, num_microbatches) set_sync_funcs(self, forward_only=True) @@ -421,19 +561,23 @@ def get_logprob_batch(self, batch): num_microbatches=num_microbatches, forward_only=True, seq_length=seq_length, - micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size * 2, + micro_batch_size=micro_batch_size, collect_non_loss_data=True, ) if len(logprobs_list) > 0: - chosen_logprobs_list = [] - rejected_logprobs_list = [] - for item in logprobs_list: - chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) - chosen_logprobs_list.append(chosen_logprobs) - rejected_logprobs_list.append(rejected_logprobs) - - logprobs = torch.cat([torch.cat(chosen_logprobs_list), torch.cat(rejected_logprobs_list)], dim=0) + if not packed: + chosen_logprobs_list = [] + rejected_logprobs_list = [] + for item in logprobs_list: + chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) + chosen_logprobs_list.append(chosen_logprobs) + rejected_logprobs_list.append(rejected_logprobs) + + logprobs = torch.cat([torch.cat(chosen_logprobs_list), torch.cat(rejected_logprobs_list)], dim=0) + else: + logprobs_list = [item["logprobs"] for item in logprobs_list] + logprobs = torch.cat(logprobs_list, dim=0) else: logprobs = None @@ -448,7 +592,6 @@ def get_logprob_batch(self, batch): return logprobs def get_ref_policy_logprobs(self, batch): - if self.use_peft and self.ref_policy_state_dict is None: # when using adapters instead of full-tuning, the actor is reference model + adapters with adapter_control(self): diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index 07378b5d9..654502ae4 100755 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -351,16 +351,23 @@ def calculate_distributed_entropy(vocab_parallel_logits, mask=None): return calculate_entropy(full_log_probs, mask) -def from_parallel_logits_to_logprobs(vocab_parallel_logits, target, inference_only=False, higher_stability=False): +def from_parallel_logits_to_logprobs( + vocab_parallel_logits, target, inference_only=False, higher_stability=False, ignore_last=True +): """get log probs out of a B x S x V//TP tensor NOTE: this function shifts the target, which means you must give it the unmodified targets Returns a B x S-1 tensor """ - target = target.roll(shifts=-1, dims=-1) - return DistributedLogprob.apply(vocab_parallel_logits, target, inference_only, higher_stability)[ - :, :-1 - ].contiguous() + + if ignore_last: + target = target.roll(shifts=-1, dims=-1) + probs = DistributedLogprob.apply(vocab_parallel_logits, target, inference_only, higher_stability).contiguous() + ### ignore_last should be true if labels are not shifted as a data preparation step + if ignore_last: + return probs[:, :-1] + else: + return probs def all_reduce_dict(dictionary, dtype=torch.float32, group=None, op=torch.distributed.ReduceOp.SUM): diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index bc073dcde..6dc939c62 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -1,7 +1,6 @@ #!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR +DATA_DIR=${DATA_DIR} set -eoux pipefail export NCCL_ALGO=Tree @@ -12,10 +11,10 @@ GBS=${GBS:-4} PRETRAINED_CHECKPOINT_NEMO_FILE=${PRETRAINED_CHECKPOINT_NEMO_FILE} -TRAIN_DATA_PATH=$SCRIPT_DIR/test_data/dummy-dpo.jsonl -VALID_DATA_PATH=$SCRIPT_DIR/test_data/dummy-dpo.jsonl +TRAIN_DATA_PATH=${TRAIN_DATA_PATH:-"${DATA_DIR}/dummy-dpo.jsonl"} +VALID_DATA_PATH=$TRAIN_DATA_PATH -NAME="dpo_test" +NAME=${NAME:-"dpo_test"} # PARAMETERS RESULTS_DIR="/tmp/${NAME}" @@ -38,7 +37,6 @@ torchrun --nproc-per-node 2 ${GPFS}/examples/nlp/gpt/train_gpt_dpo.py \ --config-name=${CONF_NAME} \ trainer.num_nodes=1 \ trainer.devices=2 \ - ++model.data.data_impl=jsonl \ ++model.data.seq_length=128 \ ++model.global_batch_size=${GBS} \ ++model.micro_batch_size=1 \ @@ -55,16 +53,17 @@ torchrun --nproc-per-node 2 ${GPFS}/examples/nlp/gpt/train_gpt_dpo.py \ model.data.num_workers=2 \ ++model.tensor_model_parallel_size=1 \ ++model.pipeline_model_parallel_size=1 \ - trainer.dpo.max_steps=3 \ - trainer.dpo.val_check_interval=3 \ + trainer.dpo.max_steps=${MAX_STEPS:-3} \ + trainer.dpo.val_check_interval=${MAX_STEPS:-3} \ trainer.dpo.limit_val_batches=8 \ trainer.dpo.save_interval=0 \ exp_manager.explicit_log_dir=${RESULTS_DIR} \ ++model.activations_checkpoint_granularity=full \ ++model.activations_checkpoint_method=uniform \ ++model.activations_checkpoint_num_layers=1 \ - ++model.dist_ckpt_load_strictness=log_all + ++model.dist_ckpt_load_strictness=log_all \ + "$@" } log_file=$(mktemp /tmp/dpo-log-XXXXXX) -dpo | tee $log_file \ No newline at end of file +dpo "$@" | tee $log_file \ No newline at end of file diff --git a/tests/functional/test_cases/dpo-llama3 b/tests/functional/test_cases/dpo-llama3 index 8e40e94c8..1ba79a9af 100755 --- a/tests/functional/test_cases/dpo-llama3 +++ b/tests/functional/test_cases/dpo-llama3 @@ -14,9 +14,11 @@ # limitations under the License. SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export DATA_DIR=$SCRIPT_DIR/../test_data cd $SCRIPT_DIR set -eoux pipefail PRETRAINED_CHECKPOINT_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ -bash ../dpo.sh +bash ../dpo.sh \ + ++model.data.data_impl=jsonl \ diff --git a/tests/functional/test_cases/dpo-llama3-pack b/tests/functional/test_cases/dpo-llama3-pack new file mode 100755 index 000000000..f5bf81f4b --- /dev/null +++ b/tests/functional/test_cases/dpo-llama3-pack @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +PRETRAINED_CHECKPOINT_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ +TRAIN_DATA_PATH=$SCRIPT_DIR/../test_data/dummy_dpo_packed_90.npy \ +NAME=dpo_pack_test \ +MAX_STEPS=10 \ +bash ../dpo.sh \ + ++model.data.data_impl=packed_jsonl diff --git a/tests/functional/test_data/dummy-dpo.jsonl b/tests/functional/test_data/dummy-dpo.jsonl index 3fd76dd85..a2ea95fe2 100644 --- a/tests/functional/test_data/dummy-dpo.jsonl +++ b/tests/functional/test_data/dummy-dpo.jsonl @@ -1,100 +1,100 @@ {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} {"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} -{"prompt": "System\n\nUser\nThis is a test question?\nAssistant\n", "chosen_response": "This is the chosen response.\n", "rejected_response": "This is the rejected response.\n"} \ No newline at end of file +{"prompt": "System\n\nUser\nThis is another test question?\nAssistant\n", "chosen_response": "This is an alternate chosen response.\n", "rejected_response": "This is another rejected response.\n"} +{"prompt": "System\n\nUser\nThis is a slightly longer question?\nAssistant\n", "chosen_response": "Here's a chosen response.\n", "rejected_response": "And here's a rejected response.\n"} +{"prompt": "System\n\nUser\nThis is the final question?\nAssistant\n", "chosen_response": "This the fourth chosen response.\n", "rejected_response": "This is a rejected response.\n"} \ No newline at end of file diff --git a/tests/functional/test_data/dummy_dpo_packed_90.npy b/tests/functional/test_data/dummy_dpo_packed_90.npy new file mode 100644 index 0000000000000000000000000000000000000000..fb032943a3c8e7247ad8ad06abca171a7bea33cb GIT binary patch literal 66341 zcmeI*cYG8@8wc=&PNb2ur$x3X^99JeNo^q9(7^=U6(Ut2jh?Zt_qLxz;}FX_9m??3-j zbU^Fsy2jaa=MSo$Koj~>V_hxpy{iM>S~())+tOD zEOJ_nF067|FLv6LIc*ck!eq@Nr(Na?YZ~UyO&2-s%Ucx`6pSC{bXaoOVy9!7QZ3tPUplzr%R%|qxUs4>*h8#Oq*F#?{uwjx+Rjm zQqPV^O>CVUne3YEof^|I)wM8?>o)HC(S-%chf`m+Pt7Y#J=ihTt!?Vu&Z#y%f?t&E zJ+693XMUg|liRS>XNs#B}3AXSGDF zhj(@B@z#1`do33m#SiQ<#nyUE*JzGutw*zW6U|!B;hF)5edw^Sro#?}sggvhWA~Q& z=;8fRlUk*`uZ*IJX7ddbol<4(GT&42GYyUX;-|5xAuXd|vFo5s&pMl`doI^fxm+7A z2d3e2ohFx)(I3~P%k?x}_AE?IYS$7TtA)y zy`LOqmEtPUd%~(hdjbck1P=BE4haYxN`V{LFL1-<0z-?=^uU${Zln^pu`h6wfWS>D zaG0h*^N!fq7j70?gm#^)aY`Aka=5vFO4*`mH^|fS2s#|8>CpRGFl{$DDxL_vN-Fov zp*vU917x)u+)^cQE58$N9q5FkDR3K2f!Q2emvQXVyV#h`H_Yb5=JmsE<42KFQp+e< z)SRLjFO^Nzz0NpB<#MclPTsDmGxD6gJzb8|bZKB`OvaC+o#*5oR030eXDkhL#_<$5 zK~tcad&C{%iO@SL+qXw7Q#qXI?-3^j_K1_|a3@WN26n=oMy3?xqqr%@??PKwu>WR%r?}um@JhQ=ol4utp^??F*b95IBPZ zYc&O$xf`4rPlQ(P2KP`o+|%C;&I;@XXVYPwrb7ce;hcCPw67D+RS9(bPUr?YVLb&l zXbLoNH`o|Yf%feN=cxqlkP?|%dx z9;xXto7=g|xE=bR-9J7mej4-O{_)W&m&d?m)9b8{ZF0%`$H&p-@tQ6T+#{Y4KaO_Z zKR!_<@FahQJ~^;LpF)ABY6>*DGoBVdih0l(Pgl7-!|#k|20G(eba}R>OEY(l=fqEB z<}_g6>l)5gIXutbIi4TbIbJ}A7v_b-i) zyF&P~%Hb=%!&d_iU!%jC0uM+ryFYv>Fz>g^KV@-ix4>b1uG(U+S#XNX5!>1~jpJ8oodY#wjO=~k>=k*0$ zeyQ&=%l8*-{=SsZmOF~-rLrE6+e#t;j{TFg+ARzJ%Im&RfYE6spe~yz;FD{ z_-&vwen)}d>kBltGyV`iih0l(e^j~r$?uFm2Rh>~bor~k%Pf0U^FQgU?nu-bj`u(aldCYVWzxtS)NW5qG-_3$ z$FsIuu8MLiWUfYYqFb)cNOsF!2qDw^v^{ShswLCFZrK~N49zI&0qpFSYogW)lYOX7 zbj!YsWVh^x5GL1R-gc8~W0p~q3O$~kCf7l&6(-lEHeqr-M$%+4LYPc2uVvEU+lg*} zOgA*9oD5*~Cl9y-P60hjP4G`EOX*3%%GW# zrnf_~=t{af1hrO}97=7Xdv3r;cFzqF!sJHG+ir4W%rXq0oIKV}lbfK{3X_{sn=mZ|k&NVuJ_;eEZpnP!NZkt4 z4UH-%1K2HfYZP7~buo8}JT|lawC^5zDax&oIiBXkwf6~( z6kZ{97a9~&cV#4{Rv?7b-I&*sYTyoZ8fKXjUttEda|gORYOOF?No`^WTE$52K&uhL zWDWCLCe8bGTsMtbhT)Tw$J%-Ib2@6RFgb(TgvnY)(&S8pFu4cwwwv4&v&@OBFoW7@ zau#Z>FgcsrgvmNa(&QY3FgcfbEt6)R)E!JTjG&x6*4mT0i*hSu*3+CgsW&i^C-p{z zkU5X}Jdn8;CYlpnVFtBU=H4i`Lgsv$6EgQ{JQggy zV%qf{l5$qTC9Z$I+bFRUXes?kZ0}2$?4`uO-vuJCi42x?w2g?y+{??>QNT zS4cgD2F3lJQyIzoJ*Odr)YF;I8>wety1CO9W@Njio{7RMq@G2CLh9L!q||c|Lh8B9 z=Z)0!Fx@boa`#xfrJj$%E2LgPgF@zUUzY3%P2-GJ%lPFa|d?OxY!MBx=uZ=yl5uHVc^uIslTgw$J^*OF@Hx_%ob z8b(r19&7Eoemlyoka-8qiFN%>Msi)h3n66Q&AgUOldm`3gX!kPSm;6R?xFXh@CvE- z(V*y|_cM|`^Z|sB`XKYVQcZhzg@-W9(3YYez|I}$!>F~wd^<7XoR4m z87i}=+xI>SAEMj}nIF-dxO((4BYE}c6NHfYDf3z~O}={c8KxTzs;Eb?dtLt=g;z-Z zf(FI9{v{*1u78CPQom+iOR9nE`Zt(m)U-m6XXm>9Eo!YW`5m>1b^Uura$WxcAx!?r zyzM4`!Yo4@ih2M$P5z8pD@^`EZNlWQjHJom5W?i|%xjr6bFcXaCK|P>(BoOV*ZdRZ zR>=H|=EPp}Z$@&jxfCH}7PRvI!}`z5D8wv7Gm3fuJG*5o)LLP(HMNOu*@ltqmTeKj zWIN{bpj)=bM59&}dOT~pWe1d7A+saRiEdfMNOsE=5JF}r=Cx!R_75^K`YeJjsI|gmS85a8vKu4WEmuSclPfW=Ytq=eKr3UqQQL}UJiD(Jbw}Y9QhU&# zxLULdBYCx`CqhVFmHE7px*Db%jjX6gv0LiuD7->yFB%k5*I*>2_C^S)Ycj7T)xhhT zeK5PcmPs?OA+3vvhGrD?0M?$0)m?|>LhAZ7DAx5sjO4mL7$Kw% zVP03NX|Hx4idjYjD#B6hJil*%S}RO$NNr+W--wZ1*EdE8lbbNFWzxXC=BAisgqorm z%Fey!Fw|OMax-cZCWkYUCO1b2lUp#aWzysoeFUZ(ji{(cv3o@yiNY(Sj-o-aqHoDa zuIO7Kgw(B>*OF@Hsc1AN8nv#_<5_zu+6LuT$lR9Z#HnZuBY7$six4umV_r+9$vt#? zOgFTroD5)h4;_cXE2JiAQ1s9p7|9-*LI|m)%;$~N@tAH-poJdPZmAPcc!ks*X;4Tl zVK?Y4SS1 z3#J>jt!T!xyL;}6!Yia!(4gp^yD^g8a~eWO-JN+|sis~1ti&v%0TtmWc3%CgLah}h ztEo+N&l*Oud!`Y>o?cuv=zj#~ BHvs?u literal 0 HcmV?d00001 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 79fb5e77d..ec5c04ddf 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,13 +15,18 @@ from functools import partial from tempfile import TemporaryDirectory +import numpy as np import pytest import torch.distributed from omegaconf import OmegaConf from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo_aligner.algorithms.dpo import dpo_custom_collate -from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_dpo_datasets, + build_train_valid_test_dpo_packed_datasets, +) from nemo_aligner.data.nlp.scripts.undo_special_tokens import format_conversation from nemo_aligner.utils import parallel_state @@ -359,3 +364,173 @@ def test_dpo_loader_pad_to_multiple(init_model_parallel, make_tmp_jsonl, str_to_ num_mini_batches += 1 assert num_mini_batches == 2 + + +@pytest.mark.run_only_on("GPU") +def test_packed_dpo_loader(init_model_parallel, tmp_path, llama3_tokenizer): + init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + np_data = np.array( + [ + { + "input_ids": np.ones(15), + "labels": np.concatenate((-100 * np.ones(7), np.ones(8))), + "reward": np.ones(4), + "lengths": [5, 3, 4, 3], + "seq_boundaries": [0, 5, 8, 12, 15], + }, + ] + * 8 + ) + + data_path = tmp_path / "data.npy" + np.save(data_path, np_data) + + cfg = OmegaConf.create( + { + "model": { + "data": { + "data_prefix": {"train": [data_path], "validation": [data_path], "test": [data_path]}, + "splits_string": None, + "num_workers": 2, + }, + "seed": 42, + } + } + ) + mbs = 1 + minibs = 2 + gbs = minibs * torch.distributed.get_world_size() + + train_ds, _, _ = build_train_valid_test_dpo_packed_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl="packed_jsonl", + splits_string=None, + train_valid_test_num_samples=[-1 * gbs] * 3, + seq_length=1024, + seed=cfg.model.seed, + tokenizer=llama3_tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=0, + mbs=mbs, + gbs=gbs, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=lambda x: x, + ) + + distributed_collate_fn = partial(train_ds.global_collate_fn, eos_id=llama3_tokenizer.eos_id,) + + num_mini_batches = 0 + for mbatch in train_dataloader: + mbatch = distributed_collate_fn(mbatch) + padded_seq_len = mbatch["input_ids"].shape[1] + for in_name, in_tensor in mbatch.items(): + assert in_tensor.shape[0] == minibs, f"Expected {in_name}.shape={in_tensor.shape} first dim to be {minibs}" + + assert mbatch["input_ids"].shape == (minibs, padded_seq_len) + assert mbatch["labels"].shape == (minibs, padded_seq_len) + assert mbatch["lengths"].shape == (minibs, len(np_data[0]["lengths"])) + assert mbatch["rewards"].shape == (minibs, len(np_data[0]["lengths"])) + ### last cu_seqlen set to max_length, the we add one padding element which gets removed during training + assert torch.equal(mbatch["cu_seqlens"][0], torch.tensor([0, 4, 6, 9, 16, -1])) + assert mbatch["cu_seqlens_argmin"][0] == torch.tensor([5]) + ### this will end up being the final example because it's padded + ### should be fine because final padding tokens are not included in the loss + assert mbatch["max_seqlen"][0] == torch.tensor([7]) + + num_mini_batches += 1 + + assert num_mini_batches == 2 + + +@pytest.mark.run_only_on("GPU") +def test_packed_dpo_loader_pad_to_multiple(init_model_parallel, tmp_path, str_to_list_tokenizer): + init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + np_data = np.array( + [ + { + "input_ids": np.ones(15), + "labels": np.concatenate((-100 * np.ones(7), np.ones(8))), + "reward": np.ones(8), + "lengths": [5, 3, 4, 3], + "seq_boundaries": [0, 5, 8, 12, 15], + }, + ] + * 8 + ) + + data_path = tmp_path / "data.npy" + np.save(data_path, np_data) + + cfg = OmegaConf.create( + { + "model": { + "data": { + "data_prefix": {"train": [data_path], "validation": [data_path], "test": [data_path]}, + "splits_string": None, + "num_workers": 2, + }, + "seed": 42, + } + } + ) + mbs = 1 + minibs = 2 + gbs = minibs * torch.distributed.get_world_size() + expected_seq_len_multiple = 29 # pick a prime to make sure + + train_ds, _, _ = train_ds, _, _ = build_train_valid_test_dpo_packed_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl="packed_jsonl", + splits_string=None, + train_valid_test_num_samples=[-1 * gbs] * 3, + seq_length=1024, + seed=cfg.model.seed, + tokenizer=str_to_list_tokenizer, + ) + + train_dataloader = build_dataloader( + cfg=cfg, + dataset=train_ds, + consumed_samples=0, + mbs=mbs, + gbs=gbs, + load_gbs=True, + pad_samples_to_global_batch_size=False, + collate_fn=lambda x: x, + ) + + distributed_collate_fn = partial( + train_ds.global_collate_fn, + eos_id=str_to_list_tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + pad_length_to_multiple_of=expected_seq_len_multiple, + ) + + num_mini_batches = 0 + for mbatch in train_dataloader: + + mbatch = distributed_collate_fn(mbatch) + for k in ["input_ids", "labels", "position_ids"]: + assert mbatch[k].shape[1] % expected_seq_len_multiple == 0 + + # Check that all ranks have the same length + max_chosen_seq_length = torch.tensor(mbatch["input_ids"].shape[1], device="cuda") + torch.distributed.all_reduce( + max_chosen_seq_length, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group() + ) + assert mbatch["input_ids"].shape[1] == max_chosen_seq_length.item() + + num_mini_batches += 1 + + assert num_mini_batches == 2 diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 2e379db01..43745d2f6 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -32,7 +32,7 @@ """A file to test the core distributed function calls in RLHF""" -def slow_from_parallel_logits_to_logprobs(parallel_logits, tokens): +def slow_from_parallel_logits_to_logprobs(parallel_logits, tokens, ignore_last=True): """a slow but very safe way of computing logits -> logprobs. Uses a lot of memory but good for testing""" # Gather logits across all TP ranks for testing logits = tensor_parallel.gather_from_tensor_model_parallel_region(parallel_logits) @@ -40,8 +40,12 @@ def slow_from_parallel_logits_to_logprobs(parallel_logits, tokens): # Convert from logits to log-probs. full_log_probs = torch.nn.functional.log_softmax(logits, dim=2) - full_log_probs = full_log_probs[:, :-1, :].contiguous() - indices = tokens[:, 1:].unsqueeze(-1) + if ignore_last: + full_log_probs = full_log_probs[:, :-1, :] + indices = tokens[:, 1:].unsqueeze(-1) + else: + indices = tokens.unsqueeze(-1) + full_log_probs = full_log_probs.contiguous() log_probs = torch.gather(input=full_log_probs, dim=2, index=indices).squeeze(dim=-1).contiguous() return log_probs @@ -150,23 +154,27 @@ def test_distributed_masked_global_mean_var(init_model_parallel): @pytest.mark.run_only_on("GPU") @pytest.mark.parametrize( - "batch_size,seed,dtype,atol,rtol,higher_stability", + "batch_size,seed,dtype,atol,rtol,higher_stability,ignore_last", [ - (1, 9999, torch.float32, 1e-08, 1e-05, False), - (4, 100, torch.float32, 1e-08, 1e-05, False), - (8, 1234, torch.float32, 1e-08, 1e-05, False), - (1, 9999, torch.float32, 1e-08, 1e-05, True), - (4, 100, torch.float32, 1e-08, 1e-05, True), - (8, 1234, torch.float32, 1e-08, 1e-05, True), - (1, 746, torch.bfloat16, 0.005, 0.01, False), - (4, 334, torch.bfloat16, 0.005, 0.01, False), - (8, 123456, torch.bfloat16, 0.005, 0.01, False), - (1, 746, torch.bfloat16, 0.005, 0.01, True), - (4, 334, torch.bfloat16, 0.005, 0.01, True), - (8, 123456, torch.bfloat16, 0.005, 0.01, True), + (1, 9999, torch.float32, 1e-08, 1e-05, False, True), + (4, 100, torch.float32, 1e-08, 1e-05, False, True), + (8, 1234, torch.float32, 1e-08, 1e-05, False, True), + (1, 9999, torch.float32, 1e-08, 1e-05, True, True), + (4, 100, torch.float32, 1e-08, 1e-05, True, True), + (8, 1234, torch.float32, 1e-08, 1e-05, True, True), + (1, 746, torch.bfloat16, 0.005, 0.01, False, True), + (4, 334, torch.bfloat16, 0.005, 0.01, False, True), + (8, 123456, torch.bfloat16, 0.005, 0.01, False, True), + (1, 746, torch.bfloat16, 0.005, 0.01, True, True), + (4, 334, torch.bfloat16, 0.005, 0.01, True, True), + (8, 123456, torch.bfloat16, 0.005, 0.01, True, True), + (1, 9999, torch.float32, 1e-08, 1e-05, True, False), + (8, 1234, torch.float32, 1e-08, 1e-05, True, False), ], ) -def test_distributed_log_probs(init_model_parallel, batch_size, seed, dtype, atol, rtol, higher_stability): +def test_distributed_log_probs( + init_model_parallel, batch_size, seed, dtype, atol, rtol, higher_stability, ignore_last +): """This function is used to test our custom log prob function, we compare it against the more memory intensive naive implementation in the fwd and bwd pass """ @@ -197,11 +205,13 @@ def test_distributed_log_probs(init_model_parallel, batch_size, seed, dtype, ato target = torch.randint(0, V_total, size=(B, S), device=device, generator=generator) with torch.no_grad(): - log_probs_fast = from_parallel_logits_to_logprobs(fake_output, target, higher_stability=higher_stability) - log_probs_slow = slow_from_parallel_logits_to_logprobs(fake_output, target) + log_probs_fast = from_parallel_logits_to_logprobs( + fake_output, target, higher_stability=higher_stability, ignore_last=ignore_last + ) + log_probs_slow = slow_from_parallel_logits_to_logprobs(fake_output, target, ignore_last=ignore_last) log_probs_slow_inf_only = from_parallel_logits_to_logprobs( - fake_output, target, inference_only=True, higher_stability=higher_stability + fake_output, target, inference_only=True, higher_stability=higher_stability, ignore_last=ignore_last, ) torch.testing.assert_close( @@ -219,12 +229,14 @@ def test_distributed_log_probs(init_model_parallel, batch_size, seed, dtype, ato msg="forward pass between fast, slow and log prob calculation is not the same!", ) - slow_from_parallel_logits_to_logprobs(fake_output, target).sum().backward() + slow_from_parallel_logits_to_logprobs(fake_output, target, ignore_last=ignore_last).sum().backward() fake_output_grad_slow = fake_output.grad.detach().clone() fake_output.grad = None - from_parallel_logits_to_logprobs(fake_output, target, higher_stability=higher_stability).sum().backward() + from_parallel_logits_to_logprobs( + fake_output, target, higher_stability=higher_stability, ignore_last=ignore_last + ).sum().backward() fake_output_grad_fast = fake_output.grad.detach().clone() torch.testing.assert_close(