From 1d50757c857f79ee1378e0fa2e2ee0a0f1e50507 Mon Sep 17 00:00:00 2001 From: Zhilin Wang Date: Sun, 3 Dec 2023 17:26:41 -0800 Subject: [PATCH] Update readme with links to arxiv and ai playground (#25) * add steerlm Signed-off-by: Zhilin Wang * add license for common.py Signed-off-by: jiaqiz * fix build Signed-off-by: Gerald Shen * concat datasets before training since only 1 epoch is supported now Signed-off-by: jiaqiz * concat AC-SFT training data to get 2 epochs Signed-off-by: jiaqiz --------- Signed-off-by: Zhilin Wang Signed-off-by: jiaqiz Signed-off-by: Gerald Shen Co-authored-by: jiaqiz Co-authored-by: Gerald Shen --- README.md | 2 +- docs/README.md | 2 +- docs/user-guide/SteerLM.rst | 395 +++++++++--------- .../nlp/data/steerlm/attribute_annotate.py | 154 +++++++ examples/nlp/data/steerlm/common.py | 32 ++ .../data/steerlm/preprocess_helpsteer_data.py | 82 ++++ .../steerlm/preprocess_openassistant_data.py | 159 +++++++ .../steerlm/process_to_regression_format.py | 92 ++++ 8 files changed, 714 insertions(+), 204 deletions(-) create mode 100644 examples/nlp/data/steerlm/attribute_annotate.py create mode 100644 examples/nlp/data/steerlm/common.py create mode 100644 examples/nlp/data/steerlm/preprocess_helpsteer_data.py create mode 100644 examples/nlp/data/steerlm/preprocess_openassistant_data.py create mode 100644 examples/nlp/data/steerlm/process_to_regression_format.py diff --git a/README.md b/README.md index 7e0a22c29..a974e7c07 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The toolkit is currently in it's early stages, and we are committed to improving ## Key features -* **SteerLM** +* **SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF.** Learn more at our [SteerLM](https://arxiv.org/abs/2310.05344) and [HelpSteer](https://arxiv.org/abs/2311.09528) papers. Try it instantly for free on [NVIDIA AI Playground](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/llama2-70b-steerlm) * **Supervised Fine Tuning** * **Reward Model Training** * **Reinforcement Learning from Human Feedback using the PPO Algorithm** diff --git a/docs/README.md b/docs/README.md index 08800b2f7..ba0a484bf 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,7 +3,7 @@ ## Custom Trainers NeMo-Aligner uses custom trainers to coordinate all aspects of training. There are currently 3 custom trainers: -1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT and Reward modeling. +1. [SupervisedTrainer](/nemo_aligner/algorithms/supervised.py): for SFT, SteerLM and Reward modeling. 2. [CriticServerTrainer](/nemo_aligner/algorithms/critic_server_trainer.py): trains the RL critic via PyTriton requests. It will also run the reward model depending on the configuration. 3. [PPOTrainer](/nemo_aligner/algorithms/ppo.py): performs the RLHF PPO training, since PPO has components such as the Critic, this trainer will send inference and train requests via [PyTriton](https://github.com/triton-inference-server/pytriton) to the CriticServerTrainer to train and run inference on the critic. diff --git a/docs/user-guide/SteerLM.rst b/docs/user-guide/SteerLM.rst index c995dd7a9..aa3893ac3 100644 --- a/docs/user-guide/SteerLM.rst +++ b/docs/user-guide/SteerLM.rst @@ -1,10 +1,10 @@ .. include:: /content/nemo.rsts Model Alignment by SteerLM Method -@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -**SteerLM** is a novel approach developed by the NVIDIA Research Team, introduced as part of NVIDIA NeMo Alignment methods. It simplifies the customization of large language models (LLMs) and empowers users with dynamic control over model outputs by specifying desired attributes. Despite remarkable progress in natural language generation driven by LLMs like GPT-3, Megatron-Turing, Chinchilla, PaLM-2, Falcon, and Llama 2, these foundational models often fall short in delivering nuanced and user-aligned responses. The current approach for LLM improvement combines supervised fine-tuning and reinforcement learning from human feedback, but it comes with complexities and limited user control. SteerLM addresses these challenges and represents a significant advancement in the field, making it easier to tailor LLMs to specific needs and preferences. This document delves into how SteerLM operates and offers guidance on training a SteerLM model. +**SteerLM** is a novel approach developed by the NVIDIA NeMo Team, introduced as part of NVIDIA NeMo Alignment methods. It simplifies the customization of large language models (LLMs) and empowers users with dynamic control over model outputs by specifying desired attributes. Despite remarkable progress in natural language generation driven by LLMs like GPT-3, Megatron-Turing, Chinchilla, PaLM-2, Falcon, and Llama 2, these foundational models often fall short in delivering nuanced and user-aligned responses. The current approach for LLM improvement combines supervised fine-tuning and reinforcement learning from human feedback, but it comes with complexities and limited user control. SteerLM addresses these challenges and represents a significant advancement in the field, making it easier to tailor LLMs to specific needs and preferences. This document delves into how SteerLM operates and offers guidance on training a SteerLM model. SteerLM ############### @@ -21,241 +21,178 @@ SteerLM leverages a supervised fine-tuning method that empowers you to control r .. image:: https://developer-blogs.nvidia.com/wp-content/uploads/2023/08/steerlm-four-steps.png :alt: SteerLM four steps -By relying solely on the standard language modeling objective, SteerLM simplifies alignment compared to RLHF. It supports user-steerable AI by enabling you to adjust attributes at inference time. This enables the developer to define preferences relevant to the application, unlike other techniques that require using predetermined preferences. +SteerLM simplifies alignment compared to RLHF. It supports user-steerable AI by enabling you to adjust attributes at inference time. This enables the developer to define preferences relevant to the application, unlike other techniques that require using predetermined preferences. SteerLM vs RLHF -################ +############### + Reinforcement Learning from Human Feedback (RLHF) and SteerLM are two methods aimed at aligning language models to human preferences. RLHF trains language models by providing positive or negative feedback on generated responses, reinforcing good behaviors. Specifically, the model is encouraged to generate more text similar to responses that receive positive feedback, and less like those with negative feedback. SteerLM takes a different approach to model alignment. Rather than solely reinforcing "good" behaviors, it categorizes the space of possible model responses using steering labels. At inference time, the model generates based on these categorical labels that steer its output. So while RLHF uses direct feedback on model generations, SteerLM aligns by mapping responses into labeled categories associated with human preferences. The two methods tackle model alignment from different angles - RLHF by directly reinforcing desired model behaviors, and SteerLM by steering generation based on categorical labels. Both aim to produce language model outputs better aligned with human values and preferences. .. note:: - For details of steerLM, please refer to our paper `SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF `_. + For details of SteerLM, please refer to our paper `SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF `_. + For details of HelpSteer dataset, please refer to our paper `HelpSteer: Multi-attribute Helpfulness Dataset for SteerLM `_. Train a SteerLM model ##################### -This section is a step-by-step tutorial that walks you through how to run a full SteerLM pipeline on OASST data with a Llama2 7B LLM model. It includes the following: -Data cleaning and preprocessing -Training the attribute prediction (value model) -Training the attribute-conditioned SFT (SteerLM model) -Inference on the SteerLM model with different attribute values +This section is a step-by-step tutorial that walks you through how to run a full SteerLM pipeline with a Llama2 70B LLM model. It includes the following: -Step 1: Install requirements -############################# -Start by installing the necessary Python libraries: +1. Data download and preprocessing -.. code-block:: bash +2. Training the attribute prediction model (aka regression reward model) - pip install fire langchain==0.0.133 +3. Training the attribute-conditioned SFT -Get access to NeMo. +4. Inference on the SteerLM model with different attribute values -Step 2: Download and subset data -################################## -This document uses a small subset of the OASST dataset. OASST contains open-domain conversations with human annotations for 13 different quality attributes. -First download and subset it: +Step 1: Download Llama 2 LLM model +############################################################# +Download the Llama 2 70B LLM model from HF into the models folder. + +Then convert the Llama 2 LLM into .nemo format: .. code-block:: bash - mkdir -p data - cd data + mkdir -p /models/llama70b/ + python /opt/NeMo/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py --in-file /path/to/llama --out-file /models/llama70b/llama70b.nemo - wget https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz +Download and convert to .nemo format for the 13B model as well, which is needed for the Attribute Prediction Modelling step. - gunzip -f 2023-04-12_oasst_all.trees.jsonl.gz +Untar the .nemo file to obtain the tokenizer in NeMo format (only for the 70B model): - mv 2023-04-12_oasst_all.trees.jsonl data.jsonl +.. code-block:: bash - head -n 5000 data.jsonl > subset_data.jsonl + cd /models/llama70b + tar xvf llama70b.nemo . + rm llama70b.nemo - cd .. + mv _tokenizer.model tokenizer.model -Step 3: Download Llama 2 LLM model and tokenizer and convert -############################################################# -Download the Llama 2 7B LLM model and tokenizer into the models folder. +The prefix for the tokenizer would be different when extracted. Ensure that the correct tokenizer file is used when running the preceding command. -Then convert the Llama 2 LLM into .nemo format: +Step 2: Download and Preprocess data for Attribute Prediction Modelling +####################################################################### + +First, download and convert both datasets into a common format. .. code-block:: bash - python /opt/NeMo/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py --in-file /path/to/llama --out-file /output_path/llama7b.nemo + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/preprocess_openassistant_data.py --output_directory=data/oasst + + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/preprocess_helpsteer_data.py --output_directory=data/helpsteer -Untar the .nemo file to obtain the tokenizer in NeMo format: +Then, merge the two datasets for the train and val subset respectively. .. code-block:: bash - tar xfv /llama7b.nemo . + cat data/oasst/train.jsonl data/helpsteer/train.jsonl | awk '{for(i=1;i<=4;i++) print}' > data/merge_train.jsonl - mv _tokenizer.model tokenizer.model + cat data/oasst/val.jsonl data/helpsteer/val.jsonl > data/merge_val.jsonl -The prefix for the tokenizer would be different when extracted. Ensure that the correct tokenizer file is used when running the preceding command. - -Step 4: Preprocess OASST data -############################# -Preprocess the data using the NeMo preprocessing scripts. Then create separate text-to-value and value-to-text versions: +Finally, preprocess the data into regression reward model training format. .. code-block:: bash - python /opt/NeMo/scripts/nlp_language_modeling/sft/preprocessing.py \ - --input_file=data/subset_data.jsonl \ - --output_file_prefix=data/subset_data_output \ - --mask_role=User \ - --type=TEXT_TO_VALUE \ - --split_ratio=0.95 \ - --seed=10 - - python /opt/NeMo/scripts/nlp_language_modeling/sft/preprocessing.py \ - --input_file=data/subset_data.jsonl \ - --output_file_prefix=data/subset_data_output_v2t \ - --mask_role=User \ - --type=VALUE_TO_TEXT \ - --split_ratio=0.95 \ - --seed=10 - -Step 5: Clean text-to-value data -################################# -Running the following script will remove the records if all the tokens are masked due to truncation by sequence length. + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/process_to_regression_format.py \ + --input-file=data/merge_train.jsonl \ + --output-file=data/merge_train_reg.jsonl -.. code-block:: bash + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/process_to_regression_format.py \ + --input-file=data/merge_val.jsonl \ + --output-file=data/merge_val_reg.jsonl - python /opt/NeMo/scripts/nlp_language_modeling/sft/data_clean.py \ - --dataset_file=data/subset_data_output_train.jsonl \ - --output_file=data/subset_data_output_train_clean.jsonl \ - --library sentencepiece \ - --model_file tokenizer.model \ - --seq_len 4096 - python /opt/NeMo/scripts/nlp_language_modeling/sft/data_clean.py \ - --dataset_file=data/subset_data_output_val.jsonl \ - --output_file=data/subset_data_output_val_clean.jsonl \ - --library sentencepiece \ - --model_file tokenizer.model \ - --seq_len 4096 +Step 3: Train the regression reward model on OASST+HelpSteer data +################################################################# -Step 6: Train the value model on cleaned OASST data -################################################### -For this tutorial, train the value model for 1K steps. Note that we recommend training much longer on more data to get a good value model. +For this tutorial, train the regression reward model for 800 steps. + +Note that you would need to set up multi-node training in your cluster env, depending on the type of cluster you use. For details, please refer to https://lightning.ai/docs/pytorch/stable/clouds/cluster.html .. code-block:: bash - python examples/nlp/gpt/train_gpt_sft.py \ - trainer.num_nodes=1 \ - trainer.devices=4 \ - trainer.precision=bf16 \ - trainer.sft.limit_val_batches=40 \ - trainer.sft.max_epochs=1 \ - trainer.sft.max_steps=1000 \ - trainer.sft.val_check_interval=200 \ - trainer.sft.save_interval=200 \ - model.megatron_amp_O2=True \ - model.restore_from_path=/models/llama7b.nemo \ - model.tensor_model_parallel_size=2 \ - model.pipeline_model_parallel_size=1 \ - model.optim.lr=6e-6 \ - model.optim.name=distributed_fused_adam \ - model.optim.weight_decay=0.01 \ - model.optim.sched.constant_steps=200 \ - model.optim.sched.warmup_steps=1 \ - model.optim.sched.min_lr=5e-6 \ - model.answer_only_loss=True \ - model.activations_checkpoint_granularity=selective \ - model.activations_checkpoint_method=uniform \ - model.data.chat=True \ - model.data.num_workers=0 \ - model.data.chat_prompt_tokens.system_turn_start=\x00 \ - model.data.chat_prompt_tokens.turn_start=\x11 \ - model.data.chat_prompt_tokens.label_start=\x12 \ - model.data.train_ds.max_seq_length=4096 \ - model.data.train_ds.micro_batch_size=2 \ - model.data.train_ds.global_batch_size=128 \ - model.data.train_ds.file_path=data/subset_data_output_train_clean.jsonl \ - model.data.train_ds.index_mapping_dir=/indexmap_dir \ - model.data.train_ds.add_eos=False \ - model.data.train_ds.hf_dataset=True \ - model.data.validation_ds.max_seq_length=4906 \ - model.data.validation_ds.file_path=data/subset_data_output_val_clean.jsonl \ - model.data.validation_ds.micro_batch_size=2 \ - model.data.validation_ds.global_batch_size=128 \ - model.data.validation_ds.index_mapping_dir=/indexmap_dir \ - model.data.validation_ds.add_eos=False \ - model.data.validation_ds.hf_dataset=True \ - exp_manager.create_wandb_logger=True \ - exp_manager.explicit_log_dir=/results \ - exp_manager.resume_if_exists=True \ - exp_manager.resume_ignore_no_checkpoint=True \ - exp_manager.create_checkpoint_callback=True - -Step 7: Generate annotations + python /opt/NeMo-Aligner/examples/nlp/gpt/train_reward_model.py \ + trainer.num_nodes=32 \ + trainer.devices=8 \ + ++model.micro_batch_size=2 \ + ++model.global_batch_size=512 \ + ++model.data.data_impl=jsonl \ + pretrained_checkpoint.restore_from_path=/models/llama13b/llama13b.nemo \ + "model.data.data_prefix={train: ["data/merge_train_reg.jsonl"], validation: ["data/merge_val_reg.jsonl"], test: ["data/merge_val_reg.jsonl"]}" \ + exp_manager.explicit_log_dir=/results/reward_model_13b \ + trainer.rm.val_check_interval=10 \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.project=steerlm \ + exp_manager.wandb_logger_kwargs.name=rm_training \ + trainer.rm.save_interval=10 \ + trainer.rm.max_steps=800 \ + ++model.tensor_model_parallel_size=4 \ + ++model.pipeline_model_parallel_size=1 \ + ++model.activations_checkpoint_granularity="selective" \ + ++model.activations_checkpoint_method="uniform" \ + model.global_batch_size=512 \ + model.optim.sched.constant_steps=0 \ + model.reward_model_type="regression" \ + model.regression.num_attributes=9 + + +Step 4: Generate annotations ############################ -To generate annotation, run the following command in the background to run an inference server: +To generate annotations, run the following command in the background to launch an inference server: .. code-block:: bash - python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \ - gpt_model_file=/models/ \ - pipeline_model_parallel_split_rank=0 \ - server=True \ - tensor_model_parallel_size=2 \ - pipeline_model_parallel_size=1 \ - trainer.precision=bf16 \ - trainer.devices=1 \ - trainer.num_nodes=1 \ - web_server=False \ - port=1424 + python /opt/NeMo-Aligner/examples/nlp/gpt/serve_reward_model.py \ + rm_model_file=/results/reward_model_13b/checkpoints/megatron_gpt.nemo \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + ++model.pipeline_model_parallel_size=1 \ + inference.micro_batch_size=2 \ + inference.port=1424 + Now execute: .. code-block:: bash - python /opt/NeMo/scripts/nlp_language_modeling/sft/attribute_annotate.py --batch_size=1 --host=localhost --input_file_name=data/subset_data_output_train_clean.jsonl --output_file_name=data/subset_data_output_train_value_output.jsonl --port_num=1424 + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/attribute_annotate.py \ + --input-file=data/oasst/train.jsonl \ + --output-file=data/oasst/train_labeled.jsonl \ + --port=1424 - python /opt/NeMo/scripts/nlp_language_modeling/sft/attribute_annotate.py --batch_size=1 --host=localhost --input_file_name=data/subset_data_output_val_clean.jsonl --output_file_name=data/subset_data_output_val_value_output.jsonl --port_num=1424 + python /opt/NeMo-Aligner/examples/nlp/data/steerlm/attribute_annotate.py \ + --input-file=data/oasst/val.jsonl \ + --output-file=data/oasst/val_labeled.jsonl \ + --port=1424 -.. note:: - This step can take a long time to run. For the purposes of this tutorial, we use a small subset of the data and a single inference server. For optimal results, use the full dataset and multiple inference servers to run data annotation in parallel. - -Step 8: Clean the value-to-text data -#################################### -Remove the record if all tokens are masked after truncation by sequence length: + cat data/oasst/train_labeled.jsonl data/oasst/train_labeled.jsonl > data/oasst/train_labeled_2ep.jsonl -.. code-block:: bash - python /opt/NeMo/scripts/data_clean.py \ - --dataset_file=data/subset_data_output_train_value_output.jsonl \ - --output_file=data/subset_data_output_train_value_output_clean.jsonl \ - --library sentencepiece \ - --model_file tokenizer.model \ - --seq_len 4096 +Step 5: Train the Attribute-Conditioned SFT model +################################################# - python /opt/NeMo/scripts/data_clean.py \ - --dataset_file=data/subset_data_output_val_value_output.jsonl \ - --output_file=data/subset_data_output_val_value_output_clean.jsonl \ - --library sentencepiece \ - --model_file tokenizer.model \ - --seq_len 4096 - -Step 9: Train the SteerLM model -############################### -For the purposes of this tutorial, the SteerLM model is trained for 1K steps. Note that we recommend training much longer and on more data to get a well-tuned model. +For the purposes of this tutorial, the Attribute-Conditioned SFT model is trained for 800 steps. .. code-block:: bash python examples/nlp/gpt/train_gpt_sft.py \ - trainer.num_nodes=1 \ - trainer.devices=4 \ + trainer.num_nodes=32 \ + trainer.devices=8 \ trainer.precision=bf16 \ trainer.sft.limit_val_batches=40 \ trainer.sft.max_epochs=1 \ - trainer.sft.max_steps=1000 \ - trainer.sft.val_check_interval=200 \ - trainer.sft.save_interval=200 \ + trainer.sft.max_steps=800 \ + trainer.sft.val_check_interval=800 \ + trainer.sft.save_interval=800 \ model.megatron_amp_O2=True \ - model.restore_from_path=/models/llama7b.nemo \ - model.tensor_model_parallel_size=2 \ - model.pipeline_model_parallel_size=1 \ + model.restore_from_path=/models/llama70b \ + model.tensor_model_parallel_size=8 \ + model.pipeline_model_parallel_size=2 \ model.optim.lr=6e-6 \ model.optim.name=distributed_fused_adam \ model.optim.weight_decay=0.01 \ @@ -267,88 +204,142 @@ For the purposes of this tutorial, the SteerLM model is trained for 1K steps. No model.activations_checkpoint_method=uniform \ model.data.chat=True \ model.data.num_workers=0 \ - model.data.chat_prompt_tokens.system_turn_start=\x00 \ - model.data.chat_prompt_tokens.turn_start=\x11 \ - model.data.chat_prompt_tokens.label_start=\x12 \ + model.data.chat_prompt_tokens.system_turn_start=\'\\' \ + model.data.chat_prompt_tokens.turn_start=\'\\' \ + model.data.chat_prompt_tokens.label_start=\'\\' \ model.data.train_ds.max_seq_length=4096 \ - model.data.train_ds.micro_batch_size=2 \ + model.data.train_ds.micro_batch_size=1 \ model.data.train_ds.global_batch_size=128 \ - model.data.train_ds.file_path=data/subset_data_v2t_train_value_output_clean.jsonl \ + model.data.train_ds.file_path=data/oasst/train_labeled_2ep.jsonl \ model.data.train_ds.index_mapping_dir=/indexmap_dir \ model.data.train_ds.add_eos=False \ model.data.train_ds.hf_dataset=True \ - model.data.validation_ds.max_seq_length=4906 \ - model.data.validation_ds.file_path=data/subset_data_v2t_val_value_output_clean.jsonl \ - model.data.validation_ds.micro_batch_size=2 \ + model.data.validation_ds.max_seq_length=4096 \ + model.data.validation_ds.file_path=data/oasst/val_labeled.jsonl \ + model.data.validation_ds.micro_batch_size=1 \ model.data.validation_ds.global_batch_size=128 \ model.data.validation_ds.index_mapping_dir=/indexmap_dir \ model.data.validation_ds.add_eos=False \ model.data.validation_ds.hf_dataset=True \ exp_manager.create_wandb_logger=True \ - exp_manager.explicit_log_dir=/results \ - exp_manager.resume_if_exists=True \ - exp_manager.resume_ignore_no_checkpoint=True \ - exp_manager.create_checkpoint_callback=True + exp_manager.wandb_logger_kwargs.project=steerlm \ + exp_manager.wandb_logger_kwargs.name=acsft_training \ + exp_manager.explicit_log_dir=/results/acsft_70b \ + exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True + + -Step 10: Inference +Step 6: Inference ################## To start inference, run an inference server in the background using the following command: .. code-block:: bash python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_eval.py \ - gpt_model_file=/models/ \ + gpt_model_file=/results/acsft_70b/checkpoints/megatron_gpt_sft.nemo \ pipeline_model_parallel_split_rank=0 \ server=True \ - tensor_model_parallel_size=1 \ + tensor_model_parallel_size=8 \ pipeline_model_parallel_size=1 \ trainer.precision=bf16 \ - trainer.devices=1 \ + trainer.devices=8 \ trainer.num_nodes=1 \ web_server=False \ - port=1427 + port=1427 + +Please wait for the server to be ready before proceeeding. Next, create Python helper functions: .. code-block:: python - - def get_answer(question, max_tokens, values, eval_port='1427'): - prompt ="System\nA chat between a curious user and an artificial intelligence assistant. \nThe assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUser\n{question}\nAssistant\n{values}\n" - prompts = [prompt.format(question=question, values=values))] - data = {"sentences": prompts, "tokens_to_generate": max_tokens, "top_k": 1, 'greedy': True, 'end_strings': ["", "quality:", "quality:4", "quality:0"]} + + import requests + from collections import OrderedDict + + def get_answer(question, max_tokens, values, eval_port=1427): + prompt = ( + "System\nA chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.\n" + "User\n{question}\nAssistant\n{values}\n" + ) + prompts = [prompt.format(question=question, values=values)] + data = { + "sentences": prompts, + "tokens_to_generate": max_tokens, + "top_k": 1, + "greedy": True, + "end_strings": [""], + } url = f"http://localhost:{eval_port}/generate" response = requests.put(url, json=data) json_response = response.json() - response_sentence = json_response['sentences'][0][len(prompt):] + response_sentence = json_response["sentences"][0][len(prompt):] return response_sentence .. code-block:: python def encode_labels(labels): - items = [] - for key in labels: - value = labels[key] - items.append(f'{key}:{value}') - return ','.join(items) + return ",".join(f"{key}:{value}" for key, value in labels.items()) Next, change the values below to steer the language model: .. code-block:: python - values = OrderedDict([('quality', 4), ('toxicity', 0), ('humor', 0), ('creativity', 0), ('violence', 0), ('helpfulness', 4), ('not_appropriate', 0), ('hate_speech', 0), ('sexual_content', 0), ('fails_task', 0), ('political_content', 0), ('moral_judgement', 0)]) + values = OrderedDict( + [ + ("quality", 4), + ("toxicity", 0), + ("humor", 0), + ("creativity", 0), + ("helpfulness", 4), + ("correctness", 4), + ("coherence", 4), + ("complexity", 4), + ("verbosity", 4), + ] + ) values = encode_labels(values) Finally, ask questions and generate responses: .. code-block:: python - question = """Where and when did techno music originate?""" - print (get_answer(question, 4096, values)) + question = "Write a poem on NVIDIA in the style of Shakespeare" + print(get_answer(question, 512, values)) + +Response is as below + +.. code-block:: python + + """ + In days of yore, in tech's great hall, + A company arose, NVIDIA its call. + With graphics cards, it did astound, + And gaming world with awe did abound. + + But NVIDIA's reach far more than play, + Its GPUs now deep learning's sway. + With neural nets and data vast, + AI's rise, it did forecast. + + From self-driving cars to medical scans, + Its tech now touches all life's plans. + With each new day, its impact grows, + In science, research, and industry's prose. + + So here's to NVIDIA, whose name we praise, + For tech that captivates in countless ways. + With Shakespearean verse, we now impart, + Our thanks and admiration from the heart. + + """ + .. note:: - This tutorial covers only steps 1-3: training the value model, generating annotations, and initial SteerLM model training. Step 4 bootstraps the SteerLM model by sampling responses conditioned on high quality, evaluating them with the value model, and fine-tuning the SteerLM model on this new data. This closing of the loop continually improves the SteerLM model. Be sure to fully train models, use full datasets, and perform bootstrapping for optimal accuracy. + This tutorial covers only Phase 1-3: training the value model, generating annotations, and initial SteerLM model training. Phase 4 bootstraps the SteerLM model by sampling responses conditioned on high quality data, but is ignored for simplicity in this tutorial. + +SteerLM: Novel Technique for Simple and Controllable Model Alignment +#################################################################### -The future of AI with SteerLM -############################## SteerLM provides a novel technique for realizing a new generation of AI systems aligned with human preferences in a controllable manner. Its conceptual simplicity, performance gains, and customizability highlight the transformative possibilities of user-steerable AI. To learn more, please check out our paper `SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF `_. \ No newline at end of file diff --git a/examples/nlp/data/steerlm/attribute_annotate.py b/examples/nlp/data/steerlm/attribute_annotate.py new file mode 100644 index 000000000..3bfa1865e --- /dev/null +++ b/examples/nlp/data/steerlm/attribute_annotate.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for annotating attributes for a dataset by sending requests to a regression reward model server. +""" + + +import argparse +import json +import os +from typing import List + +import jsonlines +import numpy as np +from common import ( + ALL_STEERLM_ATTRIBUTES, + ASSISTANT_TURN_TEMPLATE, + LABEL_PREFIX, + SYSTEM_PROMPT, + SYSTEM_PROMPT_TEMPLATE, + USER_TURN_TEMPLATE, +) +from pytriton.client import FuturesModelClient +from tqdm import tqdm, trange + + +def _str_list2numpy(str_list: List[str]) -> np.ndarray: + str_ndarray = np.array(str_list)[..., np.newaxis] + return np.char.encode(str_ndarray, "utf-8") + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--port", type=int, default=5555) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--model_name", type=str, default="reward_model") + parser.add_argument("--add-eos", action="store_true") + return parser.parse_args() + + +def get_reward( + sentences: List[str], add_EOS=False, host="localhost", port=5555, model_name="reward_model", +): + sentences = _str_list2numpy(sentences) + + futures = [] + + with FuturesModelClient(f"{host}:{port}", model_name) as client: + for sen in np.split(sentences, sentences.shape[0]): + add_EOS_arr = np.ones_like(sen, dtype=bool) * add_EOS + future = client.infer_batch(sentences=sen, add_EOS=add_EOS_arr) + futures.append(future) + + all_result_dicts = [f.result() for f in futures] + + all_rewards, all_exceeded = [], [] + + for output_dict in all_result_dicts: + reward_out = output_dict["rewards"].flatten().tolist() + + all_rewards.append(reward_out) + all_exceeded += output_dict["exceeded"].tolist() + + return all_rewards, all_exceeded + + +def get_key(l): + convs = [c["value"] for c in l["conversations"]] + return "".join(convs) + + +def main(args): + inference_output = args.output_file + + exist = set() + if os.path.exists(inference_output): + with jsonlines.open(inference_output) as reader: + for obj in tqdm(reader): + exist.add(get_key(obj)) + + fout = open(inference_output, "a", encoding="utf-8") + + # to warm up the jit + _ = get_reward(["hello world!"], add_EOS=args.add_eos, host=args.host, port=args.port, model_name=args.model_name) + + all_samples, inputs = [], [] + + with jsonlines.open(args.input_file) as reader: + for obj in tqdm(reader): + if get_key(obj) in exist: + continue + user = obj["mask"] + turns = [] + text = SYSTEM_PROMPT_TEMPLATE.format(value=SYSTEM_PROMPT) + for turn in obj["conversations"]: + value = turn["value"] + if turn["from"] == user: + text += USER_TURN_TEMPLATE.format(value=value) + else: + text += ASSISTANT_TURN_TEMPLATE.format(value=value) + if "label" in turn and turn["label"] is not None: + out_text = text + LABEL_PREFIX + turns.append(out_text) + + all_samples.append(turns) + inputs.append(obj) + + print(f"exist {len(exist)}, rest {len(inputs)}") + if len(inputs) == 0: + exit(0) + + for idx in trange(0, len(all_samples)): + input = inputs[idx] + sample = all_samples[idx] + rewards_all, _ = get_reward( + sample, add_EOS=args.add_eos, host=args.host, port=args.port, model_name=args.model_name + ) + + t = 0 + for turn in input["conversations"]: + if "label" in turn and turn["label"] is not None: + reward = rewards_all[t] + t += 1 + + reward_each = [min(4.0, max(0.0, float(r))) for r in reward] + reward_each = [round(r) for r in reward_each] + + reward_string = ",".join(f"{a}:{r}" for a, r in zip(ALL_STEERLM_ATTRIBUTES, reward_each)) + turn["label"] = reward_string + + assert t == len(rewards_all) + + fout.write(json.dumps(input) + "\n") + + print("all annotations finished") + fout.close() + + +if __name__ == "__main__": + main(prepare_args()) diff --git a/examples/nlp/data/steerlm/common.py b/examples/nlp/data/steerlm/common.py new file mode 100644 index 000000000..2ca120b85 --- /dev/null +++ b/examples/nlp/data/steerlm/common.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SYSTEM_PROMPT = ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." +) + +SYSTEM_PROMPT_TEMPLATE = "System\n{value}\n" + +USER_TURN_TEMPLATE = "User\n{value}\n" + +ASSISTANT_TURN_TEMPLATE = "Assistant\n{value}\n" + +LABEL_PREFIX = "" + +OPEN_ASSISTANT_ATTRIBUTES = ["quality", "toxicity", "humor", "creativity"] + +HELPSTEER_ATTRIBUTES = ["helpfulness", "correctness", "coherence", "complexity", "verbosity"] + +ALL_STEERLM_ATTRIBUTES = OPEN_ASSISTANT_ATTRIBUTES + HELPSTEER_ATTRIBUTES diff --git a/examples/nlp/data/steerlm/preprocess_helpsteer_data.py b/examples/nlp/data/steerlm/preprocess_helpsteer_data.py new file mode 100644 index 000000000..12df98f16 --- /dev/null +++ b/examples/nlp/data/steerlm/preprocess_helpsteer_data.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is to preprocess HelpSteer dataset from HuggingFace format into Attribute-conditioned SFT training format. +""" + +import argparse +import json +import os + +from common import HELPSTEER_ATTRIBUTES, SYSTEM_PROMPT +from datasets import load_dataset + + +def download_helpsteer(): + ds = load_dataset("nvidia/HelpSteer") + train = ds["train"] + val = ds["validation"] + return train, val + + +def format_label(dp): + label_list = [] + for attr in HELPSTEER_ATTRIBUTES: + label_list.append(f"{attr}:{dp[attr]}") + return ",".join(label_list) + + +def process_dataset(data): + output = [] + for dp in data: + conversation_obj = {} + conversation_obj["conversations"] = [ + {"value": dp["prompt"], "from": "User", "label": None}, + {"value": dp["response"], "from": "Assistant", "label": format_label(dp)}, + ] + conversation_obj["system"] = SYSTEM_PROMPT + conversation_obj["mask"] = "User" + conversation_obj["type"] = "VALUE_TO_TEXT" + output.append(conversation_obj) + return output + + +def main(output_dir): + train, val = download_helpsteer() + + os.makedirs(output_dir, exist_ok=True) + processed_train = process_dataset(train) + with open(f"{output_dir}/train.jsonl", "w", encoding="utf-8") as f: + for record in processed_train: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + processed_val = process_dataset(val) + with open(f"{output_dir}/val.jsonl", "w", encoding="utf-8") as f: + for record in processed_val: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-dir", + "--output_directory", + required=True, + help="folder to store the created train.jsonl and val.jsonl; will be created if it does not exist", + ) + args = parser.parse_args() + + main(args.output_directory) diff --git a/examples/nlp/data/steerlm/preprocess_openassistant_data.py b/examples/nlp/data/steerlm/preprocess_openassistant_data.py new file mode 100644 index 000000000..d22381e66 --- /dev/null +++ b/examples/nlp/data/steerlm/preprocess_openassistant_data.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is to preprocess OpenAssistant dataset from HuggingFace format into Attribute-conditioned SFT training format. +""" + +import argparse +import gzip +import json +import os +import random + +import requests +from common import OPEN_ASSISTANT_ATTRIBUTES, SYSTEM_PROMPT + +likert_scale = 5 + + +def encode_labels(labels): + items = [] + for key in OPEN_ASSISTANT_ATTRIBUTES: + if key in labels: + value = labels[key]["value"] + items.append(f"{key}:{round(value*(likert_scale-1))}") + return ",".join(items) + + +def parse_conversations(tree_obj): + """ recusive function that returns all the sub converstaions in a list starting from node tree_obj + + Args: + tree_obj (obj): current conversation node + + Returns: + a list of sub conversation threads including the current conversation node + """ + if "prompt" in tree_obj: + prompt_obj = tree_obj["prompt"] + elif "text" in tree_obj and "role" in tree_obj: + prompt_obj = tree_obj + else: + return [[]] + + if prompt_obj["role"] == "prompter": + role = "User" + elif prompt_obj["role"] == "assistant": + role = "Assistant" + else: + raise ValueError(f'unknown role {prompt_obj["role"]}') + + turn = {"value": prompt_obj["text"], "from": role} + + if "labels" in prompt_obj: + turn["label"] = encode_labels(prompt_obj["labels"]) + all_conversations = [] + multiple_sub_threads = [] + for next_obj in prompt_obj["replies"]: + multiple_threads = parse_conversations(next_obj) + multiple_sub_threads.extend(multiple_threads) + if len(multiple_sub_threads) != 0: + for sub_thread in multiple_sub_threads: + all_conversations.append([turn] + sub_thread) + else: + all_conversations.append([turn]) + return all_conversations + + +def get_data_records(objs, mask_role, type): + output = [] + for obj in objs: + multi_conversations = parse_conversations(obj) + for conversations in multi_conversations: + if len(conversations) <= 1: + # remove single turn conversations + continue + + # mask out labels from user turns + updated_conversation = [] + for turn in conversations: + if turn["from"] == "User": + turn["label"] = None + updated_conversation.append(turn) + + conversation_obj = { + "conversations": updated_conversation, + "system": SYSTEM_PROMPT, + "mask": mask_role, + "type": type, + } + output.append(conversation_obj) + return output + + +def download_open_assistant(output_directory): + filename = f"{output_directory}/2023-04-12_oasst_all.trees.jsonl.gz" + + # only download if doesn't exist + if not os.path.isfile(filename): + url = "https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz" + response = requests.get(url) + with open(filename, mode="wb") as fw: + fw.write(response.content) + + with gzip.open(filename) as f: + file_content = f.readlines() + + data = [json.loads(dp.decode("utf-8")) for dp in file_content] + return data + + +def main(output_directory, proportion_of_train=0.95, seed=10): + os.makedirs(args.output_directory, exist_ok=True) + all_objs = download_open_assistant(output_directory) + + # Note that we manually shuffle and split the dataset into train / valid sets as we do not use + # the official train / valid splits from Hugging Face. This is because we use the full dataset that + # also includes low-quality data (since SteerLM can still learn from such data), instead of + # the smaller "ready for export" dataset. + random.seed(seed) + random.shuffle(all_objs) + + train_num = int(len(all_objs) * proportion_of_train) + train_objs = all_objs[:train_num] + val_objs = all_objs[train_num:] + train_records = get_data_records(train_objs, "User", "VALUE_TO_TEXT") + val_records = get_data_records(val_objs, "User", "VALUE_TO_TEXT") + + with open(f"{output_directory}/train.jsonl", "w", encoding="utf-8") as f: + for record in train_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + with open(f"{output_directory}/val.jsonl", "w", encoding="utf-8") as f: + for record in val_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-dir", + "--output_directory", + required=True, + help="folder to store the created train.jsonl and val.jsonl; will be created if not exist", + ) + args = parser.parse_args() + main(args.output_directory) diff --git a/examples/nlp/data/steerlm/process_to_regression_format.py b/examples/nlp/data/steerlm/process_to_regression_format.py new file mode 100644 index 000000000..867fed4c2 --- /dev/null +++ b/examples/nlp/data/steerlm/process_to_regression_format.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for processing data from Attribute-conditioned SFT training format into regression reward model training format. +""" + + +import argparse +import json + +from common import ( + ALL_STEERLM_ATTRIBUTES, + ASSISTANT_TURN_TEMPLATE, + LABEL_PREFIX, + SYSTEM_PROMPT, + SYSTEM_PROMPT_TEMPLATE, + USER_TURN_TEMPLATE, +) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output-file", type=str, required=True, + ) + parser.add_argument( + "--input-file", type=str, required=True, + ) + return parser.parse_args() + + +def parse(s): + # Split the string by comma + try: + pairs = s.split(",") + + # Split each pair by colon to separate key and value + result = {pair.split(":")[0]: pair.split(":")[1] for pair in pairs} + assert len(result) > 0, "At least one attribute should be present" + return result + except Exception: + raise Exception("invalid sample", s) + + +def process_sample(line, fout): + text = SYSTEM_PROMPT_TEMPLATE.format(value=SYSTEM_PROMPT) + conversations = line["conversations"] + user = line["mask"] + for turn in conversations: + value = turn["value"] + if turn["from"] == user: + text += USER_TURN_TEMPLATE.format(value=value) + else: + text += ASSISTANT_TURN_TEMPLATE.format(value=value) + + if "label" in turn and turn["label"]: # label field is present and not None or empty + out_text = text + LABEL_PREFIX + given_attrs = parse(turn["label"]) + labels = [float(given_attrs.get(a, -100)) for a in ALL_STEERLM_ATTRIBUTES] + newline = {"text": out_text, "label": labels} + + fout.write(json.dumps(newline, ensure_ascii=False) + "\n") + + +def main(args): + f = open(args.input_file, "r", encoding="utf-8") + fout = open(args.output_file, "w", encoding="utf-8") + + lines = f.readlines() + + for line in lines: + jline = json.loads(line) + process_sample(jline, fout) + + f.close() + fout.close() + + +if __name__ == "__main__": + main(prepare_args())