Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VAS to TRL #2195

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
165 changes: 165 additions & 0 deletions examples/scripts/vas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
python vas.py \
--log_with=wandb
--ref_model_name hanseungwook/vas-llama-2-7b-hh-sft
--model_name hanseungwook/vas-tiny-llama-1.1b-hh-sft
"""

import os
import json
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel, get_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline, LogitsProcessorList, BitsAndBytesConfig, AutoModelForSequenceClassification

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, set_seed, VASTrainer, VASConfig


tqdm.pandas()


@dataclass
class ScriptArguments:
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})

# LoraConfig
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})

# Generation kwargs
generation_batch_size: Optional[int] = field(default=16, metadata={"help": "The batch size for generation"})
temperature: Optional[float] = field(default=1.0, metadata={"help": "The temperature for generation"})
top_k: Optional[float] = field(default=0.0, metadata={"help": "The top_k for generation"})
top_p: Optional[float] = field(default=1.0, metadata={"help": "The top_p for generation"})

parser = HfArgumentParser((ScriptArguments, VASConfig))
args, vas_config = parser.parse_args_into_dataclasses()

trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead

def build_response_train_dataset(config, dataset_name='Anthropic/hh-rlhf'):
ds = load_dataset(dataset_name, split='train')
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token_id = tokenizer.unk_token_id

def tokenize(sample):
query = sample["chosen"][:sample["chosen"].rfind('Assistant:')+len('Assistant:')].replace('\n', ' ').strip()
sample["query"] = tokenizer.encode(query)
return sample

ds = ds.map(tokenize, batched=False)
ds.set_format(type="torch")
return ds


dataset = build_response_train_dataset(vas_config)

def collator(data):
return {key: [d[key] for d in data] for key in data[0]}

# set seed before initializing value head for deterministic eval
set_seed(vas_config.seed)

# Now let's build the model, the reference model, and the tokenizer.
quantization_config = BitsAndBytesConfig(
load_in_4bit=False,
load_in_8bit=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
)

model = trl_model_class.from_pretrained(
vas_config.model_name,
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
device_map='auto',
)

if args.use_peft:
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
model.pretrained_model = prepare_model_for_kbit_training(model.pretrained_model, use_gradient_checkpointing=True)
model.pretrained_model = get_peft_model(model.pretrained_model, peft_config)
model.is_peft_model = True

# Initialize the value head with zeros leads to better performance
torch.nn.init.zeros_(model.v_head.summary.weight)
torch.nn.init.zeros_(model.v_head.summary.bias)

# Disable dropout
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0

ref_model = trl_model_class.from_pretrained(
vas_config.ref_model_name,
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
device_map='auto',
)

tokenizer = ref_tokenizer =AutoTokenizer.from_pretrained(vas_config.model_name)

# Some tokenizers like don't have a padding token by default, so we set one here.
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.add_eos_token = True

# We then build the VASTrainer, passing the model, the reference model, the tokenizer
vas_trainer = VASTrainer(vas_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

device = vas_trainer.accelerator.device
if vas_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug

reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name).to(vas_trainer.accelerator.device)
reward_model = vas_trainer.accelerator.prepare(reward_model)
reward_model.requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
reward_model.eval()

generation_kwargs = {
"top_k": args.top_k,
"top_p": args.top_p,
"temperature": args.temperature,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 100,
}

for _epoch, batch in tqdm(enumerate(vas_trainer.dataloader)):
query_tensors = batch["query"]
response_tensors = vas_trainer.generate(query_tensors, batch_size=args.generation_batch_size, return_prompt=False, **generation_kwargs)

# Compute score
full_responses = [torch.cat([query, response]) for query, response in zip(query_tensors, response_tensors)]
texts = tokenizer.batch_decode(full_responses, skip_special_tokens=True)
rewards = []
for text in texts:
inputs_ids = reward_tokenizer.encode(text, return_tensors='pt').to(reward_model.device)
reward_outputs = reward_model(inputs_ids)
reward = reward_outputs.logits[0]
rewards.append(reward.squeeze())

# Run VAS step
stats = vas_trainer.step(query_tensors, response_tensors, rewards)
vas_trainer.log_stats(stats, batch, rewards, columns_to_log=["query"])

vas_trainer.save_pretrained("/data/pulkitag/models/idanshen/trl/example")

# Decoding example
# query = "Human: How are you doing today? Assistant:"
# inputs = ref_tokenizer.encode(query, return_tensors='pt').to(reward_model.device)
# output = vas_trainer.generate(inputs, vas_generation=True, beta=3.0)

6 changes: 6 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
"RLOOTrainer",
"SFTConfig",
"SFTTrainer",
"VASConfig",
"VASTrainer",
"WinRateCallback",
"XPOConfig",
"XPOTrainer",
Expand Down Expand Up @@ -168,6 +170,8 @@
PairRMJudge,
PPOConfig,
PPOTrainer,
PPOv2Config,
PPOv2Trainer,
kashif marked this conversation as resolved.
Show resolved Hide resolved
RandomPairwiseJudge,
RandomRankJudge,
RewardConfig,
Expand All @@ -176,6 +180,8 @@
RLOOTrainer,
SFTConfig,
SFTTrainer,
VASConfig,
VASTrainer,
WinRateCallback,
XPOConfig,
XPOTrainer,
Expand Down
187 changes: 187 additions & 0 deletions trl/extras/vas_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Optional, Union

import torch
from transformers import (
LogitsProcessor,
GenerationConfig
)
from ..core import set_seed
from ..models import (
SUPPORTED_ARCHITECTURES,
PreTrainedModelWrapper,
AutoModelForSeq2SeqLMWithValueHead,
AutoModelForCausalLMWithValueHead
)


class VASSampler:
def __init__(self,
model: PreTrainedModelWrapper,
value_model: Union[AutoModelForSeq2SeqLMWithValueHead, AutoModelForCausalLMWithValueHead],
beta: float = 1.0,
top_k: int = 10,
value_model_batch_size: int = 1,
seed: Optional[int] = None,
generation_config: Optional[GenerationConfig] = None,
) -> None:
"""
VASSampler is used to generate responses from a model trained with the VAS framework (see /trainers/VASTrainer.py).
Args:
model (`PreTrainedModelWrapper`):
The pretrained model to use for generation
value_model (`AutoModelForSeq2SeqLMWithValueHead` or `AutoModelForCausalLMWithValueHead`):
The Pretrained Value model use to augment the sampling process
beta (`float`):
The value to use for weighting the Value outputs versus the logits of the LLM
top_k (`int`):
The number of top-k tokens that will be evaluated using the Value model
value_model_batch_size (`int`):
Batch size for the Value model, can be different from the batch size of the LLM
seed (`int`, *optional*):
Random seed used to control generation
generation_config (`GenerationConfig`, *optional*):
Generation config passed to the underlying model's `generate` method.
See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details

"""
if seed is not None:
set_seed(seed)

if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(
f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
)

if not isinstance(value_model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(
f"value_model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
)

self.model = model
self.value_model = value_model
self.beta = beta,
self.top_k = top_k
self.value_model_batch_size = value_model_batch_size
self.gen_config = generation_config

# Create a VAS logits processor
self.logits_processor = VASLogitsProcessor(self.value_model,
beta=self.beta,
top_k=self.top_k,
value_model_batch_size=self.value_model_batch_size)

def generate(
self,
tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
device: Optional[Union[str, torch.device]] = None,
**generation_kwargs
) -> List[List[str]]:
"""
Generate a response using the VAS framework.

Args:
tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
device (`str` or `torch.device`, *optional*):
The device on which the model will be loaded
**generation_kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's `generate` method.
This is used to override generation config

Returns:
List[List[str]]: A list of lists of generated texts
"""

# generate the response
outputs = self.model.generate(
tokenized_query.to(device),
logits_processor=self.logits_processor,
generation_config=self.gen_config,
**generation_kwargs)

return outputs


class VASLogitsProcessor(LogitsProcessor, torch.nn.Module):
"""

value_model: AutoModelForCausalLMWithValueHead, the Value model to use
beta: float, the beta value to use for weighting the q model
topk: int, the number of topk to use for the Value model
topk_per_device_batch_size: int, the batch suze of tokens to evaluate at once
"""

def __init__(self,
value_model: Union[AutoModelForSeq2SeqLMWithValueHead, AutoModelForCausalLMWithValueHead],
beta: float = 1.0,
top_k: int = 10,
value_model_batch_size: int = 1):
"""
A logit processor that augment the output logits with Value model as per the VAS decoding scheme.

Args:
value_model (`AutoModelForSeq2SeqLMWithValueHead` or `AutoModelForCausalLMWithValueHead`):
The Pretrained Value model use to augment the sampling process
beta (`float`):
The value to use for weighting the Value outputs versus the logits of the LLM
top_k (`int`):
The number of top-k tokens that will be evaluated using the Value model
value_model_batch_size (`int`):
Batch size for the Value model, can be different from the batch size of the LLM
"""
super().__init__()
self.value_model = value_model
self.beta = beta
self.top_k = top_k
self.value_model_batch_size = value_model_batch_size

assert self.topk > 0, 'topk must be larger than zero'

self.last_input_ids = None
self.past_key_values = None

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
augmented_outputs = torch.clone(scores)
batch_size = input_ids.shape[0]

orig_input_ids = input_ids

if self.last_input_ids is not None and (input_ids[0, :-1].shape == self.last_input_ids.shape) and torch.all(input_ids[0, :-1] == self.last_input_ids):
# if the last input ids are the same as the current input ids, we can reuse the past key values
_, _, _, past_key_values = self.value_model(input_ids, past_key_values=self.past_key_values, return_past_key_values=True)
else:
_, _, _, past_key_values = self.value_model(input_ids, return_past_key_values=True)
self.past_key_values = past_key_values
self.last_input_ids = input_ids[0, :]

values = torch.zeros_like(scores, device=scores.device)
topk_ids = torch.topk(scores, self.topk, dim=-1).indices

for i in range(0, topk_ids.shape[1], self.topk_per_device_batch_size):
curr_topk_ids = topk_ids[:, i:i + self.topk_per_device_batch_size]
curr_input_ids = orig_input_ids.unsqueeze(1).repeat(1, curr_topk_ids.shape[1], 1)
curr_input_ids = torch.cat([curr_input_ids, curr_topk_ids.unsqueeze(-1)], dim=-1)
curr_input_ids = curr_input_ids.reshape((batch_size*self.topk_per_device_batch_size, -1))

_, _, value, _ = self.value_model(curr_input_ids, past_key_values=tuple((t1.repeat(curr_topk_ids.shape[1],1, 1, 1), t2.repeat(curr_topk_ids.shape[1],1, 1, 1)) for t1, t2 in self.past_key_values), return_past_key_values=True)
value = value.reshape((batch_size, self.topk_per_device_batch_size, -1))[:,:,-1]
values = values.scatter_(1, curr_topk_ids, value)

values = values.scatter_(1, topk_ids, values.gather(1, topk_ids) - values.gather(1, topk_ids).mean(-1, keepdim=True))
augmented_outputs += self.beta * values

return augmented_outputs
4 changes: 4 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"orpo_trainer": ["ORPOTrainer"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
'vas_config': ['VASConfig'],
'vas_trainer': ['VASTrainer'],
"ppov2_config": ["PPOv2Config"],
"ppov2_trainer": ["PPOv2Trainer"],
"reward_config": ["RewardConfig"],
Expand Down Expand Up @@ -118,6 +120,8 @@
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .vas_config import VASConfig
from .vas_trainer import VASTrainer
from .ppov2_config import PPOv2Config
from .ppov2_trainer import PPOv2Trainer
from .reward_config import RewardConfig
Expand Down
Loading