Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jdf-prog committed Nov 24, 2023
1 parent f26df0b commit 348dec9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
## 🔥News

- [11/10] Glad to announce that our pairwise reward-model, 🤗[PairRM](https://huggingface.co/llm-blender/PairRM), has released. It's trained on high-quality and large-scale human reference dataset and approaches GPT-4's alignment with human preference with a extremly small model size (0.4B).
- [10/28] We release a newly trained PairRanker used for reward model at 🤗 [llm-blender/pair-reward-model](https://huggingface.co/llm-blender/pair-reward-model)

- [10/24] Pre-trained PairRanker is able to be loaded directly from 🤗 Hugging face Models [llm-blender/PairRM](https://huggingface.co/llm-blender/PairRM) within 3 lines of code. See Guidance for [Rank & Fusion](#rank-and-fusion) for details.

Expand Down
4 changes: 2 additions & 2 deletions llm_blender/pair_ranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def save_model(self, output_dir: Optional[str] = None, **kwargs):
if self.is_world_process_zero():
super().save_model(output_dir, **kwargs)
model = self.model.module if hasattr(self.model, "module") else self.model
json.dump(asdict(model.args), open(os.path.join(output_dir, "ranker_config.json"), "w"), indent=4)

json.dump(asdict(model.args), open(os.path.join(output_dir, "config.json"), "w"), indent=4)
class FiDTrainer(Seq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
"""
Expand Down
14 changes: 7 additions & 7 deletions train_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def main(args):
args.n_tasks = predict_dataset.n_tasks

# set up model
config = RankerConfig()
for k in args.__dict__:
if k in config.__dict__:
print(k, getattr(args, k))
setattr(config, k, getattr(args, k))

if args.load_checkpoint:
# config = torch.load(os.path.join(args.load_checkpoint, "config.bin"))
config = RankerConfig.from_json_file(os.path.join(args.load_checkpoint, "config.json"))
for k in args.__dict__:
if k in config.__dict__:
print(k, getattr(args, k))
setattr(config, k, getattr(args, k))
model = build_ranker(
args.ranker_type,
args.model_type,
Expand All @@ -106,6 +106,7 @@ def main(args):
else:
logging.info(f"Successfully loaded checkpoint from '{args.load_checkpoint}'")
else:
config = RankerConfig()
model = build_ranker(
args.ranker_type,
args.model_type,
Expand Down Expand Up @@ -201,7 +202,6 @@ def main(args):
logging.info("Saving model")
best_checkpoint_folder = os.path.join(args.output_dir, "checkpoint-best")
trainer.save_model(best_checkpoint_folder)
torch.save(model.args, os.path.join(best_checkpoint_folder, "config.bin"))

if args.do_predict:
logging.info("Start predicting")
Expand Down

0 comments on commit 348dec9

Please sign in to comment.