Skip to content

Commit

Permalink
update support of Starling RM and ultraRM
Browse files Browse the repository at this point in the history
  • Loading branch information
jdf-prog committed Feb 2, 2024
1 parent 8911cdb commit 5dad62b
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 40 deletions.
61 changes: 40 additions & 21 deletions llm_blender/blender/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(

def loadranker(self, ranker_path:str, device:str=None, **kwargs):
"""Load ranker from a path
Supportted rankers:
Supported rankers:
- llm-blender/pair-ranker
- llm-blender/pair-reward-model
- llm-blender/PairRM
- OpenAssistant/reward-model-deberta-v3-large-v2
- Other rankers that can be loaded by transformers.AutoModelForSequenceClassification
- openbmb/UltraRM-13b
- berkeley-nest/Starling-RM-7B-alpha
- Local path, e.g. "/path/to/ranker"
Args:
Expand Down Expand Up @@ -113,29 +114,43 @@ def loadranker(self, ranker_path:str, device:str=None, **kwargs):

# load ranker config from ranker_path
ranker_path = Path(ranker_path)
if not os.path.exists(ranker_path / "config.json"):
# other ranker type
ranker_config_json = {
"ranker_type": "other",
"model_type": "other",
"model_name": str(ranker_path),
"cache_dir": cache_dir,
}
ranker_config = RankerConfig.from_dict(ranker_config_json)
self.ranker_config = ranker_config
for k, v in kwargs.items():
setattr(self.ranker_config, k, v)
else:
if os.path.exists(ranker_path / "config.json"):
with open(ranker_path / "config.json", "r") as f:
ranker_config_json = json.load(f)
ranker_config = RankerConfig.from_dict(ranker_config_json)
ranker_config.load_checkpoint = str(ranker_path)
ranker_config.cache_dir = cache_dir
self.ranker_config = ranker_config
for k, v in kwargs.items():
if k in ['load_checkpoint', 'cache_dir']:
continue
setattr(self.ranker_config, k, v)
else:
ranker_config_json = {
"ranker_type": None,
"model_type": None,
"model_name": str(ranker_path),
"cache_dir": cache_dir,
}
ranker_config = RankerConfig.from_dict(ranker_config_json)
self.ranker_config = ranker_config
for k, v in kwargs.items():
setattr(self.ranker_config, k, v)
if ranker_config.model_name is None:
ranker_config.model_name = str(ranker_path)

# for other rms
if ranker_config.ranker_type not in ["pairranker", "summareranker", "simcls"]:
# tell from the ranker_path
if ranker_config.model_name.endswith("OpenAssistant/reward-model-deberta-v3-large-v2"):
ranker_config.ranker_type = "deberta-rm"
ranker_config.model_type = "deberta-rm"
elif ranker_config.model_name.endswith("berkeley-nest/Starling-RM-7B-alpha"):
ranker_config.ranker_type = "starling-rm"
ranker_config.model_type = "starling-rm"
elif ranker_config.model_name.endswith("openbmb/UltraRM-13b"):
ranker_config.ranker_type = "ultra-rm"
ranker_config.model_type = "ultra-rm"
else:
raise ValueError(f"reward model type {ranker_config.model_name} not supported")
ranker_config.load_checkpoint = None

self.ranker_config.device = device or self.ranker_config.device or self.blender_config.device

self.ranker, self.ranker_tokenizer, self.ranker_collator = load_ranker(ranker_config)
Expand Down Expand Up @@ -211,10 +226,14 @@ def rank(
elif self.ranker_config.ranker_type in ["summareranker", "simcls"]:
outputs = self.ranker(**batch)
batch_scores = outputs['logits'].detach().cpu().numpy()
elif self.ranker_config.ranker_type == "other":
elif self.ranker_config.ranker_type in ["deberta-rm"]:
outputs = self.ranker(**batch)
batch_scores = outputs.logits.detach().cpu().numpy()
batch_scores = batch_scores.squeeze(-1).reshape(batch_size, len(candidates[0]))
batch_scores = batch_scores.squeeze(-1).reshape(-1, len(candidates[0]))
else:
outputs = self.ranker(**batch) # outputs is a list of scores
batch_scores = outputs.detach().cpu().numpy()
batch_scores = batch_scores.reshape(-1, len(candidates[0]))
scores.append(batch_scores)
scores = np.concatenate(scores, axis=0)
if return_scores:
Expand Down
101 changes: 98 additions & 3 deletions llm_blender/pair_ranker/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __call__(self, batch):
"scores" : scores,
}

class OtherRMCollator(object):
class DebertaRMCollator(object):
def __init__(
self,
source_maxlength,
Expand All @@ -241,7 +241,7 @@ def __init__(
self.separate_token = self.sep_token
self.source_prefix = source_prefix if source_prefix is not None else ""
self.candidate_prefix = candidate_prefix if candidate_prefix is not None else ""
self.model_max_length = min(tokenizer.model_max_length, self.source_maxlength+self.candidate_maxlength+3)
self.model_max_length = tokenizer.model_max_length


def __call__(self, batch):
Expand All @@ -255,9 +255,104 @@ def __call__(self, batch):
encodings = self.tokenizer(
[s for s in batch_source for _ in range(len(batch_candidates[0]))],
[c for cs in batch_candidates for c in cs],
padding='max_length',
padding='longest',
return_tensors='pt',
truncation=False,
max_length=self.model_max_length,
)

return {**encodings}


class StarlingRMCollator(object):
template = "<s>[INST] {instruction} </s> [/INST] {completion}</s>"
def __init__(
self,
source_maxlength,
tokenizer,
candidate_maxlength,
source_prefix=None,
candidate_prefix=None,
):
self.tokenizer = tokenizer
self.source_maxlength = source_maxlength
self.candidate_maxlength = candidate_maxlength

self.sep_token = tokenizer.sep_token if tokenizer.sep_token is not None else tokenizer.eos_token
self.cls_token = tokenizer.cls_token if tokenizer.cls_token is not None else tokenizer.bos_token
assert self.sep_token is not None, 'sep_token is not found in the tokenizer'
self.separate_token = self.sep_token
self.source_prefix = source_prefix if source_prefix is not None else ""
self.candidate_prefix = candidate_prefix if candidate_prefix is not None else ""
self.model_max_length = tokenizer.model_max_length


def __call__(self, batch):
batch_size = len(batch)
batch_source = [b['source'] for b in batch]
batch_candidates = [b['candidates'] for b in batch]

batch_source = get_truncated_text(batch_source, self.tokenizer, self.source_maxlength)
batch_candidates = [get_truncated_text(c, self.tokenizer, self.candidate_maxlength) for c in batch_candidates]

input_texts = []
for i in range(batch_size):
for j in range(len(batch_candidates[i])):
input_texts.append(self.template.format(instruction=batch_source[i], completion=batch_candidates[i][j]))

encodings = self.tokenizer(
input_texts,
truncation=True,
max_length=2048,
padding="max_length",
return_tensors="pt",
)

return {**encodings}


class UltraRMCollator(object):
template = "Human: {instruction}\n\nAssistant: {completion}"

def __init__(
self,
source_maxlength,
tokenizer,
candidate_maxlength,
source_prefix=None,
candidate_prefix=None,
):
self.tokenizer = tokenizer
self.source_maxlength = source_maxlength
self.candidate_maxlength = candidate_maxlength

self.sep_token = tokenizer.sep_token if tokenizer.sep_token is not None else tokenizer.eos_token
self.cls_token = tokenizer.cls_token if tokenizer.cls_token is not None else tokenizer.bos_token
assert self.sep_token is not None, 'sep_token is not found in the tokenizer'
self.separate_token = self.sep_token
self.source_prefix = source_prefix if source_prefix is not None else ""
self.candidate_prefix = candidate_prefix if candidate_prefix is not None else ""
self.model_max_length = tokenizer.model_max_length


def __call__(self, batch):
batch_size = len(batch)
batch_source = [b['source'] for b in batch]
batch_candidates = [b['candidates'] for b in batch]

batch_source = get_truncated_text(batch_source, self.tokenizer, self.source_maxlength)
batch_candidates = [get_truncated_text(c, self.tokenizer, self.candidate_maxlength) for c in batch_candidates]

input_texts = []
for i in range(batch_size):
for j in range(len(batch_candidates[i])):
input_texts.append(self.template.format(instruction=batch_source[i], completion=batch_candidates[i][j]))

encodings = self.tokenizer(
input_texts,
padding='longest',
return_tensors='pt',
truncation=False,
max_length=self.model_max_length,
)

Expand Down
12 changes: 6 additions & 6 deletions llm_blender/pair_ranker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
@dataclass
class RankerConfig:
ranker_type:str = field(
default="pairranker",
default=None,
metadata={"help": "Ranker type, pairranker or reranker \
choices: summareranker, dual, pairranker, other;"},
)
model_type:str = field(default="deberta",
metadata={"help": "Model type, deberta or roberta"}
model_type:str = field(default=None,
metadata={"help": "Model type, deberta or roberta or other"}
)
model_name:str = field(default="microsoft/deberta-v3-large",
model_name:str = field(default=None,
metadata={"help": "Model name"}
)
cache_dir:str = field(default=None,
Expand All @@ -21,10 +21,10 @@ class RankerConfig:
load_checkpoint:str = field(default=None,
metadata={"help": "Load checkpoint path"}
)
source_maxlength:int = field(default=128,
source_maxlength:int = field(default=None,
metadata={"help": "Max length of the source sequence"}
)
candidate_maxlength:int = field(default=128,
candidate_maxlength:int = field(default=None,
metadata={"help": "Max length of the candidate sequence"}
)
n_tasks:int = field(default=1,
Expand Down
47 changes: 37 additions & 10 deletions llm_blender/pair_ranker/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
DualCollator,
SCRCollator,
CrossCompareCollator,
OtherRMCollator,
DebertaRMCollator,
StarlingRMCollator,
UltraRMCollator
)
from .other_rms.starling_rm import StarlingRM
from .other_rms.ultra_rm import UltraRM
from transformers import (
RobertaModel,
BertModel,
Expand All @@ -36,6 +40,8 @@ def build_pretrained_model(model_type, model_name, **kwargs):
model = T5ForConditionalGeneration.from_pretrained(model_name, **kwargs)
elif model_type.startswith("bart"):
model = BartForConditionalGeneration.from_pretrained(model_name, **kwargs)
elif model_type.startswith("deberta-rm"):
model = AutoModelForSequenceClassification.from_pretrained(model_name, **kwargs)
elif model_type.startswith("deberta"):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_name, **kwargs)
Expand All @@ -48,6 +54,11 @@ def build_pretrained_model(model_type, model_name, **kwargs):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
elif model_type.startswith("opt"):
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
elif model_type.startswith("starling-rm"):
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", **kwargs)
elif model_type.startswith("ultra-rm"):

model = UltraRM.from_pretrained(model_name, **kwargs)
elif model_type.startswith("other"):
model = AutoModelForSequenceClassification.from_pretrained(model_name, **kwargs)
else:
Expand All @@ -66,6 +77,10 @@ def build_tokenizer(model_name, **kwargs):
if "alpaca" in model_name or "llama" in model_name:
# padding left
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", **kwargs)
elif "starling-rm" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", **kwargs)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.truncation_side = "left"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
if tokenizer.pad_token is None:
Expand All @@ -76,36 +91,48 @@ def build_tokenizer(model_name, **kwargs):
def build_ranker(ranker_type, model_type, model_name, cache_dir, config, tokenizer):
ranker = None
pretrained_model = build_pretrained_model(model_type, model_name, cache_dir=cache_dir)
pretrained_model.resize_token_embeddings(len(tokenizer))
if ranker_type == "summareranker":
pretrained_model.resize_token_embeddings(len(tokenizer))
ranker = SummaReranker(pretrained_model, config, tokenizer)
elif ranker_type == "dual":
pretrained_model.resize_token_embeddings(len(tokenizer))
ranker = DualReranker(pretrained_model, config, tokenizer)
elif ranker_type == "pairranker":
pretrained_model.resize_token_embeddings(len(tokenizer))
ranker = CrossCompareReranker(pretrained_model, config, tokenizer)
elif ranker_type == "other":
elif ranker_type == "deberta-rm":
ranker = pretrained_model
elif ranker_type == "starling-rm":
ranker = StarlingRM(pretrained_model, config, tokenizer)
elif ranker_type == "ultra-rm":
ranker = pretrained_model
else:
raise ValueError(f"ranker_type {ranker_type} not supported")
return ranker

def build_collator(
model_type:str,
ranker_type:str,
tokenizer,
source_maxlength:int,
candidate_maxlength:int,
source_prefix:str = None,
candidate1_prefix:str = None,
candidate2_prefix:str = None,
):
if model_type == "summareranker":
if ranker_type == "summareranker":
return SCRCollator(source_maxlength, tokenizer, candidate_maxlength, source_prefix, candidate1_prefix)
elif model_type == "dual":
elif ranker_type == "dual":
return DualCollator(source_maxlength, tokenizer, candidate_maxlength, source_prefix, candidate1_prefix)
elif model_type == "pairranker":
elif ranker_type == "pairranker":
return CrossCompareCollator(source_maxlength, tokenizer, candidate_maxlength, source_prefix, candidate1_prefix, candidate2_prefix)
elif model_type == "other":
return OtherRMCollator(source_maxlength, tokenizer, candidate_maxlength, "", "")
elif ranker_type == "deberta-rm":
return DebertaRMCollator(source_maxlength, tokenizer, candidate_maxlength)
elif ranker_type == "starling-rm":
return StarlingRMCollator(source_maxlength, tokenizer, candidate_maxlength)
elif ranker_type == "ultra-rm":
return UltraRMCollator(source_maxlength, tokenizer, candidate_maxlength)
else:
raise ValueError(f"model_type {model_type} not supported")
raise ValueError(f"ranker_type {ranker_type} not supported")


def get_torch_dtype(dtype_str):
Expand Down
Loading

0 comments on commit 5dad62b

Please sign in to comment.